aboutsummaryrefslogtreecommitdiff
path: root/mlir/python
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2022-07-07 15:56:06 +0200
committerAlex Zinenko <zinenko@google.com>2022-07-12 12:36:28 +0000
commit3963b4d0dc5bf2bb92eedbab91e2c11653cd8f4e (patch)
tree68c87c71b2d30c659d546d2edccfb08babdc768e /mlir/python
parentcc309721d20c8e544ae7a10a66735ccf4981a11c (diff)
downloadllvm-3963b4d0dc5bf2bb92eedbab91e2c11653cd8f4e.zip
llvm-3963b4d0dc5bf2bb92eedbab91e2c11653cd8f4e.tar.gz
llvm-3963b4d0dc5bf2bb92eedbab91e2c11653cd8f4e.tar.bz2
[mlir] Transform op for multitile size generation
Introduce a structured transform op that emits IR computing the multi-tile sizes with requested parameters (target size and divisor) for the given structured op. The sizes may fold to arithmetic constant operations when the shape is constant. These operations may then be used to call the existing tiling transformation with a single non-zero dynamic size (i.e. perform strip-mining) for each of the dimensions separately, thus achieving multi-size tiling with optional loop interchange. A separate test exercises the entire script. Depends On D129217 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129287
Diffstat (limited to 'mlir/python')
-rw-r--r--mlir/python/mlir/dialects/_structured_transform_ops_ext.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index b6e078f..95bf2cc 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -110,6 +110,29 @@ class InterchangeOp:
ip=ip)
+class MultiTileSizesOp:
+ """Specialization for MultitileSizesOp class."""
+
+ def __init__(self,
+ target: Union[Operation, Value],
+ *,
+ dimension: Union[int, IntegerAttr],
+ target_size: Union[int, IntegerAttr],
+ divisor: Optional[Union[int, IntegerAttr]] = None,
+ loc=None,
+ ip=None):
+ super().__init__(
+ pdl.OperationType.get(),
+ pdl.OperationType.get(),
+ pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ dimension=_get_int64_attr(dimension),
+ target_size=_get_int64_attr(target_size),
+ divisor=_get_int64_attr(divisor if divisor else 1),
+ loc=loc,
+ ip=ip)
+
+
class PadOp:
"""Specialization for PadOp class."""