aboutsummaryrefslogtreecommitdiff
path: root/mlir/python
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2022-10-05 14:23:19 +0000
committerAlex Zinenko <zinenko@google.com>2022-10-11 09:55:13 +0000
commit6fe03096025329add1b4e72273d9f631aa3acfda (patch)
treee70c433d0133b76775f531ec3a52e8b16bd2cdad /mlir/python
parentb586d56c7b8b0188355a4d7f5f8b09f8b3847757 (diff)
downloadllvm-6fe03096025329add1b4e72273d9f631aa3acfda.zip
llvm-6fe03096025329add1b4e72273d9f631aa3acfda.tar.gz
llvm-6fe03096025329add1b4e72273d9f631aa3acfda.tar.bz2
[mlir] switch transform dialect ops to use TransformTypeInterface
Use the recently introduced TransformTypeInterface instead of hardcoding the PDLOperationType. This will allow the operations to use more specific transform types to express pre/post-conditions in the future. It requires the syntax and Python op construction API to be updated. Dialect extensions will be switched separately. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135584
Diffstat (limited to 'mlir/python')
-rw-r--r--mlir/python/mlir/dialects/_transform_ops_ext.py46
1 files changed, 18 insertions, 28 deletions
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index 992139f..18cd3ad 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -5,7 +5,6 @@
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
- from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -21,9 +20,9 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]):
class GetClosestIsolatedParentOp:
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
- pdl.OperationType.get(),
+ result_type,
_get_op_result_or_value(target),
loc=loc,
ip=ip)
@@ -38,7 +37,7 @@ class MergeHandlesOp:
loc=None,
ip=None):
super().__init__(
- pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles],
+ [_get_op_result_or_value(h) for h in handles],
deduplicate=deduplicate,
loc=loc,
ip=ip)
@@ -47,13 +46,14 @@ class MergeHandlesOp:
class PDLMatchOp:
def __init__(self,
+ result_type: Type,
target: Union[Operation, Value],
pattern_name: Union[Attribute, str],
*,
loc=None,
ip=None):
super().__init__(
- pdl.OperationType.get(),
+ result_type,
_get_op_result_or_value(target),
_get_symbol_ref_attr(pattern_name),
loc=loc,
@@ -69,7 +69,7 @@ class ReplicateOp:
loc=None,
ip=None):
super().__init__(
- [pdl.OperationType.get()] * len(handles),
+ [_get_op_result_or_value(h).type for h in handles],
_get_op_result_or_value(pattern),
[_get_op_result_or_value(h) for h in handles],
loc=loc,
@@ -78,24 +78,11 @@ class ReplicateOp:
class SequenceOp:
- @overload
- def __init__(self, failure_propagation_mode,
- resultsOrRoot: Sequence[Type],
- optionalRoot: Optional[Union[Operation, Value]]):
- ...
-
- @overload
- def __init__(self, failure_propagation_mode,
- resultsOrRoot: Optional[Union[Operation,
- Value]], optionalRoot: NoneType):
- ...
-
- def __init__(self, failure_propagation_mode, resultsOrRoot=None, optionalRoot=None):
- results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else []
- root = (
- resultsOrRoot
- if not isinstance(resultsOrRoot, Sequence) else optionalRoot)
- root = _get_op_result_or_value(root) if root else None
+ def __init__(self, failure_propagation_mode, results: Sequence[Type],
+ target: Union[Operation, Value, Type]):
+ root = _get_op_result_or_value(target) if isinstance(
+ target, (Operation, Value)) else None
+ root_type = root.type if not isinstance(target, Type) else target
if not isinstance(failure_propagation_mode, Attribute):
failure_propagation_mode_attr = IntegerAttr.get(
IntegerType.get_signless(32), failure_propagation_mode._as_int())
@@ -104,7 +91,7 @@ class SequenceOp:
super().__init__(results_=results,
failure_propagation_mode=failure_propagation_mode_attr,
root=root)
- self.regions[0].blocks.append(pdl.OperationType.get())
+ self.regions[0].blocks.append(root_type)
@property
def body(self) -> Block:
@@ -118,15 +105,18 @@ class SequenceOp:
class WithPDLPatternsOp:
def __init__(self,
- target: Optional[Union[Operation, Value]] = None,
+ target: Union[Operation, Value, Type],
*,
loc=None,
ip=None):
+ root = _get_op_result_or_value(target) if not isinstance(target,
+ Type) else None
+ root_type = target if isinstance(target, Type) else root.type
super().__init__(
- root=_get_op_result_or_value(target) if target else None,
+ root=root,
loc=loc,
ip=ip)
- self.regions[0].blocks.append(pdl.OperationType.get())
+ self.regions[0].blocks.append(root_type)
@property
def body(self) -> Block: