diff options
Diffstat (limited to 'mlir/test/python/integration/dialects/pdl.py')
-rw-r--r-- | mlir/test/python/integration/dialects/pdl.py | 91 |
1 files changed, 89 insertions, 2 deletions
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index c8e6197..fe27dd4 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -16,6 +16,7 @@ def construct_and_print_in_module(f): print(module) return f + def get_pdl_patterns(): # Create a rewrite from add to mul. This will match # - operation name is arith.addi @@ -121,8 +122,10 @@ def load_myint_dialect(): # This PDL pattern is to fold constant additions, -# i.e. add(constant0, constant1) -> constant2 -# where constant2 = constant0 + constant1. +# including two patterns: +# 1. add(constant0, constant1) -> constant2 +# where constant2 = constant0 + constant1; +# 2. add(x, 0) or add(0, x) -> x. def get_pdl_pattern_fold(): m = Module.create() i32 = IntegerType.get_signless(32) @@ -237,3 +240,87 @@ def test_pdl_register_function_constraint(module_): apply_patterns_and_fold_greedily(module_, frozen) return module_ + + +# This pattern is to expand constant to additions +# unless the constant is no more than 1, +# e.g. 3 -> 1 + 2 -> 1 + (1 + 1). +def get_pdl_pattern_expand(): + m = Module.create() + i32 = IntegerType.get_signless(32) + with InsertionPoint(m.body): + + @pdl.pattern(benefit=1, sym_name="myint_constant_expand") + def pat(): + t = pdl.TypeOp(i32) + cst = pdl.AttributeOp() + pdl.apply_native_constraint([], "is_one", [cst]) + op0 = pdl.OperationOp( + name="myint.constant", attributes={"value": cst}, types=[t] + ) + + @pdl.rewrite() + def rew(): + expanded = pdl.apply_native_rewrite( + [pdl.OperationType.get()], "expand", [cst] + ) + pdl.ReplaceOp(op0, with_op=expanded) + + def is_one(rewriter, results, values): + cst = values[0].value + return cst <= 1 + + def expand(rewriter, results, values): + cst = values[0].value + c1 = cst // 2 + c2 = cst - c1 + with rewriter.ip: + op1 = Operation.create( + "myint.constant", + results=[i32], + attributes={"value": IntegerAttr.get(i32, c1)}, + ) + op2 = Operation.create( + "myint.constant", + results=[i32], + attributes={"value": IntegerAttr.get(i32, c2)}, + ) + res = Operation.create( + "myint.add", results=[i32], operands=[op1.result, op2.result] + ) + results.append(res) + + pdl_module = PDLModule(m) + pdl_module.register_constraint_function("is_one", is_one) + pdl_module.register_rewrite_function("expand", expand) + return pdl_module.freeze() + + +# CHECK-LABEL: TEST: test_pdl_register_function_expand +# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32 +# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32 +# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32 +# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32 +# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32 +# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32 +# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32 +# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32 +# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32 +# CHECK: return %8 : i32 +@construct_and_print_in_module +def test_pdl_register_function_expand(module_): + load_myint_dialect() + + module_ = Module.parse( + """ + func.func @f() -> i32 { + %0 = "myint.constant"() { value = 5 }: () -> (i32) + return %0 : i32 + } + """ + ) + + frozen = get_pdl_pattern_expand() + apply_patterns_and_fold_greedily(module_, frozen) + + return module_ |