aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgoldsteinn <35538541+goldsteinn@users.noreply.github.com>2024-10-29 07:41:59 -0700
committerGitHub <noreply@github.com>2024-10-29 09:41:59 -0500
commit2e612f8d868b3fb88a44964a3d4efd61ee63e06a (patch)
treed21945696fa80ad6697847bf5382d993741b50a1
parenta388df712700f38ad9a51d49a657a28e739f5eb4 (diff)
downloadllvm-2e612f8d868b3fb88a44964a3d4efd61ee63e06a.zip
llvm-2e612f8d868b3fb88a44964a3d4efd61ee63e06a.tar.gz
llvm-2e612f8d868b3fb88a44964a3d4efd61ee63e06a.tar.bz2
[MLIR][Arith] Improve accuracy of `inferDivU` (#113789)
1) We can always bound the maximum with the numerator. - https://alive2.llvm.org/ce/z/PqHvuT 2) Even if denominator min can be zero, we can still bound the minimum result with `lhs.umin u/ rhs.umax`. This is similar to https://github.com/llvm/llvm-project/pull/110169
-rw-r--r--mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp10
-rw-r--r--mlir/test/Dialect/Arith/int-range-interface.mlir21
2 files changed, 25 insertions, 6 deletions
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index ec9ed87..a2acf3e 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -298,8 +298,14 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
/*isSigned=*/false);
}
- // Otherwise, it's possible we might divide by 0.
- return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+
+ APInt umin = APInt::getZero(rhsMin.getBitWidth());
+ if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
+ umin = lhsMin.udiv(rhsMax);
+
+ // X u/ Y u<= X.
+ APInt umax = lhsMax;
+ return ConstantIntRanges::fromUnsigned(umin, umax);
}
ConstantIntRanges
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 4b04229..6d66da2 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -178,8 +178,8 @@ func.func @div_bounds_negative(%arg0 : index) -> i1 {
}
// CHECK-LABEL: func @div_zero_undefined
-// CHECK: %[[ret:.*]] = arith.cmpi ule
-// CHECK: return %[[ret]]
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
func.func @div_zero_undefined(%arg0 : index) -> i1 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -190,6 +190,19 @@ func.func @div_zero_undefined(%arg0 : index) -> i1 {
func.return %2 : i1
}
+// CHECK-LABEL: func @div_refine_min
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @div_refine_min(%arg0 : index) -> i1 {
+ %c0 = arith.constant 1 : index
+ %c1 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %0 = arith.andi %arg0, %c1 : index
+ %1 = arith.divui %c4, %0 : index
+ %2 = arith.cmpi uge, %1, %c0 : index
+ func.return %2 : i1
+}
+
// CHECK-LABEL: func @ceil_divui
// CHECK: %[[ret:.*]] = arith.cmpi eq
// CHECK: return %[[ret]]
@@ -271,13 +284,13 @@ func.func @remui_base(%arg0 : index, %arg1 : index ) -> i1 {
// CHECK: return %[[true]]
func.func @remui_base_maybe_zero(%arg0 : index, %arg1 : index ) -> i1 {
%c4 = arith.constant 4 : index
- %c5 = arith.constant 5 : index
+ %c5 = arith.constant 5 : index
%0 = arith.minui %arg1, %c4 : index
%1 = arith.remui %arg0, %0 : index
%2 = arith.cmpi ult, %1, %c5 : index
func.return %2 : i1
-}
+}
// CHECK-LABEL: func @remsi_base
// CHECK: %[[ret:.*]] = arith.cmpi sge