aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp')
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp55
1 files changed, 55 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 08ddc7c..4c62289 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -644,6 +644,45 @@ OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
}
//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+ // -(-x) = 0 - (0 - x) = x
+ auto op = getOperand();
+ if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+ return negateOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer subtract of Operand from zero.
+ return constFoldUnaryOp<IntegerAttr>(
+ adaptor.getOperands(), [](const APInt &a) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return zero - a;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NotOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
+ // !(!x) = x
+ auto op = getOperand();
+ if (auto notOp = op.getDefiningOp<spirv::NotOp>())
+ return notOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Complement the bits of Operand.
+ return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
+ a.flipAllBits();
+ return a;
+ });
+}
+
+//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
@@ -714,6 +753,22 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
// spirv.LogicalNot
//===----------------------------------------------------------------------===//
+OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
+ // !(!x) = x
+ auto op = getOperand();
+ if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
+ return notOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Complement the bits of Operand.
+ return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a) {
+ APInt zero = APInt::getZero(1);
+ return a == 1 ? zero : (zero + 1);
+ });
+}
+
void spirv::LogicalNotOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results