# RUN: %PYTHON %s 2>&1 | FileCheck %s from mlir.dialects import arith, func, pdl from mlir.dialects.builtin import module from mlir.ir import * from mlir.rewrite import * def construct_and_print_in_module(f): print("\nTEST:", f.__name__) with Context(), Location.unknown(): module = Module.create() with InsertionPoint(module.body): module = f(module) if module is not None: print(module) return f def get_pdl_patterns(): # Create a rewrite from add to mul. This will match # - operation name is arith.addi # - operands are index types. # - there are two operands. with Location.unknown(): m = Module.create() with InsertionPoint(m.body): # Change all arith.addi with index types to arith.muli. @pdl.pattern(benefit=1, sym_name="addi_to_mul") def pat(): # Match arith.addi with index types. index_type = pdl.TypeOp(IndexType.get()) operand0 = pdl.OperandOp(index_type) operand1 = pdl.OperandOp(index_type) op0 = pdl.OperationOp( name="arith.addi", args=[operand0, operand1], types=[index_type] ) # Replace the matched op with arith.muli. @pdl.rewrite() def rew(): newOp = pdl.OperationOp( name="arith.muli", args=[operand0, operand1], types=[index_type] ) pdl.ReplaceOp(op0, with_op=newOp) # Create a PDL module from module and freeze it. At this point the ownership # of the module is transferred to the PDL module. This ownership transfer is # not yet captured Python side/has sharp edges. So best to construct the # module and PDL module in same scope. # FIXME: This should be made more robust. return PDLModule(m).freeze() # CHECK-LABEL: TEST: test_add_to_mul # CHECK: arith.muli @construct_and_print_in_module def test_add_to_mul(module_): index_type = IndexType.get() # Create a test case. @module(sym_name="ir") def ir(): @func.func(index_type, index_type) def add_func(a, b): return arith.addi(a, b) frozen = get_pdl_patterns() # Could apply frozen pattern set multiple times. apply_patterns_and_fold_greedily(module_, frozen) return module_ # CHECK-LABEL: TEST: test_add_to_mul_with_op # CHECK: arith.muli @construct_and_print_in_module def test_add_to_mul_with_op(module_): index_type = IndexType.get() # Create a test case. @module(sym_name="ir") def ir(): @func.func(index_type, index_type) def add_func(a, b): return arith.addi(a, b) frozen = get_pdl_patterns() apply_patterns_and_fold_greedily(module_.operation, frozen) return module_ # If we use arith.constant and arith.addi here, # these C++-defined folding/canonicalization will be applied # implicitly in the greedy pattern rewrite driver to # make our Python-defined folding useless, # so here we define a new dialect to workaround this. def load_myint_dialect(): from mlir.dialects import irdl m = Module.create() with InsertionPoint(m.body): myint = irdl.dialect("myint") with InsertionPoint(myint.body): constant = irdl.operation_("constant") with InsertionPoint(constant.body): iattr = irdl.base(base_name="#builtin.integer") i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) irdl.attributes_([iattr], ["value"]) irdl.results_([i32], ["cst"], [irdl.Variadicity.single]) add = irdl.operation_("add") with InsertionPoint(add.body): i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) irdl.operands_( [i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single], ) irdl.results_([i32], ["res"], [irdl.Variadicity.single]) m.operation.verify() irdl.load_dialects(m) # This PDL pattern is to fold constant additions, # including two patterns: # 1. add(constant0, constant1) -> constant2 # where constant2 = constant0 + constant1; # 2. add(x, 0) or add(0, x) -> x. def get_pdl_pattern_fold(): m = Module.create() i32 = IntegerType.get_signless(32) with InsertionPoint(m.body): @pdl.pattern(benefit=1, sym_name="myint_add_fold") def pat(): t = pdl.TypeOp(i32) a0 = pdl.AttributeOp() a1 = pdl.AttributeOp() c0 = pdl.OperationOp( name="myint.constant", attributes={"value": a0}, types=[t] ) c1 = pdl.OperationOp( name="myint.constant", attributes={"value": a1}, types=[t] ) v0 = pdl.ResultOp(c0, 0) v1 = pdl.ResultOp(c1, 0) op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) @pdl.rewrite() def rew(): sum = pdl.apply_native_rewrite( [pdl.AttributeType.get()], "add_fold", [a0, a1] ) newOp = pdl.OperationOp( name="myint.constant", attributes={"value": sum}, types=[t] ) pdl.ReplaceOp(op0, with_op=newOp) @pdl.pattern(benefit=1, sym_name="myint_add_zero_fold") def pat(): t = pdl.TypeOp(i32) v0 = pdl.OperandOp() v1 = pdl.OperandOp() v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1]) op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) @pdl.rewrite() def rew(): pdl.ReplaceOp(op0, with_values=[v]) def add_fold(rewriter, results, values): a0, a1 = values results.append(IntegerAttr.get(i32, a0.value + a1.value)) def is_zero(value): op = value.owner if isinstance(op, Operation): return op.name == "myint.constant" and op.attributes["value"].value == 0 return False # Check if either operand is a constant zero, # and append the other operand to the results if so. def has_zero(rewriter, results, values): v0, v1 = values if is_zero(v0): results.append(v1) return False if is_zero(v1): results.append(v0) return False return True pdl_module = PDLModule(m) pdl_module.register_rewrite_function("add_fold", add_fold) pdl_module.register_constraint_function("has_zero", has_zero) return pdl_module.freeze() # CHECK-LABEL: TEST: test_pdl_register_function # CHECK: "myint.constant"() {value = 8 : i32} : () -> i32 @construct_and_print_in_module def test_pdl_register_function(module_): load_myint_dialect() module_ = Module.parse( """ %c0 = "myint.constant"() { value = 2 }: () -> (i32) %c1 = "myint.constant"() { value = 3 }: () -> (i32) %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32) "myint.add"(%x, %c1): (i32, i32) -> (i32) """ ) frozen = get_pdl_pattern_fold() apply_patterns_and_fold_greedily(module_, frozen) return module_ # CHECK-LABEL: TEST: test_pdl_register_function_constraint # CHECK: return %arg0 : i32 @construct_and_print_in_module def test_pdl_register_function_constraint(module_): load_myint_dialect() module_ = Module.parse( """ func.func @f(%x : i32) -> i32 { %c0 = "myint.constant"() { value = 1 }: () -> (i32) %c1 = "myint.constant"() { value = -1 }: () -> (i32) %a = "myint.add"(%c0, %c1): (i32, i32) -> (i32) %b = "myint.add"(%a, %x): (i32, i32) -> (i32) %c = "myint.add"(%b, %a): (i32, i32) -> (i32) func.return %c : i32 } """ ) frozen = get_pdl_pattern_fold() apply_patterns_and_fold_greedily(module_, frozen) return module_ # This pattern is to expand constant to additions # unless the constant is no more than 1, # e.g. 3 -> 1 + 2 -> 1 + (1 + 1). def get_pdl_pattern_expand(): m = Module.create() i32 = IntegerType.get_signless(32) with InsertionPoint(m.body): @pdl.pattern(benefit=1, sym_name="myint_constant_expand") def pat(): t = pdl.TypeOp(i32) cst = pdl.AttributeOp() pdl.apply_native_constraint([], "is_one", [cst]) op0 = pdl.OperationOp( name="myint.constant", attributes={"value": cst}, types=[t] ) @pdl.rewrite() def rew(): expanded = pdl.apply_native_rewrite( [pdl.OperationType.get()], "expand", [cst] ) pdl.ReplaceOp(op0, with_op=expanded) def is_one(rewriter, results, values): cst = values[0].value return cst <= 1 def expand(rewriter, results, values): cst = values[0].value c1 = cst // 2 c2 = cst - c1 with rewriter.ip: op1 = Operation.create( "myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c1)}, ) op2 = Operation.create( "myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c2)}, ) res = Operation.create( "myint.add", results=[i32], operands=[op1.result, op2.result] ) results.append(res) pdl_module = PDLModule(m) pdl_module.register_constraint_function("is_one", is_one) pdl_module.register_rewrite_function("expand", expand) return pdl_module.freeze() # CHECK-LABEL: TEST: test_pdl_register_function_expand # CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32 # CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32 # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32 # CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32 # CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32 # CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32 # CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32 # CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32 # CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32 # CHECK: return %8 : i32 @construct_and_print_in_module def test_pdl_register_function_expand(module_): load_myint_dialect() module_ = Module.parse( """ func.func @f() -> i32 { %0 = "myint.constant"() { value = 5 }: () -> (i32) return %0 : i32 } """ ) frozen = get_pdl_pattern_expand() apply_patterns_and_fold_greedily(module_, frozen) return module_