From 2dbaf265255a5fa9643a8092ec2dffa881d2cf93 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Wed, 13 Mar 2024 00:12:37 -0400 Subject: [mlir][ods] Fix generation of optional custom parsers (#84821) We need to generate `.has_value` for `OptionalParseResult`, also ensure that `auto result` doesn't conflict with `result` which is the variable name for `OperationState`. --- mlir/test/IR/custom-print-parse.mlir | 5 +++++ mlir/test/IR/invalid-custom-print-parse.mlir | 5 +++++ mlir/test/lib/Dialect/Test/TestDialect.cpp | 17 +++++++++++++++++ mlir/test/lib/Dialect/Test/TestOps.td | 11 +++++++++++ mlir/test/mlir-tblgen/attr-or-type-format.td | 2 +- mlir/test/mlir-tblgen/op-format.td | 8 ++++---- mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 2 +- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 8 ++++---- 8 files changed, 48 insertions(+), 10 deletions(-) diff --git a/mlir/test/IR/custom-print-parse.mlir b/mlir/test/IR/custom-print-parse.mlir index b157fd1..0eadc2e 100644 --- a/mlir/test/IR/custom-print-parse.mlir +++ b/mlir/test/IR/custom-print-parse.mlir @@ -14,4 +14,9 @@ module @dimension_list { test.custom_dimension_list_attr dimension_list = ? // CHECK: test.custom_dimension_list_attr dimension_list = ?x? test.custom_dimension_list_attr dimension_list = ?x? + + // CHECK: test.optional_custom_attr + test.optional_custom_attr bar + // CHECK: test.optional_custom_attr foo false + test.optional_custom_attr foo false } diff --git a/mlir/test/IR/invalid-custom-print-parse.mlir b/mlir/test/IR/invalid-custom-print-parse.mlir index 456b16c..00da145e 100644 --- a/mlir/test/IR/invalid-custom-print-parse.mlir +++ b/mlir/test/IR/invalid-custom-print-parse.mlir @@ -14,3 +14,8 @@ test.custom_dimension_list_attr dimension_list = -1 // expected-error@+2 {{expected ']'}} // expected-error@+1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list.}} test.custom_dimension_list_attr dimension_list = [2x3] + +// ----- + +// expected-error @below {{expected attribute value}} +test.optional_custom_attr foo diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 1ee52fc..380c74a 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -500,6 +500,23 @@ void AffineScopeOp::print(OpAsmPrinter &p) { } //===----------------------------------------------------------------------===// +// Test OptionalCustomAttrOp +//===----------------------------------------------------------------------===// + +static OptionalParseResult parseOptionalCustomParser(AsmParser &p, + IntegerAttr &result) { + if (succeeded(p.parseOptionalKeyword("foo"))) + return p.parseAttribute(result); + return {}; +} + +static void printOptionalCustomParser(AsmPrinter &p, Operation *, + IntegerAttr result) { + p << "foo "; + p.printAttribute(result); +} + +//===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index dfd2f21..e6c3601 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2048,6 +2048,17 @@ def CustomDimensionListAttrOp : TEST_Op<"custom_dimension_list_attr"> { }]; } +def OptionalCustomAttrOp : TEST_Op<"optional_custom_attr"> { + let description = [{ + Test using a custom directive as the optional group anchor and the first + element to parse. It is expected to return an `OptionalParseResult`. + }]; + let arguments = (ins OptionalAttr:$attr); + let assemblyFormat = [{ + attr-dict (custom($attr)^) : (`bar`)? + }]; +} + //===----------------------------------------------------------------------===// // Test OpAsmInterface. diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td index b9041e4..2884c4e 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -648,7 +648,7 @@ def TypeN : TestType<"TestP"> { // TYPE-LABEL: TestQType::parse // TYPE: if (auto result = [&]() -> ::mlir::OptionalParseResult { // TYPE: auto odsCustomResult = parseAB(odsParser -// TYPE: if (!odsCustomResult) return {}; +// TYPE: if (!odsCustomResult.has_value()) return {}; // TYPE: if (::mlir::failed(*odsCustomResult)) return ::mlir::failure(); // TYPE: return ::mlir::success(); // TYPE: }(); result.has_value() && ::mlir::failed(*result)) { diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td index 3250589..4a19ffb 100644 --- a/mlir/test/mlir-tblgen/op-format.td +++ b/mlir/test/mlir-tblgen/op-format.td @@ -93,14 +93,14 @@ def OptionalGroupC : TestFormat_Op<[{ }]>, Arguments<(ins DefaultValuedStrAttr:$a)>; // CHECK-LABEL: OptionalGroupD::parse -// CHECK: if (auto result = [&]() -> ::mlir::OptionalParseResult { +// CHECK: if (auto optResult = [&]() -> ::mlir::OptionalParseResult { // CHECK: auto odsResult = parseCustom(parser, aOperand, bOperand); -// CHECK: if (!odsResult) return {}; +// CHECK: if (!odsResult.has_value()) return {}; // CHECK: if (::mlir::failed(*odsResult)) return ::mlir::failure(); // CHECK: return ::mlir::success(); -// CHECK: }(); result.has_value() && ::mlir::failed(*result)) { +// CHECK: }(); optResult.has_value() && ::mlir::failed(*optResult)) { // CHECK: return ::mlir::failure(); -// CHECK: } else if (result.has_value()) { +// CHECK: } else if (optResult.has_value()) { // CHECK-LABEL: OptionalGroupD::print // CHECK-NEXT: if (((getA()) || (getB()))) { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index f8e0c83..6098808 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -622,7 +622,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx, } os.unindent() << ");\n"; if (isOptional) { - os << "if (!odsCustomResult) return {};\n"; + os << "if (!odsCustomResult.has_value()) return {};\n"; os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n"; } else { os << "if (::mlir::failed(odsCustomResult)) return {};\n"; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index eb8c0ab..1ffac05 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1025,7 +1025,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, body << ");\n"; if (isOptional) { - body << " if (!odsResult) return {};\n" + body << " if (!odsResult.has_value()) return {};\n" << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n"; } else { body << " if (odsResult) return ::mlir::failure();\n"; @@ -1285,13 +1285,13 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, region->name); } } else if (auto *custom = dyn_cast(firstElement)) { - body << " if (auto result = [&]() -> ::mlir::OptionalParseResult {\n"; + body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n"; genCustomDirectiveParser(custom, body, useProperties, opCppClassName, /*isOptional=*/true); body << " return ::mlir::success();\n" - << " }(); result.has_value() && ::mlir::failed(*result)) {\n" + << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n" << " return ::mlir::failure();\n" - << " } else if (result.has_value()) {\n"; + << " } else if (optResult.has_value()) {\n"; } genElementParsers(firstElement, thenElements.drop_front(), -- cgit v1.1