diff options
author | Mikhail Goncharov <goncharov.mikhail@gmail.com> | 2023-12-07 10:28:35 +0100 |
---|---|---|
committer | Mikhail Goncharov <goncharov.mikhail@gmail.com> | 2023-12-07 10:28:35 +0100 |
commit | 10879403e56b0ba4fde4676ed20ae658d32e3356 (patch) | |
tree | 683441c6f0d26c3a685f8f9bc612eeb6d5a63cc2 /mlir | |
parent | 9e8a7377421a13d06e496eaa9dca900e189e3d69 (diff) | |
download | llvm-10879403e56b0ba4fde4676ed20ae658d32e3356.zip llvm-10879403e56b0ba4fde4676ed20ae658d32e3356.tar.gz llvm-10879403e56b0ba4fde4676ed20ae658d32e3356.tar.bz2 |
Revert "[MLIR][Transform] Add attribute in MatchOp to filter by operand type (#67994)"
This reverts commit c4399130ae403acf4e6325b8b46a51bb6abf222f.
Test fails https://lab.llvm.org/buildbot/#/builders/272/builds/2757
Diffstat (limited to 'mlir')
3 files changed, 1 insertions, 84 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 77ed9db..de65f31 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -574,11 +574,6 @@ def MatchOp : Op<Transform_Dialect, "structured.match", - attribute: the matched op must have all specified attributes (with their specified values). - filter_result_type: the matched op must return exactly this one type. - - filter_operand_types: all the operands of the matched op must must be of - this type. If more than a type is specified, then the length of the list - must be equal to the number of operands in the matched op, and the match - will succeed only if the operand types match all the types in the list - in the order in which they are specified. Note: Only ops that satisfy all specified constraints are matched. @@ -600,8 +595,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match", OptionalAttr<StrArrayAttr>:$ops, OptionalAttr<MatchInterfaceEnum>:$interface, OptionalAttr<DictionaryAttr>:$op_attrs, - OptionalAttr<TypeAttr>:$filter_result_type, - OptionalAttr<TypeArrayAttr>:$filter_operand_types); + OptionalAttr<TypeAttr>:$filter_result_type); // TODO: variadic results when needed. let results = (outs TransformHandleTypeInterface:$results); @@ -615,7 +609,6 @@ def MatchOp : Op<Transform_Dialect, "structured.match", (`interface` `{` $interface^ `}`)? (`attributes` $op_attrs^)? (`filter_result_type` `=` $filter_result_type^)? - (`filter_operand_types` `=` $filter_operand_types^)? `in` $target attr-dict `:` functional-type($target, results) }]; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 54055ae..e371345 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1171,7 +1171,6 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter, } SmallVector<Operation *> res; - bool incorrectNumOperandTypes = false; auto matchFun = [&](Operation *op) { if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) return; @@ -1211,47 +1210,12 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter, return; } - if (getFilterOperandTypes().has_value()) { - mlir::ArrayAttr types = getFilterOperandTypes().value(); - auto operandTypes = op->getOperandTypes(); - - if (types.size() == 1) { - // All the operands must must be equal to the specified type - auto typeattr = - dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]); - Type t = typeattr.getValue().cast<::mlir::Type>(); - if (!llvm::all_of(op->getOperandTypes(), - [&](Type operandType) { return operandType == t; })) - return; - } else { - // The operand types must match all the types in the list (in the same - // order in with they are specified) - if (types.size() != operandTypes.size()) { - incorrectNumOperandTypes = true; - return; - } - - for (auto [attr, operandType] : - llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) { - auto typeattr = cast<mlir::TypeAttr>(attr); - Type type = typeattr.getValue().cast<::mlir::Type>(); - - if (type != operandType) - return; - } - } - } - // All constraints are satisfied. res.push_back(op); return; }; (*payloadOps.begin())->walk(matchFun); - if (incorrectNumOperandTypes) - return emitDefiniteFailure("If filter_operand_types contains more than a " - "type, then it must contain as much types as " - "the number of operands in the target ops"); results.set(cast<OpResult>(getResult()), res); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir index 3b30c18..7d48b1f 100644 --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -43,46 +43,6 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @by_operand_type() { - %c2 = arith.constant 2.0: f32 - %v = arith.constant 8: i32 - %r1 = math.fpowi %c2, %v : f32, i32 - // expected-remark @below {{matched op name}} - %r2 = arith.addf %c2, %c2 : f32 - // expected-remark @below {{matched op name}} - %r3 = arith.fptoui %r2 : f32 to i32 - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !transform.any_op): - %match_name1 = transform.structured.match - ops{["arith.fptoui"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op - transform.test_print_remark_at_operand %match_name1, "matched op name" : !transform.any_op - transform.test_consume_operand %match_name1 : !transform.any_op - - %match_name2 = transform.structured.match - ops{["arith.addf"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op - transform.test_print_remark_at_operand %match_name2, "matched op name" : !transform.any_op - transform.test_consume_operand %match_name2 : !transform.any_op - - %no_match_name1 = transform.structured.match - ops{["arith.fptoui"]} filter_operand_types = [i32] in %arg1 : (!transform.any_op) -> !transform.any_op - transform.test_print_remark_at_operand %no_match_name1, "should not match" : !transform.any_op - transform.test_consume_operand %no_match_name1 : !transform.any_op - - %no_match_name2 = transform.structured.match - ops{["math.fpowi"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op - transform.test_print_remark_at_operand %no_match_name2, "should not match" : !transform.any_op - transform.test_consume_operand %no_match_name2 : !transform.any_op - - // expected-error @+1 {{If filter_operand_types contains more than a type, then it must contain as much types as the number of operands in the target ops}} - %failure_match = transform.structured.match - ops{["arith.fptoui"]} filter_operand_types = [i32, i32] in %arg1 : (!transform.any_op) -> !transform.any_op -} - -// ----- - func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) { %c0 = arith.constant 0.0 : f32 // expected-remark @below {{tileable}} |