aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTobias Gysi <tobias.gysi@nextsilicon.com>2025-06-19 16:32:48 +0200
committerTobias Gysi <tobias.gysi@nextsilicon.com>2025-06-20 06:07:08 +0000
commite2d291fcf6732a60db323a0d6f337ce28fd997f1 (patch)
tree5853ad3bd4db1f183f29c3026cd3e8b41797c6d4
parent5a531b115844a038d7bd0108ebafe6bacbef75e3 (diff)
downloadllvm-users/gysit/remove-arith-ext-mul-patterns.zip
llvm-users/gysit/remove-arith-ext-mul-patterns.tar.gz
llvm-users/gysit/remove-arith-ext-mul-patterns.tar.bz2
[mlir][arith] Delete mul ext canonicalizations (#144844)users/gysit/remove-arith-ext-mul-patterns
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ``` %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ``` _, %hi = arith.mului_extended %a, %b : i32 ``` This commit removes the faulty canonicalizations since they are not believed to be generally beneficial (c.f., the discussion of the alternative https://github.com/llvm/llvm-project/pull/144787 which fixes the canonicalizations).
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td28
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp6
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir114
3 files changed, 5 insertions, 143 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 13eb97a..5d9c2dd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -273,7 +273,7 @@ def RedundantSelectFalse :
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
(SelectOp $pred, $a, $c)>;
-// select(pred, false, true) => not(pred)
+// select(pred, false, true) => not(pred)
def SelectI1ToNot :
Pat<(SelectOp $pred,
(ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -361,10 +361,6 @@ def OrOfExtSI :
// TruncIOp
//===----------------------------------------------------------------------===//
-def ValuesWithSameType :
- Constraint<
- CPred<"llvm::all_equal({$0.getType(), $1.getType(), $2.getType()})">>;
-
def ValueWiderThan :
Constraint<And<[
CPred<"getScalarOrElementWidth($0) > getScalarOrElementWidth($1)">,
@@ -397,28 +393,6 @@ def TruncIShrSIToTrunciShrUI :
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
-// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
-def TruncIShrUIMulIToMulSIExtended :
- Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
- (Arith_MulIOp:$mul
- (Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
- (ConstantLikeMatcher AnyAttr:$c0))),
- (Arith_MulSIExtendedOp:$res__1 $x, $y),
- [(ValuesWithSameType $tr, $x, $y),
- (ValueWiderThan $mul, $x),
- (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
-
-// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
-def TruncIShrUIMulIToMulUIExtended :
- Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
- (Arith_MulIOp:$mul
- (Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
- (ConstantLikeMatcher AnyAttr:$c0))),
- (Arith_MulUIExtendedOp:$res__1 $x, $y),
- [(ValuesWithSameType $tr, $x, $y),
- (ValueWiderThan $mul, $x),
- (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
-
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 41f2d0f..082abf2 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1496,9 +1496,9 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
- TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
- context);
+ patterns
+ .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
+ context);
}
LogicalResult arith::TruncIOp::verify() {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index b6188c8..076f3a9 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
// CHECK-LABEL: @foldSubXX_tensor
-// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
// CHECK: %[[sub:.+]] = arith.subi
// CHECK: return %[[c0]], %[[sub]]
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2952,118 +2952,6 @@ func.func @truncIShrSIToTrunciShrUIBadShiftAmt2(%a: i64) -> i32 {
return %hi : i32
}
-// CHECK-LABEL: @wideMulToMulSIExtended
-// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
-// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : i32
-// CHECK-NEXT: return %[[HIGH]] : i32
-func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
- %x = arith.extsi %a: i32 to i64
- %y = arith.extsi %b: i32 to i64
- %m = arith.muli %x, %y: i64
- %c32 = arith.constant 32: i64
- %sh = arith.shrui %m, %c32 : i64
- %hi = arith.trunci %sh: i64 to i32
- return %hi : i32
-}
-
-// CHECK-LABEL: @wideMulToMulSIExtendedVector
-// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
-// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
-// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
-func.func @wideMulToMulSIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
- %x = arith.extsi %a: vector<3xi32> to vector<3xi64>
- %y = arith.extsi %b: vector<3xi32> to vector<3xi64>
- %m = arith.muli %x, %y: vector<3xi64>
- %c32 = arith.constant dense<32>: vector<3xi64>
- %sh = arith.shrui %m, %c32 : vector<3xi64>
- %hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
- return %hi : vector<3xi32>
-}
-
-// CHECK-LABEL: @wideMulToMulUIExtended
-// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
-// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : i32
-// CHECK-NEXT: return %[[HIGH]] : i32
-func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
- %x = arith.extui %a: i32 to i64
- %y = arith.extui %b: i32 to i64
- %m = arith.muli %x, %y: i64
- %c32 = arith.constant 32: i64
- %sh = arith.shrui %m, %c32 : i64
- %hi = arith.trunci %sh: i64 to i32
- return %hi : i32
-}
-
-// CHECK-LABEL: @wideMulToMulUIExtendedVector
-// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
-// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>
-// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
-func.func @wideMulToMulUIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
- %x = arith.extui %a: vector<3xi32> to vector<3xi64>
- %y = arith.extui %b: vector<3xi32> to vector<3xi64>
- %m = arith.muli %x, %y: vector<3xi64>
- %c32 = arith.constant dense<32>: vector<3xi64>
- %sh = arith.shrui %m, %c32 : vector<3xi64>
- %hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
- return %hi : vector<3xi32>
-}
-
-// CHECK-LABEL: @wideMulToMulIExtendedMixedExt
-// CHECK: arith.muli
-// CHECK: arith.shrui
-// CHECK: arith.trunci
-func.func @wideMulToMulIExtendedMixedExt(%a: i32, %b: i32) -> i32 {
- %x = arith.extsi %a: i32 to i64
- %y = arith.extui %b: i32 to i64
- %m = arith.muli %x, %y: i64
- %c32 = arith.constant 32: i64
- %sh = arith.shrui %m, %c32 : i64
- %hi = arith.trunci %sh: i64 to i32
- return %hi : i32
-}
-
-// CHECK-LABEL: @wideMulToMulSIExtendedBadExt
-// CHECK: arith.muli
-// CHECK: arith.shrui
-// CHECK: arith.trunci
-func.func @wideMulToMulSIExtendedBadExt(%a: i16, %b: i16) -> i32 {
- %x = arith.extsi %a: i16 to i64
- %y = arith.extsi %b: i16 to i64
- %m = arith.muli %x, %y: i64
- %c32 = arith.constant 32: i64
- %sh = arith.shrui %m, %c32 : i64
- %hi = arith.trunci %sh: i64 to i32
- return %hi : i32
-}
-
-// CHECK-LABEL: @wideMulToMulSIExtendedBadShift1
-// CHECK: arith.muli
-// CHECK: arith.shrui
-// CHECK: arith.trunci
-func.func @wideMulToMulSIExtendedBadShift1(%a: i32, %b: i32) -> i32 {
- %x = arith.extsi %a: i32 to i64
- %y = arith.extsi %b: i32 to i64
- %m = arith.muli %x, %y: i64
- %c33 = arith.constant 33: i64
- %sh = arith.shrui %m, %c33 : i64
- %hi = arith.trunci %sh: i64 to i32
- return %hi : i32
-}
-
-// CHECK-LABEL: @wideMulToMulSIExtendedBadShift2
-// CHECK: arith.muli
-// CHECK: arith.shrui
-// CHECK: arith.trunci
-func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 {
- %x = arith.extsi %a: i32 to i64
- %y = arith.extsi %b: i32 to i64
- %m = arith.muli %x, %y: i64
- %c31 = arith.constant 31: i64
- %sh = arith.shrui %m, %c31 : i64
- %hi = arith.trunci %sh: i64 to i32
- return %hi : i32
-}
-
// CHECK-LABEL: @foldShli0
// CHECK-SAME: (%[[ARG:.*]]: i64)
// CHECK: return %[[ARG]] : i64