aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/rewrite.py
blob: acf7db23db9142145c509b6aec24167f7025ec3d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# RUN: %PYTHON %s 2>&1 | FileCheck %s

from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
from mlir.rewrite import *


def run(f):
    print("\nTEST:", f.__name__)
    f()


# CHECK-LABEL: TEST: testRewritePattern
@run
def testRewritePattern():
    def to_muli(op, rewriter):
        with rewriter.ip:
            new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
        rewriter.replace_op(op, new_op.owner)

    def constant_1_to_2(op, rewriter):
        c = op.attributes["value"].value
        if c != 1:
            return True  # failed to match
        with rewriter.ip:
            new_op = arith.constant(op.result.type, 2, loc=op.location)
        rewriter.replace_op(op, [new_op])

    with Context():
        patterns = RewritePatternSet()
        patterns.add(arith.AddIOp, to_muli)
        patterns.add(arith.ConstantOp, constant_1_to_2)
        frozen = patterns.freeze()

        module = ModuleOp.parse(
            r"""
            module {
              func.func @add(%a: i64, %b: i64) -> i64 {
                %sum = arith.addi %a, %b : i64
                return %sum : i64
              }
            }
            """
        )

        apply_patterns_and_fold_greedily(module, frozen)
        # CHECK: %0 = arith.muli %arg0, %arg1 : i64
        # CHECK: return %0 : i64
        print(module)

        module = ModuleOp.parse(
            r"""
            module {
              func.func @const() -> (i64, i64) {
                %0 = arith.constant 1 : i64
                %1 = arith.constant 3 : i64
                return %0, %1 : i64, i64
              }
            }
            """
        )

        apply_patterns_and_fold_greedily(module, frozen)
        # CHECK: %c2_i64 = arith.constant 2 : i64
        # CHECK: %c3_i64 = arith.constant 3 : i64
        # CHECK: return %c2_i64, %c3_i64 : i64, i64
        print(module)