aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python')
-rw-r--r--mlir/test/python/dialects/transform_structured_ext.py60
-rw-r--r--mlir/test/python/rewrite.py7
2 files changed, 61 insertions, 6 deletions
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 8785d6d..d6b70dc 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -109,13 +109,29 @@ def testFuseOpCompact(target):
)
# CHECK-LABEL: TEST: testFuseOpCompact
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
- # CHECK-SAME: interchange [0, 1] apply_cleanup = true
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
+ # CHECK-SAME: interchange [0, 1] {apply_cleanup}
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@run
@create_sequence
+def testFuseOpCompactForall(target):
+ structured.FuseOp(
+ target,
+ tile_sizes=[4, 8],
+ apply_cleanup=True,
+ use_forall=True,
+ )
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
+ # CHECK-SAME: {apply_cleanup, use_forall}
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
def testFuseOpNoArg(target):
structured.FuseOp(target)
# CHECK-LABEL: TEST: testFuseOpNoArg
@@ -126,13 +142,51 @@ def testFuseOpNoArg(target):
@run
@create_sequence
+def testFuseOpParams(target):
+ structured.FuseOp(
+ target,
+ tile_sizes=[constant_param(4), Attribute.parse("8")],
+ tile_interchange=[constant_param(0), Attribute.parse("1")],
+ )
+ # CHECK-LABEL: TEST: testFuseOpParams
+ # CHECK: transform.sequence
+ # CHECK-DAG: %[[P:.*]] = transform.param.constant 4
+ # CHECK-DAG: %[[I:.*]] = transform.param.constant 0
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
+ # CHECK-SAME: tile_sizes [%[[P]], 8]
+ # CHECK-SAME: interchange [%[[I]], 1]
+ # CHECK-SAME: (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
+def testFuseOpHandles(target):
+ size1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ ichange1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ structured.FuseOp(
+ target,
+ tile_sizes=[size1, 8],
+ tile_interchange=[ichange1, 1],
+ )
+ # CHECK-LABEL: TEST: testFuseOpHandles
+ # CHECK: transform.sequence
+ # CHECK: %[[H:.*]] = transform.structured.match
+ # CHECK: %[[I:.*]] = transform.structured.match
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
+ # CHECK-SAME: tile_sizes [%[[H]], 8]
+ # CHECK-SAME: interchange [%[[I]], 1]
+ # CHECK-SAME: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
def testFuseOpAttributes(target):
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
# CHECK-LABEL: TEST: testFuseOpAttributes
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: interchange [0, 1]
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index acf7db2..821e470 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -17,15 +17,16 @@ def run(f):
def testRewritePattern():
def to_muli(op, rewriter):
with rewriter.ip:
- new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+ assert isinstance(op, arith.AddIOp)
+ new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
rewriter.replace_op(op, new_op.owner)
def constant_1_to_2(op, rewriter):
- c = op.attributes["value"].value
+ c = op.value.value
if c != 1:
return True # failed to match
with rewriter.ip:
- new_op = arith.constant(op.result.type, 2, loc=op.location)
+ new_op = arith.constant(op.type, 2, loc=op.location)
rewriter.replace_op(op, [new_op])
with Context():