diff options
Diffstat (limited to 'mlir/test/python')
-rw-r--r-- | mlir/test/python/dialects/transform_structured_ext.py | 60 | ||||
-rw-r--r-- | mlir/test/python/rewrite.py | 7 |
2 files changed, 61 insertions, 6 deletions
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 8785d6d..d6b70dc 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -109,13 +109,29 @@ def testFuseOpCompact(target): ) # CHECK-LABEL: TEST: testFuseOpCompact # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] - # CHECK-SAME: interchange [0, 1] apply_cleanup = true + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8] + # CHECK-SAME: interchange [0, 1] {apply_cleanup} # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) @run @create_sequence +def testFuseOpCompactForall(target): + structured.FuseOp( + target, + tile_sizes=[4, 8], + apply_cleanup=True, + use_forall=True, + ) + # CHECK-LABEL: TEST: testFuseOpCompact + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse %{{.*}} tile_sizes [4, 8] + # CHECK-SAME: {apply_cleanup, use_forall} + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + +@run +@create_sequence def testFuseOpNoArg(target): structured.FuseOp(target) # CHECK-LABEL: TEST: testFuseOpNoArg @@ -126,13 +142,51 @@ def testFuseOpNoArg(target): @run @create_sequence +def testFuseOpParams(target): + structured.FuseOp( + target, + tile_sizes=[constant_param(4), Attribute.parse("8")], + tile_interchange=[constant_param(0), Attribute.parse("1")], + ) + # CHECK-LABEL: TEST: testFuseOpParams + # CHECK: transform.sequence + # CHECK-DAG: %[[P:.*]] = transform.param.constant 4 + # CHECK-DAG: %[[I:.*]] = transform.param.constant 0 + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse + # CHECK-SAME: tile_sizes [%[[P]], 8] + # CHECK-SAME: interchange [%[[I]], 1] + # CHECK-SAME: (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence +def testFuseOpHandles(target): + size1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) + ichange1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) + structured.FuseOp( + target, + tile_sizes=[size1, 8], + tile_interchange=[ichange1, 1], + ) + # CHECK-LABEL: TEST: testFuseOpHandles + # CHECK: transform.sequence + # CHECK: %[[H:.*]] = transform.structured.match + # CHECK: %[[I:.*]] = transform.structured.match + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse + # CHECK-SAME: tile_sizes [%[[H]], 8] + # CHECK-SAME: interchange [%[[I]], 1] + # CHECK-SAME: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence def testFuseOpAttributes(target): attr = DenseI64ArrayAttr.get([4, 8]) ichange = DenseI64ArrayAttr.get([0, 1]) structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange) # CHECK-LABEL: TEST: testFuseOpAttributes # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8] # CHECK-SAME: interchange [0, 1] # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index acf7db2..821e470 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -17,15 +17,16 @@ def run(f): def testRewritePattern(): def to_muli(op, rewriter): with rewriter.ip: - new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location) + assert isinstance(op, arith.AddIOp) + new_op = arith.muli(op.lhs, op.rhs, loc=op.location) rewriter.replace_op(op, new_op.owner) def constant_1_to_2(op, rewriter): - c = op.attributes["value"].value + c = op.value.value if c != 1: return True # failed to match with rewriter.ip: - new_op = arith.constant(op.result.type, 2, loc=op.location) + new_op = arith.constant(op.type, 2, loc=op.location) rewriter.replace_op(op, [new_op]) with Context(): |