diff options
Diffstat (limited to 'mlir/test/python')
-rw-r--r-- | mlir/test/python/rewrite.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py new file mode 100644 index 0000000..acf7db2 --- /dev/null +++ b/mlir/test/python/rewrite.py @@ -0,0 +1,69 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.ir import * +from mlir.passmanager import * +from mlir.dialects.builtin import ModuleOp +from mlir.dialects import arith +from mlir.rewrite import * + + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testRewritePattern +@run +def testRewritePattern(): + def to_muli(op, rewriter): + with rewriter.ip: + new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location) + rewriter.replace_op(op, new_op.owner) + + def constant_1_to_2(op, rewriter): + c = op.attributes["value"].value + if c != 1: + return True # failed to match + with rewriter.ip: + new_op = arith.constant(op.result.type, 2, loc=op.location) + rewriter.replace_op(op, [new_op]) + + with Context(): + patterns = RewritePatternSet() + patterns.add(arith.AddIOp, to_muli) + patterns.add(arith.ConstantOp, constant_1_to_2) + frozen = patterns.freeze() + + module = ModuleOp.parse( + r""" + module { + func.func @add(%a: i64, %b: i64) -> i64 { + %sum = arith.addi %a, %b : i64 + return %sum : i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %0 = arith.muli %arg0, %arg1 : i64 + # CHECK: return %0 : i64 + print(module) + + module = ModuleOp.parse( + r""" + module { + func.func @const() -> (i64, i64) { + %0 = arith.constant 1 : i64 + %1 = arith.constant 3 : i64 + return %0, %1 : i64, i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %c2_i64 = arith.constant 2 : i64 + # CHECK: %c3_i64 = arith.constant 3 : i64 + # CHECK: return %c2_i64, %c3_i64 : i64, i64 + print(module) |