diff options
author | Alex Zinenko <zinenko@google.com> | 2022-10-05 14:23:19 +0000 |
---|---|---|
committer | Alex Zinenko <zinenko@google.com> | 2022-10-11 09:55:13 +0000 |
commit | 6fe03096025329add1b4e72273d9f631aa3acfda (patch) | |
tree | e70c433d0133b76775f531ec3a52e8b16bd2cdad /mlir/python | |
parent | b586d56c7b8b0188355a4d7f5f8b09f8b3847757 (diff) | |
download | llvm-6fe03096025329add1b4e72273d9f631aa3acfda.zip llvm-6fe03096025329add1b4e72273d9f631aa3acfda.tar.gz llvm-6fe03096025329add1b4e72273d9f631aa3acfda.tar.bz2 |
[mlir] switch transform dialect ops to use TransformTypeInterface
Use the recently introduced TransformTypeInterface instead of hardcoding
the PDLOperationType. This will allow the operations to use more
specific transform types to express pre/post-conditions in the future.
It requires the syntax and Python op construction API to be updated.
Dialect extensions will be switched separately.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D135584
Diffstat (limited to 'mlir/python')
-rw-r--r-- | mlir/python/mlir/dialects/_transform_ops_ext.py | 46 |
1 files changed, 18 insertions, 28 deletions
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 992139f..18cd3ad 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -5,7 +5,6 @@ try: from ..ir import * from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values - from ..dialects import pdl except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -21,9 +20,9 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]): class GetClosestIsolatedParentOp: - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), loc=loc, ip=ip) @@ -38,7 +37,7 @@ class MergeHandlesOp: loc=None, ip=None): super().__init__( - pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles], + [_get_op_result_or_value(h) for h in handles], deduplicate=deduplicate, loc=loc, ip=ip) @@ -47,13 +46,14 @@ class MergeHandlesOp: class PDLMatchOp: def __init__(self, + result_type: Type, target: Union[Operation, Value], pattern_name: Union[Attribute, str], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), _get_symbol_ref_attr(pattern_name), loc=loc, @@ -69,7 +69,7 @@ class ReplicateOp: loc=None, ip=None): super().__init__( - [pdl.OperationType.get()] * len(handles), + [_get_op_result_or_value(h).type for h in handles], _get_op_result_or_value(pattern), [_get_op_result_or_value(h) for h in handles], loc=loc, @@ -78,24 +78,11 @@ class ReplicateOp: class SequenceOp: - @overload - def __init__(self, failure_propagation_mode, - resultsOrRoot: Sequence[Type], - optionalRoot: Optional[Union[Operation, Value]]): - ... - - @overload - def __init__(self, failure_propagation_mode, - resultsOrRoot: Optional[Union[Operation, - Value]], optionalRoot: NoneType): - ... - - def __init__(self, failure_propagation_mode, resultsOrRoot=None, optionalRoot=None): - results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] - root = ( - resultsOrRoot - if not isinstance(resultsOrRoot, Sequence) else optionalRoot) - root = _get_op_result_or_value(root) if root else None + def __init__(self, failure_propagation_mode, results: Sequence[Type], + target: Union[Operation, Value, Type]): + root = _get_op_result_or_value(target) if isinstance( + target, (Operation, Value)) else None + root_type = root.type if not isinstance(target, Type) else target if not isinstance(failure_propagation_mode, Attribute): failure_propagation_mode_attr = IntegerAttr.get( IntegerType.get_signless(32), failure_propagation_mode._as_int()) @@ -104,7 +91,7 @@ class SequenceOp: super().__init__(results_=results, failure_propagation_mode=failure_propagation_mode_attr, root=root) - self.regions[0].blocks.append(pdl.OperationType.get()) + self.regions[0].blocks.append(root_type) @property def body(self) -> Block: @@ -118,15 +105,18 @@ class SequenceOp: class WithPDLPatternsOp: def __init__(self, - target: Optional[Union[Operation, Value]] = None, + target: Union[Operation, Value, Type], *, loc=None, ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, + Type) else None + root_type = target if isinstance(target, Type) else root.type super().__init__( - root=_get_op_result_or_value(target) if target else None, + root=root, loc=loc, ip=ip) - self.regions[0].blocks.append(pdl.OperationType.get()) + self.regions[0].blocks.append(root_type) @property def body(self) -> Block: |