diff options
-rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | 28 | ||||
-rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 6 | ||||
-rw-r--r-- | mlir/test/Dialect/Arith/canonicalize.mlir | 114 |
3 files changed, 5 insertions, 143 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 13eb97a..5d9c2dd 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -273,7 +273,7 @@ def RedundantSelectFalse : Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)), (SelectOp $pred, $a, $c)>; -// select(pred, false, true) => not(pred) +// select(pred, false, true) => not(pred) def SelectI1ToNot : Pat<(SelectOp $pred, (ConstantLikeMatcher ConstantAttr<I1Attr, "0">), @@ -361,10 +361,6 @@ def OrOfExtSI : // TruncIOp //===----------------------------------------------------------------------===// -def ValuesWithSameType : - Constraint< - CPred<"llvm::all_equal({$0.getType(), $1.getType(), $2.getType()})">>; - def ValueWiderThan : Constraint<And<[ CPred<"getScalarOrElementWidth($0) > getScalarOrElementWidth($1)">, @@ -397,28 +393,6 @@ def TruncIShrSIToTrunciShrUI : (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))), [(TruncationMatchesShiftAmount $x, $tr, $c0)]>; -// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) -def TruncIShrUIMulIToMulSIExtended : - Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp - (Arith_MulIOp:$mul - (Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1), - (ConstantLikeMatcher AnyAttr:$c0))), - (Arith_MulSIExtendedOp:$res__1 $x, $y), - [(ValuesWithSameType $tr, $x, $y), - (ValueWiderThan $mul, $x), - (TruncationMatchesShiftAmount $mul, $x, $c0)]>; - -// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) -def TruncIShrUIMulIToMulUIExtended : - Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp - (Arith_MulIOp:$mul - (Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1), - (ConstantLikeMatcher AnyAttr:$c0))), - (Arith_MulUIExtendedOp:$res__1 $x, $y), - [(ValuesWithSameType $tr, $x, $y), - (ValueWiderThan $mul, $x), - (TruncationMatchesShiftAmount $mul, $x, $c0)]>; - //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 41f2d0f..082abf2 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1496,9 +1496,9 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI, - TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>( - context); + patterns + .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>( + context); } LogicalResult arith::TruncIOp::verify() { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index b6188c8..076f3a9 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index { // CHECK-LABEL: @foldSubXX_tensor -// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> +// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> // CHECK: %[[sub:.+]] = arith.subi // CHECK: return %[[c0]], %[[sub]] func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) { @@ -2952,118 +2952,6 @@ func.func @truncIShrSIToTrunciShrUIBadShiftAmt2(%a: i64) -> i32 { return %hi : i32 } -// CHECK-LABEL: @wideMulToMulSIExtended -// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32) -// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : i32 -// CHECK-NEXT: return %[[HIGH]] : i32 -func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 { - %x = arith.extsi %a: i32 to i64 - %y = arith.extsi %b: i32 to i64 - %m = arith.muli %x, %y: i64 - %c32 = arith.constant 32: i64 - %sh = arith.shrui %m, %c32 : i64 - %hi = arith.trunci %sh: i64 to i32 - return %hi : i32 -} - -// CHECK-LABEL: @wideMulToMulSIExtendedVector -// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>) -// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32> -// CHECK-NEXT: return %[[HIGH]] : vector<3xi32> -func.func @wideMulToMulSIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> { - %x = arith.extsi %a: vector<3xi32> to vector<3xi64> - %y = arith.extsi %b: vector<3xi32> to vector<3xi64> - %m = arith.muli %x, %y: vector<3xi64> - %c32 = arith.constant dense<32>: vector<3xi64> - %sh = arith.shrui %m, %c32 : vector<3xi64> - %hi = arith.trunci %sh: vector<3xi64> to vector<3xi32> - return %hi : vector<3xi32> -} - -// CHECK-LABEL: @wideMulToMulUIExtended -// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32) -// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : i32 -// CHECK-NEXT: return %[[HIGH]] : i32 -func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 { - %x = arith.extui %a: i32 to i64 - %y = arith.extui %b: i32 to i64 - %m = arith.muli %x, %y: i64 - %c32 = arith.constant 32: i64 - %sh = arith.shrui %m, %c32 : i64 - %hi = arith.trunci %sh: i64 to i32 - return %hi : i32 -} - -// CHECK-LABEL: @wideMulToMulUIExtendedVector -// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>) -// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32> -// CHECK-NEXT: return %[[HIGH]] : vector<3xi32> -func.func @wideMulToMulUIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> { - %x = arith.extui %a: vector<3xi32> to vector<3xi64> - %y = arith.extui %b: vector<3xi32> to vector<3xi64> - %m = arith.muli %x, %y: vector<3xi64> - %c32 = arith.constant dense<32>: vector<3xi64> - %sh = arith.shrui %m, %c32 : vector<3xi64> - %hi = arith.trunci %sh: vector<3xi64> to vector<3xi32> - return %hi : vector<3xi32> -} - -// CHECK-LABEL: @wideMulToMulIExtendedMixedExt -// CHECK: arith.muli -// CHECK: arith.shrui -// CHECK: arith.trunci -func.func @wideMulToMulIExtendedMixedExt(%a: i32, %b: i32) -> i32 { - %x = arith.extsi %a: i32 to i64 - %y = arith.extui %b: i32 to i64 - %m = arith.muli %x, %y: i64 - %c32 = arith.constant 32: i64 - %sh = arith.shrui %m, %c32 : i64 - %hi = arith.trunci %sh: i64 to i32 - return %hi : i32 -} - -// CHECK-LABEL: @wideMulToMulSIExtendedBadExt -// CHECK: arith.muli -// CHECK: arith.shrui -// CHECK: arith.trunci -func.func @wideMulToMulSIExtendedBadExt(%a: i16, %b: i16) -> i32 { - %x = arith.extsi %a: i16 to i64 - %y = arith.extsi %b: i16 to i64 - %m = arith.muli %x, %y: i64 - %c32 = arith.constant 32: i64 - %sh = arith.shrui %m, %c32 : i64 - %hi = arith.trunci %sh: i64 to i32 - return %hi : i32 -} - -// CHECK-LABEL: @wideMulToMulSIExtendedBadShift1 -// CHECK: arith.muli -// CHECK: arith.shrui -// CHECK: arith.trunci -func.func @wideMulToMulSIExtendedBadShift1(%a: i32, %b: i32) -> i32 { - %x = arith.extsi %a: i32 to i64 - %y = arith.extsi %b: i32 to i64 - %m = arith.muli %x, %y: i64 - %c33 = arith.constant 33: i64 - %sh = arith.shrui %m, %c33 : i64 - %hi = arith.trunci %sh: i64 to i32 - return %hi : i32 -} - -// CHECK-LABEL: @wideMulToMulSIExtendedBadShift2 -// CHECK: arith.muli -// CHECK: arith.shrui -// CHECK: arith.trunci -func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 { - %x = arith.extsi %a: i32 to i64 - %y = arith.extsi %b: i32 to i64 - %m = arith.muli %x, %y: i64 - %c31 = arith.constant 31: i64 - %sh = arith.shrui %m, %c31 : i64 - %hi = arith.trunci %sh: i64 to i32 - return %hi : i32 -} - // CHECK-LABEL: @foldShli0 // CHECK-SAME: (%[[ARG:.*]]: i64) // CHECK: return %[[ARG]] : i64 |