aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdrian Kuegel <akuegel@google.com>2021-12-22 18:09:59 +0100
committerAdrian Kuegel <akuegel@google.com>2021-12-22 18:12:24 +0100
commit4a10457d33e92f2e024f6d024d168ddcd49c3a59 (patch)
tree01927cec4c70580ef578a6d4e14d3a8b7fe031d3
parent4639461531e7325458dc6a37bc6d857c0df109f7 (diff)
downloadllvm-4a10457d33e92f2e024f6d024d168ddcd49c3a59.zip
llvm-4a10457d33e92f2e024f6d024d168ddcd49c3a59.tar.gz
llvm-4a10457d33e92f2e024f6d024d168ddcd49c3a59.tar.bz2
[mlir][arith] Fix CmpIOP folding for vector types.
Previously, the folding assumed that it always operates on scalar types. Differential Revision: https://reviews.llvm.org/D116151
-rw-r--r--mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp10
-rw-r--r--mlir/test/Dialect/Arithmetic/canonicalize.mlir26
2 files changed, 35 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index aca6e4f..a413fb2 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1056,13 +1056,21 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
llvm_unreachable("unknown cmpi predicate kind");
}
+static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
+ auto boolAttr = BoolAttr::get(ctx, value);
+ ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
+ if (!shapedType)
+ return boolAttr;
+ return DenseElementsAttr::get(shapedType, boolAttr);
+}
+
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two operands");
// cmpi(pred, x, x)
if (getLhs() == getRhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
- return BoolAttr::get(getContext(), val);
+ return getBoolAttribute(getType(), getContext(), val);
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 328ed1d..96a630a 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -22,6 +22,32 @@ func @cmpi_equal_operands(%arg0: i64)
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}
+// Test case: Folding of comparisons with equal vector operands.
+// CHECK-LABEL: @cmpi_equal_vector_operands
+// CHECK-DAG: %[[T:.*]] = arith.constant dense<true>
+// CHECK-DAG: %[[F:.*]] = arith.constant dense<false>
+// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
+// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
+func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
+ -> (vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
+ vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
+ vector<1x8xi1>, vector<1x8xi1>) {
+ %0 = arith.cmpi eq, %arg0, %arg0 : vector<1x8xi64>
+ %1 = arith.cmpi sle, %arg0, %arg0 : vector<1x8xi64>
+ %2 = arith.cmpi sge, %arg0, %arg0 : vector<1x8xi64>
+ %3 = arith.cmpi ule, %arg0, %arg0 : vector<1x8xi64>
+ %4 = arith.cmpi uge, %arg0, %arg0 : vector<1x8xi64>
+ %5 = arith.cmpi ne, %arg0, %arg0 : vector<1x8xi64>
+ %6 = arith.cmpi slt, %arg0, %arg0 : vector<1x8xi64>
+ %7 = arith.cmpi sgt, %arg0, %arg0 : vector<1x8xi64>
+ %8 = arith.cmpi ult, %arg0, %arg0 : vector<1x8xi64>
+ %9 = arith.cmpi ugt, %arg0, %arg0 : vector<1x8xi64>
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
+ : vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
+ vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>, vector<1x8xi1>,
+ vector<1x8xi1>, vector<1x8xi1>
+}
+
// -----
// CHECK-LABEL: @indexCastOfSignExtend