diff options
author | long.chen <lipracer@gmail.com> | 2024-03-22 23:52:47 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-22 23:52:47 +0800 |
commit | 631e54aa1a0b7a79d0dec8dce7ec0f5e506acf6c (patch) | |
tree | df022721db3779286765f3299514708f99781f02 /mlir | |
parent | 3054d0dae7a813c493d2bb8e969aa2321145a83b (diff) | |
download | llvm-631e54aa1a0b7a79d0dec8dce7ec0f5e506acf6c.zip llvm-631e54aa1a0b7a79d0dec8dce7ec0f5e506acf6c.tar.gz llvm-631e54aa1a0b7a79d0dec8dce7ec0f5e506acf6c.tar.bz2 |
[mlir][arith] fix wrong floordivsi fold (#83248)
Fixs https://github.com/llvm/llvm-project/issues/83079
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 36 | ||||
-rw-r--r-- | mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 58 | ||||
-rw-r--r-- | mlir/test/Dialect/Arith/expand-ops.mlir | 84 | ||||
-rw-r--r-- | mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir | 30 | ||||
-rw-r--r-- | mlir/test/Transforms/canonicalize.mlir | 9 |
5 files changed, 102 insertions, 115 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 9f64a07..2f32d9a 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -689,43 +689,17 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) { return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. - bool overflowOrDiv0 = false; + bool overflowOrDiv = false; auto result = constFoldBinaryOp<IntegerAttr>( adaptor.getOperands(), [&](APInt a, const APInt &b) { - if (overflowOrDiv0 || !b) { - overflowOrDiv0 = true; + if (b.isZero()) { + overflowOrDiv = true; return a; } - if (!a) - return a; - // After this point we know that neither a or b are zero. - unsigned bits = a.getBitWidth(); - APInt zero = APInt::getZero(bits); - bool aGtZero = a.sgt(zero); - bool bGtZero = b.sgt(zero); - if (aGtZero && bGtZero) { - // Both positive, return a / b. - return a.sdiv_ov(b, overflowOrDiv0); - } - if (!aGtZero && !bGtZero) { - // Both negative, return -a / -b. - APInt posA = zero.ssub_ov(a, overflowOrDiv0); - APInt posB = zero.ssub_ov(b, overflowOrDiv0); - return posA.sdiv_ov(posB, overflowOrDiv0); - } - if (!aGtZero && bGtZero) { - // A is negative, b is positive, return - ceil(-a, b). - APInt posA = zero.ssub_ov(a, overflowOrDiv0); - APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); - return zero.ssub_ov(ceil, overflowOrDiv0); - } - // A is positive, b is negative, return - ceil(a, -b). - APInt posB = zero.ssub_ov(b, overflowOrDiv0); - APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); - return zero.ssub_ov(ceil, overflowOrDiv0); + return a.sfloordiv_ov(b, overflowOrDiv); }); - return overflowOrDiv0 ? Attribute() : result; + return overflowOrDiv ? Attribute() : result; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 7f246da..71e14a1 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -110,9 +110,13 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { } }; -/// Expands FloorDivSIOp (n, m) into -/// 1) x = (m<0) ? 1 : -1 -/// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m +/// Expands FloorDivSIOp (x, y) into +/// z = x / y +/// if (z * y != x && (x < 0) != (y < 0)) { +/// return z - 1; +/// } else { +/// return z; +/// } struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(arith::FloorDivSIOp op, @@ -121,41 +125,29 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { Type type = op.getType(); Value a = op.getLhs(); Value b = op.getRhs(); - Value plusOne = createConst(loc, type, 1, rewriter); + + Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b); + Value product = rewriter.create<arith::MulIOp>(loc, quotient, b); + Value notEqualDivisor = rewriter.create<arith::CmpIOp>( + loc, arith::CmpIPredicate::ne, a, product); Value zero = createConst(loc, type, 0, rewriter); - Value minusOne = createConst(loc, type, -1, rewriter); - // Compute x = (b<0) ? 1 : -1. - Value compare = - rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); - Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne); - // Compute negative res: -1 - ((x-a)/b). - Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a); - Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b); - Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB); - // Compute positive res: a/b. - Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b); - // Result is (a*b<0) ? negative result : positive result. - // Note, we want to avoid using a*b because of possible overflow. - // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do - // not particuliarly care if a*b<0 is true or false when b is zero - // as this will result in an illegal divide. So `a*b<0` can be reformulated - // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'. - // We pick the first expression here. + Value aNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); - Value aPos = - rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); Value bNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); - Value bPos = - rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); - Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos); - Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg); - Value compareRes = - rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); - // Perform substitution and return success. - rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes, - posRes); + + Value signOpposite = rewriter.create<arith::CmpIOp>( + loc, arith::CmpIPredicate::ne, aNeg, bNeg); + Value cond = + rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite); + + Value minusOne = createConst(loc, type, -1, rewriter); + Value quotientMinusOne = + rewriter.create<arith::AddIOp>(loc, quotient, minusOne); + + rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne, + quotient); return success(); } }; diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index 91f652e..6bed93e 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -66,23 +66,17 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) { func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) { %res = arith.floordivsi %arg0, %arg1 : i32 return %res : i32 -// CHECK: [[ONE:%.+]] = arith.constant 1 : i32 -// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32 -// CHECK: [[MIN1:%.+]] = arith.constant -1 : i32 -// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : i32 -// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : i32 -// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32 -// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : i32 -// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32 -// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32 -// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32 -// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1 -// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1 -// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32 +// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : i32 +// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : i32 +// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : i32 +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : i32 +// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : i32 +// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1 +// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1 +// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant -1 : i32 +// CHECK: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : i32 +// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : i32 } // ----- @@ -93,23 +87,17 @@ func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) { func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) { %res = arith.floordivsi %arg0, %arg1 : index return %res : index -// CHECK: [[ONE:%.+]] = arith.constant 1 : index -// CHECK: [[ZERO:%.+]] = arith.constant 0 : index -// CHECK: [[MIN1:%.+]] = arith.constant -1 : index -// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index -// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : index -// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index -// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index -// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index -// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index -// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index -// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index -// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index -// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index -// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1 -// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1 -// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index +// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : index +// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : index +// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : index +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : index +// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : index +// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1 +// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1 +// CHECK: %[[NEG_ONE:.*]] = arith.constant -1 : index +// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : index +// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : index } // ----- @@ -121,23 +109,17 @@ func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) { func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) { %res = arith.floordivsi %arg0, %arg1 : vector<4xi32> return %res : vector<4xi32> -// CHECK: %[[VAL_2:.*]] = arith.constant dense<1> : vector<4xi32> -// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4xi32> -// CHECK: %[[VAL_4:.*]] = arith.constant dense<-1> : vector<4xi32> -// CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_2]], %[[VAL_4]] : vector<4xi1>, vector<4xi32> -// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_6]], %[[VAL_0]] : vector<4xi32> -// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_1]] : vector<4xi32> -// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_4]], %[[VAL_8]] : vector<4xi32> -// CHECK: %[[VAL_10:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32> -// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_12:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_11]], %[[VAL_14]] : vector<4xi1> -// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : vector<4xi1> -// CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_15]], %[[VAL_16]] : vector<4xi1> -// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_10]] : vector<4xi1>, vector<4xi32> +// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : vector<4xi32> +// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : vector<4xi32> +// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : vector<4xi32> +// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : vector<4xi32> +// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : vector<4xi32> +// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : vector<4xi32> +// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : vector<4xi1> +// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : vector<4xi1> +// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant dense<-1> : vector<4xi32> +// CHECK: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : vector<4xi32> +// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : vector<4xi1>, vector<4xi32> } // ----- diff --git a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir index 39fbb67..a7013ea 100644 --- a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir +++ b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir @@ -2,6 +2,10 @@ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_c_runner_utils | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf,lower-affine,convert-scf-to-cf,memref-expand,arith-expand),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s --check-prefix=SCHECK func.func @transfer_read_2d(%A : memref<40xi32>, %base1: index) { %i42 = arith.constant -42: i32 @@ -101,3 +105,29 @@ func.func @entry() { // CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) // CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 ) // CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + +// ----- + +func.func @non_inline_function() -> (i64, i64, i64, i64, i64, i64) { + %MIN_INT_MINUS_ONE = arith.constant -9223372036854775807 : i64 + %NEG_ONE = arith.constant -1 : i64 + %MIN_INT = arith.constant -9223372036854775808 : i64 + %ONE = arith.constant 1 : i64 + %MAX_INT = arith.constant 9223372036854775807 : i64 + return %MIN_INT_MINUS_ONE, %NEG_ONE, %MIN_INT, %ONE, %MAX_INT, %NEG_ONE : i64, i64, i64, i64, i64, i64 +} + +func.func @main() { + %0:6 = call @non_inline_function() : () -> (i64, i64, i64, i64, i64, i64) + %1 = arith.floordivsi %0#0, %0#1 : i64 + %2 = arith.floordivsi %0#2, %0#3 : i64 + %3 = arith.floordivsi %0#4, %0#5 : i64 + vector.print %1 : i64 + vector.print %2 : i64 + vector.print %3 : i64 + return +} + +// SCHECK: 9223372036854775807 +// SCHECK: -9223372036854775808 +// SCHECK: -9223372036854775807 diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 2cf86b50..d2c2c12 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -989,6 +989,15 @@ func.func @tensor_arith.floordivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5x return %res : tensor<4x5xi32> } +// CHECK-LABEL: func @arith.floordivsi_by_one_overflow +func.func @arith.floordivsi_by_one_overflow() -> i64 { + %neg_one = arith.constant -1 : i64 + %min_int = arith.constant -9223372036854775808 : i64 + // CHECK: arith.floordivsi + %poision = arith.floordivsi %min_int, %neg_one : i64 + return %poision : i64 +} + // ----- // CHECK-LABEL: func @arith.ceildivsi_by_one |