diff options
-rw-r--r-- | mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp | 69 | ||||
-rw-r--r-- | mlir/test/Dialect/Arith/int-narrowing.mlir | 155 |
2 files changed, 222 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 01507e3..cb6e437 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -237,6 +237,10 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> { /// this, taking into account `BinaryOp` semantics. virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0; + /// Customization point for patterns that should only apply with + /// zero/sign-extension ops as arguments. + virtual bool isSupported(ExtensionOp) const { return true; } + LogicalResult matchAndRewrite(BinaryOp op, PatternRewriter &rewriter) const final { Type origTy = op.getType(); @@ -247,7 +251,7 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> { // For the optimization to apply, we expect the lhs to be an extension op, // and for the rhs to either be the same extension op or a constant. FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp()); - if (failed(ext)) + if (failed(ext) || !isSupported(*ext)) return failure(); FailureOr<unsigned> lhsBitsRequired = @@ -286,6 +290,27 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> { struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + // Addition may require one extra bit for the result. + // Example: `UINT8_MAX + 1 == 255 + 1 == 256`. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits + 1; + } +}; + +//===----------------------------------------------------------------------===// +// SubIOp Pattern +//===----------------------------------------------------------------------===// + +struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to signed arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Sign; + } + + // Subtraction may require one extra bit for the result. + // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`. unsigned getResultBitsProduced(unsigned operandBits) const override { return operandBits + 1; } @@ -298,12 +323,51 @@ struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> { struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + // Multiplication may require up double the operand bits. + // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`. unsigned getResultBitsProduced(unsigned operandBits) const override { return 2 * operandBits; } }; //===----------------------------------------------------------------------===// +// DivSIOp Pattern +//===----------------------------------------------------------------------===// + +struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to signed arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Sign; + } + + // Unlike multiplication, signed division requires only one more result bit. + // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits + 1; + } +}; + +//===----------------------------------------------------------------------===// +// DivUIOp Pattern +//===----------------------------------------------------------------------===// + +struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> { + using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; + + // This optimization only applies to unsigned arguments. + bool isSupported(ExtensionOp ext) const override { + return ext.getKind() == ExtensionKind::Zero; + } + + // Unsigned division does not require any extra result bits. + unsigned getResultBitsProduced(unsigned operandBits) const override { + return operandBits; + } +}; + +//===----------------------------------------------------------------------===// // *IToFPOp Patterns //===----------------------------------------------------------------------===// @@ -625,7 +689,8 @@ void populateArithIntNarrowingPatterns( ExtensionOverTranspose, ExtensionOverFlatTranspose>( patterns.getContext(), options, PatternBenefit(2)); - patterns.add<AddIPattern, MulIPattern, SIToFPPattern, UIToFPPattern>( + patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern, + DivUIPattern, SIToFPPattern, UIToFPPattern>( patterns.getContext(), options); } diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir index 966e34c..4b155ad 100644 --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -102,6 +102,75 @@ func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> { } //===----------------------------------------------------------------------===// +// arith.subi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @subi_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @subi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.subi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @subi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because of mixed extensions. +// +// CHECK-LABEL: func.func @subi_mixed_ext_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[ADD]] : i32 +func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// arith.subi produces one more bit of result than the operand bitwidth. +// +// CHECK-LABEL: func.func @subi_extsi_i24 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @subi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// // arith.muli //===----------------------------------------------------------------------===// @@ -184,6 +253,92 @@ func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> { } //===----------------------------------------------------------------------===// +// arith.divsi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @divsi_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.divsi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @divsi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @divsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +// arith.divsi produces one more bit of result than the operand bitwidth. +// +// CHECK-LABEL: func.func @divsi_extsi_i24 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[ADD:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divsi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.divsi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.divui +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @divui_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[ARG0]], %[[ARG1]] : i8 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[SUB]] : i8 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @divui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.divui %a, %b : i32 + return %r : i32 +} + +// This patterns should only apply to `arith.divui` ops with zero-extended +// arguments. +// +// CHECK-LABEL: func.func @divui_extsi_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @divui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extsi %rhs : i8 to i32 + %r = arith.divui %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// // arith.*itofp //===----------------------------------------------------------------------===// |