aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python/rewrite.py')
-rw-r--r--mlir/test/python/rewrite.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000..acf7db2
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,69 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+
+
+# CHECK-LABEL: TEST: testRewritePattern
+@run
+def testRewritePattern():
+ def to_muli(op, rewriter):
+ with rewriter.ip:
+ new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+ rewriter.replace_op(op, new_op.owner)
+
+ def constant_1_to_2(op, rewriter):
+ c = op.attributes["value"].value
+ if c != 1:
+ return True # failed to match
+ with rewriter.ip:
+ new_op = arith.constant(op.result.type, 2, loc=op.location)
+ rewriter.replace_op(op, [new_op])
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.AddIOp, to_muli)
+ patterns.add(arith.ConstantOp, constant_1_to_2)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+ # CHECK: return %0 : i64
+ print(module)
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @const() -> (i64, i64) {
+ %0 = arith.constant 1 : i64
+ %1 = arith.constant 3 : i64
+ return %0, %1 : i64, i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: %c3_i64 = arith.constant 3 : i64
+ # CHECK: return %c2_i64, %c3_i64 : i64, i64
+ print(module)