aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Arith/IR/ArithOps.cpp')
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a..e016a6e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
// DivUIOp
//===----------------------------------------------------------------------===//
+/// Fold `(a * b) / b -> a`
+static Value foldDivMul(Value lhs, Value rhs,
+ arith::IntegerOverflowFlags ovfFlags) {
+ auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
+ if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
+ return {};
+
+ if (mul.getLhs() == rhs)
+ return mul.getRhs();
+
+ if (mul.getRhs() == rhs)
+ return mul.getLhs();
+
+ return {};
+}
+
OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
// divui (x, 1) -> x.
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
+ // (a * b) / b -> a
+ if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
+ return val;
+
// Don't fold if it would require a division by zero.
bool div0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
@@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
+ // (a * b) / b -> a
+ if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
+ return val;
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(