diff options
author | Ingo Müller <ingomueller@google.com> | 2023-07-12 10:38:24 +0000 |
---|---|---|
committer | Ingo Müller <ingomueller@google.com> | 2023-07-19 14:02:29 +0000 |
commit | be6e9df11f880ab128aef6550c6911d9f091e7d7 (patch) | |
tree | 3fa4122bc7a8256b0b90ca69d1e9c55c5e8530b3 /mlir | |
parent | bd253a6a039937fbd0f2ba1f7dd5338a2920e24d (diff) | |
download | llvm-be6e9df11f880ab128aef6550c6911d9f091e7d7.zip llvm-be6e9df11f880ab128aef6550c6911d9f091e7d7.tar.gz llvm-be6e9df11f880ab128aef6550c6911d9f091e7d7.tar.bz2 |
[mlir][transform][linalg][python] Add extended TileToForallOp.
This patch adds a mixin for TileToForallOp to
_structured_transform_ops_ext.py with syntactic sugar for construction
such ops. First, the types of the results are made optional and filled
with common default values if omitted. Second, for num_threads and
tile_sizes, the three possible forms (static, dynamic, or packed), can
now all be given through the same respective argument, which gets
dispatched to the correct form-specific argument automatically.
Reviewed By: nicolasvasilache, ftynse
Differential Revision: https://reviews.llvm.org/D155090
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 135 | ||||
-rw-r--r-- | mlir/test/python/dialects/transform_structured_ext.py | 92 |
2 files changed, 226 insertions, 1 deletions
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 6407309..7f90a46 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -9,7 +9,7 @@ try: except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import List, Optional, Sequence, Union, overload +from typing import List, Optional, Sequence, Tuple, Union, overload IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] @@ -17,6 +17,47 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] BoolOrAttrList = Sequence[Union[BoolAttr, bool]] OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] +MixedValues = Union[ + Sequence[Union[int, IntegerAttr, Operation, Value, OpView]], + ArrayAttr, + Operation, + Value, + OpView, +] + + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values( + values: MixedValues, +) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(_get_op_result_or_value(size)) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] @@ -354,6 +395,98 @@ class TileOp: return [element for element in attr] +class TileToForallOp: + """Specialization for TileToForallOp class.""" + + @overload + def __init__( + self, + loops_type: Type, + tiled_op_type: Type, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loops_type_or_target: Union[ + Type, Union[Operation, Value, OpView] # loops_type + ], # target + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + num_threads: MixedValues = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + # `Type` arguments in the front are optional: add default values to front. + if isinstance(loops_type_or_target, Type): + # First overload: type arguments provided. + if not isinstance(tiled_op_type_or_none, Type): + raise TypeError( + "If 'loops_type_or_target' is a type, then " + "'tiled_op_type_or_none' is expected to be one as well." + ) + loops_type = loops_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + # Last overload: type arguments missing. + loops_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = loops_type_or_target + + # Unpack mixed num_threads. + ( + dynamic_num_threads, + packed_num_threads, + num_threads_attr, + ) = _dispatch_mixed_values(num_threads) + + # Unpack mixed tile_sizes. + ( + dynamic_tile_sizes, + packed_tile_sizes, + tile_sizes_attr, + ) = _dispatch_mixed_values(tile_sizes) + + super().__init__( + loops_type, + tiled_op_type, + target=target, + tile_sizes=dynamic_tile_sizes, + packed_tile_sizes=packed_tile_sizes, + static_tile_sizes=tile_sizes_attr, + num_threads=dynamic_num_threads, + packed_num_threads=packed_num_threads, + static_num_threads=num_threads_attr, + mapping=mapping, + loc=loc, + ip=ip, + ) + + class VectorizeOp: """Specialization for VectorizeOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 03a4716..1663ea3 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -256,6 +256,98 @@ def testTileExplicitLoopTypeAll(): @run +def testTileToForallCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("linalg.matmul"), + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp(sequence.bodyTarget, num_threads=[2, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallCompact + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] + # CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) + + +@run +def testTileToForallLoopsAndTileOpTypes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp( + transform.OperationType.get("scf.forall"), # loops_type + transform.OperationType.get("linalg.matmul"), # tiled_op_type + sequence.bodyTarget, + num_threads=[2, 3, 4], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] + # CHECK-SAME: (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">) + + +@run +def testTileToForallTileSizes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallTileSizes + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [] tile_sizes [2, 3, 4] + + +@run +def testTileToForallMixedDynamic(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) + structured.TileToForallOp(sequence.bodyTarget, num_threads=[n, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMixedDynamic + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [%{{.*}} : !pdl.operation, 3, 4] + + +@run +def testTileToForallMPackedDynamic(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) + structured.TileToForallOp(sequence.bodyTarget, num_threads=n) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMPackedDynamic + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads *(%0 : !pdl.operation) + + +@run +def testTileToForallMapping(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]") + structured.TileToForallOp( + sequence.bodyTarget, num_threads=[2, 3], mapping=mapping + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMapping + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>] + + +@run def testVectorize(): sequence = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() |