diff options
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 08ddc7c..4c62289 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -644,6 +644,45 @@ OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// +// spirv.SNegate +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) { + // -(-x) = 0 - (0 - x) = x + auto op = getOperand(); + if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>()) + return negateOp->getOperand(0); + + // According to the SPIR-V spec: + // + // Signed-integer subtract of Operand from zero. + return constFoldUnaryOp<IntegerAttr>( + adaptor.getOperands(), [](const APInt &a) { + APInt zero = APInt::getZero(a.getBitWidth()); + return zero - a; + }); +} + +//===----------------------------------------------------------------------===// +// spirv.NotOp +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) { + // !(!x) = x + auto op = getOperand(); + if (auto notOp = op.getDefiningOp<spirv::NotOp>()) + return notOp->getOperand(0); + + // According to the SPIR-V spec: + // + // Complement the bits of Operand. + return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) { + a.flipAllBits(); + return a; + }); +} + +//===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===// @@ -714,6 +753,22 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) { // spirv.LogicalNot //===----------------------------------------------------------------------===// +OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) { + // !(!x) = x + auto op = getOperand(); + if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>()) + return notOp->getOperand(0); + + // According to the SPIR-V spec: + // + // Complement the bits of Operand. + return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), + [](const APInt &a) { + APInt zero = APInt::getZero(1); + return a == 1 ? zero : (zero + 1); + }); +} + void spirv::LogicalNotOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results |