diff options
author | Adrian Kuegel <akuegel@google.com> | 2021-12-22 18:09:59 +0100 |
---|---|---|
committer | Adrian Kuegel <akuegel@google.com> | 2021-12-22 18:12:24 +0100 |
commit | 4a10457d33e92f2e024f6d024d168ddcd49c3a59 (patch) | |
tree | 01927cec4c70580ef578a6d4e14d3a8b7fe031d3 | |
parent | 4639461531e7325458dc6a37bc6d857c0df109f7 (diff) | |
download | llvm-4a10457d33e92f2e024f6d024d168ddcd49c3a59.zip llvm-4a10457d33e92f2e024f6d024d168ddcd49c3a59.tar.gz llvm-4a10457d33e92f2e024f6d024d168ddcd49c3a59.tar.bz2 |
[mlir][arith] Fix CmpIOP folding for vector types.
Previously, the folding assumed that it always operates on scalar types.
Differential Revision: https://reviews.llvm.org/D116151
-rw-r--r-- | mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp | 10 | ||||
-rw-r--r-- | mlir/test/Dialect/Arithmetic/canonicalize.mlir | 26 |
2 files changed, 35 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp index aca6e4f..a413fb2 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1056,13 +1056,21 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { llvm_unreachable("unknown cmpi predicate kind"); } +static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { + auto boolAttr = BoolAttr::get(ctx, value); + ShapedType shapedType = type.dyn_cast_or_null<ShapedType>(); + if (!shapedType) + return boolAttr; + return DenseElementsAttr::get(shapedType, boolAttr); +} + OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "cmpi takes two operands"); // cmpi(pred, x, x) if (getLhs() == getRhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); - return BoolAttr::get(getContext(), val); + return getBoolAttribute(getType(), getContext(), val); } auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir index 328ed1d..96a630a 100644 --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -22,6 +22,32 @@ func @cmpi_equal_operands(%arg0: i64) : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 } +// Test case: Folding of comparisons with equal vector operands. +// CHECK-LABEL: @cmpi_equal_vector_operands +// CHECK-DAG: %[[T:.*]] = arith.constant dense<true> +// CHECK-DAG: %[[F:.*]] = arith.constant dense<false> +// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]], +// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]] +func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>) + -> (vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, + vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, + vector<1x8xi1>, vector<1x8xi1>) { + %0 = arith.cmpi eq, %arg0, %arg0 : vector<1x8xi64> + %1 = arith.cmpi sle, %arg0, %arg0 : vector<1x8xi64> + %2 = arith.cmpi sge, %arg0, %arg0 : vector<1x8xi64> + %3 = arith.cmpi ule, %arg0, %arg0 : vector<1x8xi64> + %4 = arith.cmpi uge, %arg0, %arg0 : vector<1x8xi64> + %5 = arith.cmpi ne, %arg0, %arg0 : vector<1x8xi64> + %6 = arith.cmpi slt, %arg0, %arg0 : vector<1x8xi64> + %7 = arith.cmpi sgt, %arg0, %arg0 : vector<1x8xi64> + %8 = arith.cmpi ult, %arg0, %arg0 : vector<1x8xi64> + %9 = arith.cmpi ugt, %arg0, %arg0 : vector<1x8xi64> + return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 + : vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, + vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, + vector<1x8xi1>, vector<1x8xi1> +} + // ----- // CHECK-LABEL: @indexCastOfSignExtend |