aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorIngo Müller <ingomueller@google.com>2023-07-12 10:38:24 +0000
committerIngo Müller <ingomueller@google.com>2023-07-19 14:02:29 +0000
commitbe6e9df11f880ab128aef6550c6911d9f091e7d7 (patch)
tree3fa4122bc7a8256b0b90ca69d1e9c55c5e8530b3 /mlir
parentbd253a6a039937fbd0f2ba1f7dd5338a2920e24d (diff)
downloadllvm-be6e9df11f880ab128aef6550c6911d9f091e7d7.zip
llvm-be6e9df11f880ab128aef6550c6911d9f091e7d7.tar.gz
llvm-be6e9df11f880ab128aef6550c6911d9f091e7d7.tar.bz2
[mlir][transform][linalg][python] Add extended TileToForallOp.
This patch adds a mixin for TileToForallOp to _structured_transform_ops_ext.py with syntactic sugar for construction such ops. First, the types of the results are made optional and filled with common default values if omitted. Second, for num_threads and tile_sizes, the three possible forms (static, dynamic, or packed), can now all be given through the same respective argument, which gets dispatched to the correct form-specific argument automatically. Reviewed By: nicolasvasilache, ftynse Differential Revision: https://reviews.llvm.org/D155090
Diffstat (limited to 'mlir')
-rw-r--r--mlir/python/mlir/dialects/_structured_transform_ops_ext.py135
-rw-r--r--mlir/test/python/dialects/transform_structured_ext.py92
2 files changed, 226 insertions, 1 deletions
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 6407309..7f90a46 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -9,7 +9,7 @@ try:
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import List, Optional, Sequence, Union, overload
+from typing import List, Optional, Sequence, Tuple, Union, overload
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
@@ -17,6 +17,47 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
+MixedValues = Union[
+ Sequence[Union[int, IntegerAttr, Operation, Value, OpView]],
+ ArrayAttr,
+ Operation,
+ Value,
+ OpView,
+]
+
+
+# Dispatches `MixedValues` that all represents integers in various forms into
+# the following three categories:
+# - `dynamic_values`: a list of `Value`s, potentially from op results;
+# - `packed_values`: a value handle, potentially from an op result, associated
+# to one or more payload operations of integer type;
+# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
+# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
+# The input is in the form for `packed_values`, only that result is set and the
+# other two are empty. Otherwise, the input can be a mix of the other two forms,
+# and for each dynamic value, a special value is added to the `static_values`.
+def _dispatch_mixed_values(
+ values: MixedValues,
+) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
+ dynamic_values = []
+ packed_values = None
+ static_values = None
+ if isinstance(values, ArrayAttr):
+ static_values = values
+ elif isinstance(values, (Operation, Value, OpView)):
+ packed_values = values
+ else:
+ static_values = []
+ for size in values or []:
+ if isinstance(size, int):
+ static_values.append(size)
+ else:
+ static_values.append(ShapedType.get_dynamic_size())
+ dynamic_values.append(_get_op_result_or_value(size))
+ static_values = DenseI64ArrayAttr.get(static_values)
+
+ return (dynamic_values, packed_values, static_values)
+
def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
@@ -354,6 +395,98 @@ class TileOp:
return [element for element in attr]
+class TileToForallOp:
+ """Specialization for TileToForallOp class."""
+
+ @overload
+ def __init__(
+ self,
+ loops_type: Type,
+ tiled_op_type: Type,
+ target: Union[Operation, Value, OpView],
+ *,
+ num_threads: Optional[MixedValues] = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ num_threads: Optional[MixedValues] = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loops_type_or_target: Union[
+ Type, Union[Operation, Value, OpView] # loops_type
+ ], # target
+ tiled_op_type_or_none: Optional[Type] = None,
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ num_threads: MixedValues = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ # `Type` arguments in the front are optional: add default values to front.
+ if isinstance(loops_type_or_target, Type):
+ # First overload: type arguments provided.
+ if not isinstance(tiled_op_type_or_none, Type):
+ raise TypeError(
+ "If 'loops_type_or_target' is a type, then "
+ "'tiled_op_type_or_none' is expected to be one as well."
+ )
+ loops_type = loops_type_or_target
+ tiled_op_type = tiled_op_type_or_none
+ target = target_or_none
+ else:
+ # Last overload: type arguments missing.
+ loops_type = transform.AnyOpType.get()
+ tiled_op_type = transform.AnyOpType.get()
+ target = loops_type_or_target
+
+ # Unpack mixed num_threads.
+ (
+ dynamic_num_threads,
+ packed_num_threads,
+ num_threads_attr,
+ ) = _dispatch_mixed_values(num_threads)
+
+ # Unpack mixed tile_sizes.
+ (
+ dynamic_tile_sizes,
+ packed_tile_sizes,
+ tile_sizes_attr,
+ ) = _dispatch_mixed_values(tile_sizes)
+
+ super().__init__(
+ loops_type,
+ tiled_op_type,
+ target=target,
+ tile_sizes=dynamic_tile_sizes,
+ packed_tile_sizes=packed_tile_sizes,
+ static_tile_sizes=tile_sizes_attr,
+ num_threads=dynamic_num_threads,
+ packed_num_threads=packed_num_threads,
+ static_num_threads=num_threads_attr,
+ mapping=mapping,
+ loc=loc,
+ ip=ip,
+ )
+
+
class VectorizeOp:
"""Specialization for VectorizeOp class."""
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 03a4716..1663ea3 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -256,6 +256,98 @@ def testTileExplicitLoopTypeAll():
@run
+def testTileToForallCompact():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE,
+ [],
+ transform.OperationType.get("linalg.matmul"),
+ )
+ with InsertionPoint(sequence.body):
+ structured.TileToForallOp(sequence.bodyTarget, num_threads=[2, 3, 4])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileToForallCompact
+ # CHECK: = transform.structured.tile_to_forall_op
+ # CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
+ # CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
+
+
+@run
+def testTileToForallLoopsAndTileOpTypes():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.TileToForallOp(
+ transform.OperationType.get("scf.forall"), # loops_type
+ transform.OperationType.get("linalg.matmul"), # tiled_op_type
+ sequence.bodyTarget,
+ num_threads=[2, 3, 4],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes
+ # CHECK: = transform.structured.tile_to_forall_op
+ # CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
+ # CHECK-SAME: (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">)
+
+
+@run
+def testTileToForallTileSizes():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileToForallTileSizes
+ # CHECK: = transform.structured.tile_to_forall_op
+ # CHECK-SAME: num_threads [] tile_sizes [2, 3, 4]
+
+
+@run
+def testTileToForallMixedDynamic():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+ structured.TileToForallOp(sequence.bodyTarget, num_threads=[n, 3, 4])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileToForallMixedDynamic
+ # CHECK: = transform.structured.tile_to_forall_op
+ # CHECK-SAME: num_threads [%{{.*}} : !pdl.operation, 3, 4]
+
+
+@run
+def testTileToForallMPackedDynamic():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+ structured.TileToForallOp(sequence.bodyTarget, num_threads=n)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileToForallMPackedDynamic
+ # CHECK: = transform.structured.tile_to_forall_op
+ # CHECK-SAME: num_threads *(%0 : !pdl.operation)
+
+
+@run
+def testTileToForallMapping():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
+ structured.TileToForallOp(
+ sequence.bodyTarget, num_threads=[2, 3], mapping=mapping
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testTileToForallMapping
+ # CHECK: = transform.structured.tile_to_forall_op
+ # CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>]
+
+
+@run
def testVectorize():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()