diff options
Diffstat (limited to 'mlir/lib/Dialect/Arith/IR/ArithOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d8b314a..e016a6e 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns( // DivUIOp //===----------------------------------------------------------------------===// +/// Fold `(a * b) / b -> a` +static Value foldDivMul(Value lhs, Value rhs, + arith::IntegerOverflowFlags ovfFlags) { + auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>(); + if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags)) + return {}; + + if (mul.getLhs() == rhs) + return mul.getRhs(); + + if (mul.getRhs() == rhs) + return mul.getLhs(); + + return {}; +} + OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { // divui (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); + // (a * b) / b -> a + if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw)) + return val; + // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), @@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); + // (a * b) / b -> a + if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw)) + return val; + // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp<IntegerAttr>( |