diff options
| -rw-r--r-- | mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 38 | ||||
| -rw-r--r-- | mlir/test/Dialect/Complex/canonicalize.mlir | 22 |
3 files changed, 62 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index 7116bed..dd7c1a8 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -290,6 +290,7 @@ def ImOp : ComplexUnaryOp<"im", let results = (outs AnyFloat:$imaginary); let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -436,6 +437,7 @@ def ReOp : ComplexUnaryOp<"re", let results = (outs AnyFloat:$real); let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index f2d1a96..f8c9b63 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -6,9 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::complex; @@ -99,6 +102,36 @@ OpFoldResult ImOp::fold(FoldAdaptor adaptor) { return {}; } +namespace { +template <typename OpKind, int ComponentIndex> +struct FoldComponentNeg final : OpRewritePattern<OpKind> { + using OpRewritePattern<OpKind>::OpRewritePattern; + + LogicalResult matchAndRewrite(OpKind op, + PatternRewriter &rewriter) const override { + auto negOp = op.getOperand().template getDefiningOp<NegOp>(); + if (!negOp) + return failure(); + + auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>(); + if (!createOp) + return failure(); + + Type elementType = createOp.getType().getElementType(); + assert(isa<FloatType>(elementType)); + + rewriter.replaceOpWithNewOp<arith::NegFOp>( + op, elementType, createOp.getOperand(ComponentIndex)); + return success(); + } +}; +} // namespace + +void ImOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<FoldComponentNeg<ImOp, 1>>(context); +} + //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// @@ -113,6 +146,11 @@ OpFoldResult ReOp::fold(FoldAdaptor adaptor) { return {}; } +void ReOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<FoldComponentNeg<ReOp, 0>>(context); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir index f0d287f..2fd2002 100644 --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -155,3 +155,25 @@ func.func @complex_sub_zero() -> complex<f32> { %sub = complex.sub %complex1, %complex2 : complex<f32> return %sub : complex<f32> } + +// CHECK-LABEL: func @re_neg +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +func.func @re_neg(%arg0: f32, %arg1: f32) -> f32 { + %create = complex.create %arg0, %arg1: complex<f32> + // CHECK: %[[NEG:.*]] = arith.negf %[[ARG0]] + %neg = complex.neg %create : complex<f32> + %re = complex.re %neg : complex<f32> + // CHECK-NEXT: return %[[NEG]] + return %re : f32 +} + +// CHECK-LABEL: func @im_neg +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +func.func @im_neg(%arg0: f32, %arg1: f32) -> f32 { + %create = complex.create %arg0, %arg1: complex<f32> + // CHECK: %[[NEG:.*]] = arith.negf %[[ARG1]] + %neg = complex.neg %create : complex<f32> + %im = complex.im %neg : complex<f32> + // CHECK-NEXT: return %[[NEG]] + return %im : f32 +} |
