aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/integration/dialects/pdl.py
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python/integration/dialects/pdl.py')
-rw-r--r--mlir/test/python/integration/dialects/pdl.py91
1 files changed, 89 insertions, 2 deletions
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index c8e6197..fe27dd4 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f
+
def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
@@ -121,8 +122,10 @@ def load_myint_dialect():
# This PDL pattern is to fold constant additions,
-# i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1.
+# including two patterns:
+# 1. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1;
+# 2. add(x, 0) or add(0, x) -> x.
def get_pdl_pattern_fold():
m = Module.create()
i32 = IntegerType.get_signless(32)
@@ -237,3 +240,87 @@ def test_pdl_register_function_constraint(module_):
apply_patterns_and_fold_greedily(module_, frozen)
return module_
+
+
+# This pattern is to expand constant to additions
+# unless the constant is no more than 1,
+# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
+def get_pdl_pattern_expand():
+ m = Module.create()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(m.body):
+
+ @pdl.pattern(benefit=1, sym_name="myint_constant_expand")
+ def pat():
+ t = pdl.TypeOp(i32)
+ cst = pdl.AttributeOp()
+ pdl.apply_native_constraint([], "is_one", [cst])
+ op0 = pdl.OperationOp(
+ name="myint.constant", attributes={"value": cst}, types=[t]
+ )
+
+ @pdl.rewrite()
+ def rew():
+ expanded = pdl.apply_native_rewrite(
+ [pdl.OperationType.get()], "expand", [cst]
+ )
+ pdl.ReplaceOp(op0, with_op=expanded)
+
+ def is_one(rewriter, results, values):
+ cst = values[0].value
+ return cst <= 1
+
+ def expand(rewriter, results, values):
+ cst = values[0].value
+ c1 = cst // 2
+ c2 = cst - c1
+ with rewriter.ip:
+ op1 = Operation.create(
+ "myint.constant",
+ results=[i32],
+ attributes={"value": IntegerAttr.get(i32, c1)},
+ )
+ op2 = Operation.create(
+ "myint.constant",
+ results=[i32],
+ attributes={"value": IntegerAttr.get(i32, c2)},
+ )
+ res = Operation.create(
+ "myint.add", results=[i32], operands=[op1.result, op2.result]
+ )
+ results.append(res)
+
+ pdl_module = PDLModule(m)
+ pdl_module.register_constraint_function("is_one", is_one)
+ pdl_module.register_rewrite_function("expand", expand)
+ return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function_expand
+# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
+# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
+# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
+# CHECK: return %8 : i32
+@construct_and_print_in_module
+def test_pdl_register_function_expand(module_):
+ load_myint_dialect()
+
+ module_ = Module.parse(
+ """
+ func.func @f() -> i32 {
+ %0 = "myint.constant"() { value = 5 }: () -> (i32)
+ return %0 : i32
+ }
+ """
+ )
+
+ frozen = get_pdl_pattern_expand()
+ apply_patterns_and_fold_greedily(module_, frozen)
+
+ return module_