aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python')
-rw-r--r--mlir/test/python/dialects/gpu/dialect.py93
-rw-r--r--mlir/test/python/dialects/openacc.py171
-rw-r--r--mlir/test/python/dialects/transform_structured_ext.py60
-rw-r--r--mlir/test/python/ir/operation.py28
-rw-r--r--mlir/test/python/rewrite.py7
5 files changed, 348 insertions, 11 deletions
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 26ee9f3..66c4018 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -1,6 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
+import mlir.ir as ir
import mlir.dialects.gpu as gpu
import mlir.dialects.gpu.passes
from mlir.passmanager import *
@@ -64,3 +65,95 @@ def testObjectAttr():
# CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
print(o)
assert o.kernels == kernelTable
+
+
+# CHECK-LABEL: testGPUFuncOp
+@run
+def testGPUFuncOp():
+ assert gpu.GPUFuncOp.__doc__ is not None
+ module = Module.create()
+ with InsertionPoint(module.body):
+ gpu_module_name = StringAttr.get("gpu_module")
+ gpumodule = gpu.GPUModuleOp(gpu_module_name)
+ block = gpumodule.bodyRegion.blocks.append()
+
+ def builder(func: gpu.GPUFuncOp) -> None:
+ gpu.GlobalIdOp(gpu.Dimension.x)
+ gpu.ReturnOp([])
+
+ with InsertionPoint(block):
+ name = StringAttr.get("kernel0")
+ func_type = ir.FunctionType.get(inputs=[], results=[])
+ type_attr = TypeAttr.get(func_type)
+ func = gpu.GPUFuncOp(type_attr, name)
+ func.attributes["sym_name"] = name
+ func.attributes["gpu.kernel"] = UnitAttr.get()
+
+ try:
+ func.entry_block
+ assert False, "Expected RuntimeError"
+ except RuntimeError as e:
+ assert (
+ str(e)
+ == "Entry block does not exist for kernel0. Do you need to call the add_entry_block() method on this GPUFuncOp?"
+ )
+
+ block = func.add_entry_block()
+ with InsertionPoint(block):
+ builder(func)
+
+ try:
+ func.add_entry_block()
+ assert False, "Expected RuntimeError"
+ except RuntimeError as e:
+ assert str(e) == "Entry block already exists for kernel0"
+
+ func = gpu.GPUFuncOp(
+ func_type,
+ sym_name="kernel1",
+ kernel=True,
+ body_builder=builder,
+ known_block_size=[1, 2, 3],
+ known_grid_size=DenseI32ArrayAttr.get([4, 5, 6]),
+ )
+
+ assert func.name.value == "kernel1"
+ assert func.function_type.value == func_type
+ assert func.arg_attrs == None
+ assert func.res_attrs == None
+ assert func.arguments == []
+ assert func.entry_block == func.body.blocks[0]
+ assert func.is_kernel
+ assert func.known_block_size == DenseI32ArrayAttr.get(
+ [1, 2, 3]
+ ), func.known_block_size
+ assert func.known_grid_size == DenseI32ArrayAttr.get(
+ [4, 5, 6]
+ ), func.known_grid_size
+
+ func = gpu.GPUFuncOp(
+ func_type,
+ sym_name="non_kernel_func",
+ body_builder=builder,
+ )
+ assert not func.is_kernel
+ assert func.known_block_size is None
+ assert func.known_grid_size is None
+
+ print(module)
+
+ # CHECK: gpu.module @gpu_module
+ # CHECK: gpu.func @kernel0() kernel {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
+ # CHECK: gpu.func @kernel1() kernel attributes
+ # CHECK-SAME: known_block_size = array<i32: 1, 2, 3>
+ # CHECK-SAME: known_grid_size = array<i32: 4, 5, 6>
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
+ # CHECK: gpu.func @non_kernel_func() {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
diff --git a/mlir/test/python/dialects/openacc.py b/mlir/test/python/dialects/openacc.py
new file mode 100644
index 0000000..8f2142a
--- /dev/null
+++ b/mlir/test/python/dialects/openacc.py
@@ -0,0 +1,171 @@
+# RUN: %PYTHON %s | FileCheck %s
+from unittest import result
+from mlir.ir import (
+ Context,
+ FunctionType,
+ Location,
+ Module,
+ InsertionPoint,
+ IntegerType,
+ IndexType,
+ MemRefType,
+ F32Type,
+ Block,
+ ArrayAttr,
+ Attribute,
+ UnitAttr,
+ StringAttr,
+ DenseI32ArrayAttr,
+ ShapedType,
+)
+from mlir.dialects import openacc, func, arith, memref
+from mlir.extras import types
+
+
+def run(f):
+ print("\n// TEST:", f.__name__)
+ with Context(), Location.unknown():
+ f()
+ return f
+
+
+@run
+def testParallelMemcpy():
+ module = Module.create()
+
+ dynamic = ShapedType.get_dynamic_size()
+ memref_f32_1d_any = MemRefType.get([dynamic], types.f32())
+
+ with InsertionPoint(module.body):
+ function_type = FunctionType.get(
+ [memref_f32_1d_any, memref_f32_1d_any, types.i64()], []
+ )
+ f = func.FuncOp(
+ type=function_type,
+ name="memcpy_idiom",
+ )
+ f.attributes["sym_visibility"] = StringAttr.get("public")
+
+ with InsertionPoint(f.add_entry_block()):
+ c1024 = arith.ConstantOp(types.i32(), 1024)
+ c128 = arith.ConstantOp(types.i32(), 128)
+
+ arg0, arg1, arg2 = f.arguments
+
+ copied = openacc.copyin(
+ acc_var=arg0.type,
+ var=arg0,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ created = openacc.create_(
+ acc_var=arg1.type,
+ var=arg1,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+
+ parallel_op = openacc.ParallelOp(
+ asyncOperands=[],
+ waitOperands=[],
+ numGangs=[c1024],
+ numWorkers=[],
+ vectorLength=[c128],
+ reductionOperands=[],
+ privateOperands=[],
+ firstprivateOperands=[],
+ dataClauseOperands=[],
+ )
+
+ # Set required device_type and segment attributes to satisfy verifier
+ acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")])
+ parallel_op.numGangsDeviceType = acc_device_none
+ parallel_op.numGangsSegments = DenseI32ArrayAttr.get([1])
+ parallel_op.vectorLengthDeviceType = acc_device_none
+
+ parallel_block = Block.create_at_start(parent=parallel_op.region, arg_types=[])
+
+ with InsertionPoint(parallel_block):
+ c0 = arith.ConstantOp(types.i64(), 0)
+ c1 = arith.ConstantOp(types.i64(), 1)
+
+ loop_op = openacc.LoopOp(
+ results_=[],
+ lowerbound=[c0],
+ upperbound=[f.arguments[2]],
+ step=[c1],
+ gangOperands=[],
+ workerNumOperands=[],
+ vectorOperands=[],
+ tileOperands=[],
+ cacheOperands=[],
+ privateOperands=[],
+ reductionOperands=[],
+ firstprivateOperands=[],
+ )
+
+ # Set loop attributes: gang and independent on device_type<none>
+ acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")])
+ loop_op.gang = acc_device_none
+ loop_op.independent = acc_device_none
+
+ loop_block = Block.create_at_start(
+ parent=loop_op.region, arg_types=[types.i64()]
+ )
+
+ with InsertionPoint(loop_block):
+ idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0])
+ val = memref.load(memref=copied, indices=[idx])
+ memref.store(value=val, memref=created, indices=[idx])
+ openacc.YieldOp([])
+
+ openacc.YieldOp([])
+
+ deleted = openacc.delete(
+ acc_var=copied,
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ copied = openacc.copyout(
+ acc_var=created,
+ var=arg1,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ func.ReturnOp([])
+
+ print(module)
+
+ # CHECK: TEST: testParallelMemcpy
+ # CHECK-LABEL: func.func public @memcpy_idiom(
+ # CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: i64) {
+ # CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : i32
+ # CHECK: %[[CONSTANT_1:.*]] = arith.constant 128 : i32
+ # CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ARG0]] : memref<?xf32>) -> memref<?xf32>
+ # CHECK: %[[CREATE_0:.*]] = acc.create varPtr(%[[ARG1]] : memref<?xf32>) -> memref<?xf32>
+ # CHECK: acc.parallel num_gangs({%[[CONSTANT_0]] : i32}) vector_length(%[[CONSTANT_1]] : i32) {
+ # CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i64
+ # CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : i64
+ # CHECK: acc.loop gang control(%[[VAL_0:.*]] : i64) = (%[[CONSTANT_2]] : i64) to (%[[ARG2]] : i64) step (%[[CONSTANT_3]] : i64) {
+ # CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_0]] : i64 to index
+ # CHECK: %[[LOAD_0:.*]] = memref.load %[[COPYIN_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32>
+ # CHECK: memref.store %[[LOAD_0]], %[[CREATE_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32>
+ # CHECK: acc.yield
+ # CHECK: } attributes {independent = [#acc.device_type<none>]}
+ # CHECK: acc.yield
+ # CHECK: }
+ # CHECK: acc.delete accPtr(%[[COPYIN_0]] : memref<?xf32>)
+ # CHECK: acc.copyout accPtr(%[[CREATE_0]] : memref<?xf32>) to varPtr(%[[ARG1]] : memref<?xf32>)
+ # CHECK: return
+ # CHECK: }
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 8785d6d..d6b70dc 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -109,13 +109,29 @@ def testFuseOpCompact(target):
)
# CHECK-LABEL: TEST: testFuseOpCompact
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
- # CHECK-SAME: interchange [0, 1] apply_cleanup = true
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
+ # CHECK-SAME: interchange [0, 1] {apply_cleanup}
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@run
@create_sequence
+def testFuseOpCompactForall(target):
+ structured.FuseOp(
+ target,
+ tile_sizes=[4, 8],
+ apply_cleanup=True,
+ use_forall=True,
+ )
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
+ # CHECK-SAME: {apply_cleanup, use_forall}
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
def testFuseOpNoArg(target):
structured.FuseOp(target)
# CHECK-LABEL: TEST: testFuseOpNoArg
@@ -126,13 +142,51 @@ def testFuseOpNoArg(target):
@run
@create_sequence
+def testFuseOpParams(target):
+ structured.FuseOp(
+ target,
+ tile_sizes=[constant_param(4), Attribute.parse("8")],
+ tile_interchange=[constant_param(0), Attribute.parse("1")],
+ )
+ # CHECK-LABEL: TEST: testFuseOpParams
+ # CHECK: transform.sequence
+ # CHECK-DAG: %[[P:.*]] = transform.param.constant 4
+ # CHECK-DAG: %[[I:.*]] = transform.param.constant 0
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
+ # CHECK-SAME: tile_sizes [%[[P]], 8]
+ # CHECK-SAME: interchange [%[[I]], 1]
+ # CHECK-SAME: (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
+def testFuseOpHandles(target):
+ size1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ ichange1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ structured.FuseOp(
+ target,
+ tile_sizes=[size1, 8],
+ tile_interchange=[ichange1, 1],
+ )
+ # CHECK-LABEL: TEST: testFuseOpHandles
+ # CHECK: transform.sequence
+ # CHECK: %[[H:.*]] = transform.structured.match
+ # CHECK: %[[I:.*]] = transform.structured.match
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
+ # CHECK-SAME: tile_sizes [%[[H]], 8]
+ # CHECK-SAME: interchange [%[[I]], 1]
+ # CHECK-SAME: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
def testFuseOpAttributes(target):
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
# CHECK-LABEL: TEST: testFuseOpAttributes
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: interchange [0, 1]
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index cb4cfc8c..1d4ede1 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -569,12 +569,30 @@ def testOperationAttributes():
# CHECK: Attribute value b'text'
print(f"Attribute value {sattr.value_bytes}")
+ # Python dict-style iteration
# We don't know in which order the attributes are stored.
- # CHECK-DAG: NamedAttribute(dependent="text")
- # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
- # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
- for attr in op.attributes:
- print(str(attr))
+ # CHECK-DAG: dependent
+ # CHECK-DAG: other.attribute
+ # CHECK-DAG: some.attribute
+ for name in op.attributes:
+ print(name)
+
+ # Basic dict-like introspection
+ # CHECK: True
+ print("some.attribute" in op.attributes)
+ # CHECK: False
+ print("missing" in op.attributes)
+ # CHECK: Keys: ['dependent', 'other.attribute', 'some.attribute']
+ print("Keys:", sorted(op.attributes.keys()))
+ # CHECK: Values count 3
+ print("Values count", len(op.attributes.values()))
+ # CHECK: Items count 3
+ print("Items count", len(op.attributes.items()))
+
+ # Dict() conversion test
+ d = {k: v.value for k, v in dict(op.attributes).items()}
+ # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
+ print("Dict mapping", d)
# Check that exceptions are raised as expected.
try:
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index acf7db2..821e470 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -17,15 +17,16 @@ def run(f):
def testRewritePattern():
def to_muli(op, rewriter):
with rewriter.ip:
- new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+ assert isinstance(op, arith.AddIOp)
+ new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
rewriter.replace_op(op, new_op.owner)
def constant_1_to_2(op, rewriter):
- c = op.attributes["value"].value
+ c = op.value.value
if c != 1:
return True # failed to match
with rewriter.ip:
- new_op = arith.constant(op.result.type, 2, loc=op.location)
+ new_op = arith.constant(op.type, 2, loc=op.location)
rewriter.replace_op(op, [new_op])
with Context():