aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMikhail Goncharov <goncharov.mikhail@gmail.com>2023-12-07 10:28:35 +0100
committerMikhail Goncharov <goncharov.mikhail@gmail.com>2023-12-07 10:28:35 +0100
commit10879403e56b0ba4fde4676ed20ae658d32e3356 (patch)
tree683441c6f0d26c3a685f8f9bc612eeb6d5a63cc2 /mlir
parent9e8a7377421a13d06e496eaa9dca900e189e3d69 (diff)
downloadllvm-10879403e56b0ba4fde4676ed20ae658d32e3356.zip
llvm-10879403e56b0ba4fde4676ed20ae658d32e3356.tar.gz
llvm-10879403e56b0ba4fde4676ed20ae658d32e3356.tar.bz2
Revert "[MLIR][Transform] Add attribute in MatchOp to filter by operand type (#67994)"
This reverts commit c4399130ae403acf4e6325b8b46a51bb6abf222f. Test fails https://lab.llvm.org/buildbot/#/builders/272/builds/2757
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td9
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp36
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-match.mlir40
3 files changed, 1 insertions, 84 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 77ed9db..de65f31 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -574,11 +574,6 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
- attribute: the matched op must have all specified attributes (with their
specified values).
- filter_result_type: the matched op must return exactly this one type.
- - filter_operand_types: all the operands of the matched op must must be of
- this type. If more than a type is specified, then the length of the list
- must be equal to the number of operands in the matched op, and the match
- will succeed only if the operand types match all the types in the list
- in the order in which they are specified.
Note: Only ops that satisfy all specified constraints are matched.
@@ -600,8 +595,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
OptionalAttr<StrArrayAttr>:$ops,
OptionalAttr<MatchInterfaceEnum>:$interface,
OptionalAttr<DictionaryAttr>:$op_attrs,
- OptionalAttr<TypeAttr>:$filter_result_type,
- OptionalAttr<TypeArrayAttr>:$filter_operand_types);
+ OptionalAttr<TypeAttr>:$filter_result_type);
// TODO: variadic results when needed.
let results = (outs TransformHandleTypeInterface:$results);
@@ -615,7 +609,6 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
(`interface` `{` $interface^ `}`)?
(`attributes` $op_attrs^)?
(`filter_result_type` `=` $filter_result_type^)?
- (`filter_operand_types` `=` $filter_operand_types^)?
`in` $target attr-dict
`:` functional-type($target, results)
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 54055ae..e371345 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1171,7 +1171,6 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
}
SmallVector<Operation *> res;
- bool incorrectNumOperandTypes = false;
auto matchFun = [&](Operation *op) {
if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
return;
@@ -1211,47 +1210,12 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
return;
}
- if (getFilterOperandTypes().has_value()) {
- mlir::ArrayAttr types = getFilterOperandTypes().value();
- auto operandTypes = op->getOperandTypes();
-
- if (types.size() == 1) {
- // All the operands must must be equal to the specified type
- auto typeattr =
- dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
- Type t = typeattr.getValue().cast<::mlir::Type>();
- if (!llvm::all_of(op->getOperandTypes(),
- [&](Type operandType) { return operandType == t; }))
- return;
- } else {
- // The operand types must match all the types in the list (in the same
- // order in with they are specified)
- if (types.size() != operandTypes.size()) {
- incorrectNumOperandTypes = true;
- return;
- }
-
- for (auto [attr, operandType] :
- llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
- auto typeattr = cast<mlir::TypeAttr>(attr);
- Type type = typeattr.getValue().cast<::mlir::Type>();
-
- if (type != operandType)
- return;
- }
- }
- }
-
// All constraints are satisfied.
res.push_back(op);
return;
};
(*payloadOps.begin())->walk(matchFun);
- if (incorrectNumOperandTypes)
- return emitDefiniteFailure("If filter_operand_types contains more than a "
- "type, then it must contain as much types as "
- "the number of operands in the target ops");
results.set(cast<OpResult>(getResult()), res);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 3b30c18..7d48b1f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -43,46 +43,6 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @by_operand_type() {
- %c2 = arith.constant 2.0: f32
- %v = arith.constant 8: i32
- %r1 = math.fpowi %c2, %v : f32, i32
- // expected-remark @below {{matched op name}}
- %r2 = arith.addf %c2, %c2 : f32
- // expected-remark @below {{matched op name}}
- %r3 = arith.fptoui %r2 : f32 to i32
- return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !transform.any_op):
- %match_name1 = transform.structured.match
- ops{["arith.fptoui"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.test_print_remark_at_operand %match_name1, "matched op name" : !transform.any_op
- transform.test_consume_operand %match_name1 : !transform.any_op
-
- %match_name2 = transform.structured.match
- ops{["arith.addf"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.test_print_remark_at_operand %match_name2, "matched op name" : !transform.any_op
- transform.test_consume_operand %match_name2 : !transform.any_op
-
- %no_match_name1 = transform.structured.match
- ops{["arith.fptoui"]} filter_operand_types = [i32] in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.test_print_remark_at_operand %no_match_name1, "should not match" : !transform.any_op
- transform.test_consume_operand %no_match_name1 : !transform.any_op
-
- %no_match_name2 = transform.structured.match
- ops{["math.fpowi"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.test_print_remark_at_operand %no_match_name2, "should not match" : !transform.any_op
- transform.test_consume_operand %no_match_name2 : !transform.any_op
-
- // expected-error @+1 {{If filter_operand_types contains more than a type, then it must contain as much types as the number of operands in the target ops}}
- %failure_match = transform.structured.match
- ops{["arith.fptoui"]} filter_operand_types = [i32, i32] in %arg1 : (!transform.any_op) -> !transform.any_op
-}
-
-// -----
-
func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
%c0 = arith.constant 0.0 : f32
// expected-remark @below {{tileable}}