aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorFinn Plummer <50529406+inbelic@users.noreply.github.com>2024-03-24 18:56:47 -0700
committerGitHub <noreply@github.com>2024-03-24 18:56:47 -0700
commit230b1895c493c511c11541af3b5bc819887c82a8 (patch)
tree70873e9d03b8f6bf95f4c7da259810bb8edfebde /mlir
parentcceedc939a43c7c732a5888364251775bffc2dba (diff)
downloadllvm-230b1895c493c511c11541af3b5bc819887c82a8.zip
llvm-230b1895c493c511c11541af3b5bc819887c82a8.tar.gz
llvm-230b1895c493c511c11541af3b5bc819887c82a8.tar.bz2
[mlir][spirv] Add folding for [S|U|LessThan[Equal] (#85435)
Add missing constant propogation folder for [S|U]LessThan[Equal]. Implement additional folding when the operands are equal for all ops. Allows for constant folding in the IndexToSPIRV pass. Part of work #70704
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td8
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp82
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir176
3 files changed, 266 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 3ee239d..14d639b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -716,6 +716,8 @@ def SPIRV_SLessThanOp : SPIRV_LogicalBinaryOp<"SLessThan",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -745,6 +747,8 @@ def SPIRV_SLessThanEqualOp : SPIRV_LogicalBinaryOp<"SLessThanEqual",
%5 = spirv.SLessThanEqual %2, %3 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -886,6 +890,8 @@ def SPIRV_ULessThanOp : SPIRV_LogicalBinaryOp<"ULessThan",
%5 = spirv.ULessThan %2, %3 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -949,6 +955,8 @@ def SPIRV_ULessThanEqualOp : SPIRV_LogicalBinaryOp<"ULessThanEqual",
%5 = spirv.ULessThanEqual %2, %3 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
#endif // MLIR_DIALECT_SPIRV_IR_LOGICAL_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index ff4bace..3f1a782 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -881,6 +881,88 @@ OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
}
//===----------------------------------------------------------------------===//
+// spirv.SLessThan
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
+ // x == x -> false
+ if (getOperand1() == getOperand2()) {
+ auto falseAttr = BoolAttr::get(getContext(), false);
+ if (isa<IntegerType>(getType()))
+ return falseAttr;
+ if (auto vecTy = dyn_cast<VectorType>(getType()))
+ return SplatElementsAttr::get(vecTy, falseAttr);
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+ return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SLessThanEqual
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
+ // x == x -> true
+ if (getOperand1() == getOperand2()) {
+ auto trueAttr = BoolAttr::get(getContext(), true);
+ if (isa<IntegerType>(getType()))
+ return trueAttr;
+ if (auto vecTy = dyn_cast<VectorType>(getType()))
+ return SplatElementsAttr::get(vecTy, trueAttr);
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+ return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ULessThan
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
+ // x == x -> false
+ if (getOperand1() == getOperand2()) {
+ auto falseAttr = BoolAttr::get(getContext(), false);
+ if (isa<IntegerType>(getType()))
+ return falseAttr;
+ if (auto vecTy = dyn_cast<VectorType>(getType()))
+ return SplatElementsAttr::get(vecTy, falseAttr);
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+ return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ULessThanEqual
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
+ // x == x -> true
+ if (getOperand1() == getOperand2()) {
+ auto trueAttr = BoolAttr::get(getContext(), true);
+ if (isa<IntegerType>(getType()))
+ return trueAttr;
+ if (auto vecTy = dyn_cast<VectorType>(getType()))
+ return SplatElementsAttr::get(vecTy, trueAttr);
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+ return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
+ });
+}
+
+//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogical
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index de21d114..ed0bd07 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1479,6 +1479,182 @@ func.func @const_fold_vector_inotequal() -> vector<3xi1> {
// -----
//===----------------------------------------------------------------------===//
+// spirv.SLessThan
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @slt_same
+func.func @slt_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+ %0 = spirv.SLessThan %arg0, %arg0 : i32
+ %1 = spirv.SLessThan %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_slt
+func.func @const_fold_scalar_slt() -> (i1, i1) {
+ %c4 = spirv.Constant 4 : i32
+ %c5 = spirv.Constant 5 : i32
+ %c6 = spirv.Constant 6 : i32
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.SLessThan %c5, %c6 : i32
+ %1 = spirv.SLessThan %c5, %c4 : i32
+
+ // CHECK: return %[[CTRUE]], %[[CFALSE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_slt
+func.func @const_fold_vector_slt() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
+ %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
+ %0 = spirv.SLessThan %cv0, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SLessThanEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sle_same
+func.func @sle_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+ %0 = spirv.SLessThanEqual %arg0, %arg0 : i32
+ %1 = spirv.SLessThanEqual %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_sle
+func.func @const_fold_scalar_sle() -> (i1, i1) {
+ %c4 = spirv.Constant 4 : i32
+ %c5 = spirv.Constant 5 : i32
+ %c6 = spirv.Constant 6 : i32
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.SLessThanEqual %c5, %c6 : i32
+ %1 = spirv.SLessThanEqual %c5, %c4 : i32
+
+ // CHECK: return %[[CTRUE]], %[[CFALSE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_sle
+func.func @const_fold_vector_sle() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
+ %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
+ %0 = spirv.SLessThanEqual %cv0, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ULessThan
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @ult_same
+func.func @ult_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+ %0 = spirv.ULessThan %arg0, %arg0 : i32
+ %1 = spirv.ULessThan %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_ult
+func.func @const_fold_scalar_ult() -> (i1, i1) {
+ %c4 = spirv.Constant 4 : i32
+ %c5 = spirv.Constant 5 : i32
+ %cn6 = spirv.Constant -6 : i32
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.ULessThan %c5, %cn6 : i32
+ %1 = spirv.ULessThan %c5, %c4 : i32
+
+ // CHECK: return %[[CTRUE]], %[[CFALSE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_ult
+func.func @const_fold_vector_ult() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
+ %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
+ %0 = spirv.ULessThan %cv0, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ULessThanEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @ule_same
+func.func @ule_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+ %0 = spirv.ULessThanEqual %arg0, %arg0 : i32
+ %1 = spirv.ULessThanEqual %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_ule
+func.func @const_fold_scalar_ule() -> (i1, i1) {
+ %c4 = spirv.Constant 4 : i32
+ %c5 = spirv.Constant 5 : i32
+ %cn6 = spirv.Constant -6 : i32
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.ULessThanEqual %c5, %cn6 : i32
+ %1 = spirv.ULessThanEqual %c5, %c4 : i32
+
+ // CHECK: return %[[CTRUE]], %[[CFALSE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_ule
+func.func @const_fold_vector_ule() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
+ %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
+ %0 = spirv.ULessThanEqual %cv0, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spirv.LeftShiftLogical
//===----------------------------------------------------------------------===//