diff options
-rw-r--r-- | mlir/include/mlir/IR/BuiltinDialect.td | 2 | ||||
-rw-r--r-- | mlir/include/mlir/IR/BuiltinTypes.td | 25 | ||||
-rw-r--r-- | mlir/include/mlir/IR/OpImplementation.h | 5 | ||||
-rw-r--r-- | mlir/lib/AsmParser/DialectSymbolParser.cpp | 24 | ||||
-rw-r--r-- | mlir/lib/AsmParser/Parser.h | 24 | ||||
-rw-r--r-- | mlir/lib/AsmParser/TypeParser.cpp | 211 | ||||
-rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 72 | ||||
-rw-r--r-- | mlir/lib/IR/BuiltinTypes.cpp | 150 | ||||
-rw-r--r-- | mlir/test/IR/invalid-builtin-types.mlir | 10 | ||||
-rw-r--r-- | mlir/test/IR/invalid.mlir | 4 | ||||
-rw-r--r-- | mlir/test/IR/qualified-builtin.mlir | 11 |
11 files changed, 249 insertions, 289 deletions
diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td index c131107..a862717 100644 --- a/mlir/include/mlir/IR/BuiltinDialect.td +++ b/mlir/include/mlir/IR/BuiltinDialect.td @@ -22,7 +22,7 @@ def Builtin_Dialect : Dialect { let name = "builtin"; let cppNamespace = "::mlir"; let useDefaultAttributePrinterParser = 0; - let useDefaultTypePrinterParser = 0; + let useDefaultTypePrinterParser = 1; let extraClassDeclaration = [{ private: // Register the builtin Attributes. diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 4cade83d..f3a51d2 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -25,7 +25,8 @@ include "mlir/IR/BuiltinTypeInterfaces.td" // Base class for Builtin dialect types. class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [], string baseCppClass = "::mlir::Type"> - : TypeDef<Builtin_Dialect, name, traits, baseCppClass> { + : TypeDef<Builtin_Dialect, name, !listconcat(traits, [PrintTypeQualified]), + baseCppClass> { let mnemonic = ?; let typeName = "builtin." # typeMnemonic; } @@ -62,6 +63,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> { ]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; + + let mnemonic = "complex"; + let assemblyFormat = "`<` $elementType `>`"; } //===----------------------------------------------------------------------===// @@ -668,6 +672,9 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; + + let mnemonic = "memref"; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -698,6 +705,8 @@ def Builtin_None : Builtin_Type<"None", "none"> { let extraClassDeclaration = [{ static NoneType get(MLIRContext *context); }]; + + let mnemonic = "none"; } //===----------------------------------------------------------------------===// @@ -849,6 +858,9 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [ }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; + + let mnemonic = "tensor"; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -884,7 +896,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> { tuple<i32, f32, tensor<i1>, i5> ``` }]; - let parameters = (ins "ArrayRef<Type>":$types); + let parameters = (ins OptionalArrayRefParameter<"Type">:$types); let builders = [ TypeBuilder<(ins "TypeRange":$elementTypes), [{ return $_get($_ctxt, elementTypes); @@ -916,6 +928,9 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> { return getTypes()[index]; } }]; + + let mnemonic = "tuple"; + let assemblyFormat = "`<` (`>`) : ($types^ `>`)?"; } //===----------------------------------------------------------------------===// @@ -994,6 +1009,9 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; + + let mnemonic = "memref"; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -1043,6 +1061,9 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [ }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; + + let mnemonic = "tensor"; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 50e6cc5..2a5587d 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -187,6 +187,11 @@ public: /// provide a valid type for the attribute. virtual void printAttributeWithoutType(Attribute attr); + /// Print the given attribute without its type if and only if the type is the + /// default type for the given attribute. + /// E.g. '1 : i64' is printed as just '1'. + virtual void printAttributeWithoutDefaultType(Attribute attr); + /// Print the alias for the given attribute, return failure if no alias could /// be printed. virtual LogicalResult printAlias(Attribute attr); diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 80cce7e..400d263 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -322,6 +322,30 @@ Type Parser::parseExtendedType() { }); } +Type Parser::parseExtendedBuiltinType() { + // Initially set to just the mnemonic of the type. + llvm::StringRef symbolData = getToken().getSpelling(); + const char *startOfTypePos = symbolData.data(); + consumeToken(); + // Extend 'symbolData' to include the body if it is not a singleton type. + // Note that all types in the builtin type always use the pretty dialect form + // aka 'dialect.mnemonic<body>'. + if (getToken().is(Token::less)) + if (failed(parseDialectSymbolBody(symbolData))) + return nullptr; + + const char *endOfTypePos = getToken().getLoc().getPointer(); + + // With the body of the type captured, hand it off to the dialect parser. + resetToken(startOfTypePos); + CustomDialectAsmParser customParser(symbolData, *this); + Type type = builtinDialect->parseType(customParser); + + // Move the lexer past the type. + resetToken(endOfTypePos); + return type; +} + //===----------------------------------------------------------------------===// // mlir::parseAttribute/parseType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index b959e67..73080c8 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -11,6 +11,7 @@ #include "ParserState.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/OpImplementation.h" #include <optional> @@ -28,9 +29,14 @@ public: using Delimiter = OpAsmParser::Delimiter; Builder builder; + /// Cached instance of the builtin dialect for parsing builtins. + Dialect *builtinDialect; Parser(ParserState &state) - : builder(state.config.getContext()), state(state) {} + : builder(state.config.getContext()), + builtinDialect( + builder.getContext()->getLoadedDialect<BuiltinDialect>()), + state(state) {} // Helper methods to get stuff from the parser-global state. ParserState &getState() const { return state; } @@ -192,27 +198,19 @@ public: /// Parse an arbitrary type. Type parseType(); - /// Parse a complex type. - Type parseComplexType(); - /// Parse an extended type. Type parseExtendedType(); + /// Parse an extended type from the builtin dialect where the '!builtin' + /// prefix is missing. + Type parseExtendedBuiltinType(); + /// Parse a function type. Type parseFunctionType(); - /// Parse a memref type. - Type parseMemRefType(); - /// Parse a non function type. Type parseNonFunctionType(); - /// Parse a tensor type. - Type parseTensorType(); - - /// Parse a tuple type. - Type parseTupleType(); - /// Parse a vector type. VectorType parseVectorType(); ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions, diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 5da931b..95df69b 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -11,12 +11,9 @@ //===----------------------------------------------------------------------===// #include "Parser.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/IR/TensorEncoding.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -123,29 +120,6 @@ ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { return success(); } -/// Parse a complex type. -/// -/// complex-type ::= `complex` `<` type `>` -/// -Type Parser::parseComplexType() { - consumeToken(Token::kw_complex); - - // Parse the '<'. - if (parseToken(Token::less, "expected '<' in complex type")) - return nullptr; - - SMLoc elementTypeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType || - parseToken(Token::greater, "expected '>' in complex type")) - return nullptr; - if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType)) - return emitError(elementTypeLoc, "invalid element type for complex"), - nullptr; - - return ComplexType::get(elementType); -} - /// Parse a function type. /// /// function-type ::= type-list-parens `->` function-result-type @@ -162,95 +136,6 @@ Type Parser::parseFunctionType() { return builder.getFunctionType(arguments, results); } -/// Parse a memref type. -/// -/// memref-type ::= ranked-memref-type | unranked-memref-type -/// -/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type -/// (`,` layout-specification)? (`,` memory-space)? `>` -/// -/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` -/// -/// stride-list ::= `[` (dimension (`,` dimension)*)? `]` -/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list -/// layout-specification ::= semi-affine-map | strided-layout | attribute -/// memory-space ::= integer-literal | attribute -/// -Type Parser::parseMemRefType() { - SMLoc loc = getToken().getLoc(); - consumeToken(Token::kw_memref); - - if (parseToken(Token::less, "expected '<' in memref type")) - return nullptr; - - bool isUnranked; - SmallVector<int64_t, 4> dimensions; - - if (consumeIf(Token::star)) { - // This is an unranked memref type. - isUnranked = true; - if (parseXInDimensionList()) - return nullptr; - - } else { - isUnranked = false; - if (parseDimensionListRanked(dimensions)) - return nullptr; - } - - // Parse the element type. - auto typeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType) - return nullptr; - - // Check that memref is formed from allowed types. - if (!BaseMemRefType::isValidElementType(elementType)) - return emitError(typeLoc, "invalid memref element type"), nullptr; - - MemRefLayoutAttrInterface layout; - Attribute memorySpace; - - auto parseElt = [&]() -> ParseResult { - // Either it is MemRefLayoutAttrInterface or memory space attribute. - Attribute attr = parseAttribute(); - if (!attr) - return failure(); - - if (isa<MemRefLayoutAttrInterface>(attr)) { - layout = cast<MemRefLayoutAttrInterface>(attr); - } else if (memorySpace) { - return emitError("multiple memory spaces specified in memref type"); - } else { - memorySpace = attr; - return success(); - } - - if (isUnranked) - return emitError("cannot have affine map for unranked memref type"); - if (memorySpace) - return emitError("expected memory space to be last in memref type"); - - return success(); - }; - - // Parse a list of mappings and address space if present. - if (!consumeIf(Token::greater)) { - // Parse comma separated list of affine maps, followed by memory space. - if (parseToken(Token::comma, "expected ',' or '>' in memref type") || - parseCommaSeparatedListUntil(Token::greater, parseElt, - /*allowEmptyList=*/false)) { - return nullptr; - } - } - - if (isUnranked) - return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace); - - return getChecked<MemRefType>(loc, dimensions, elementType, layout, - memorySpace); -} - /// Parse any type except the function type. /// /// non-function-type ::= integer-type @@ -272,14 +157,12 @@ Type Parser::parseNonFunctionType() { switch (getToken().getKind()) { default: return (emitWrongTokenError("expected non-function type"), nullptr); - case Token::kw_memref: - return parseMemRefType(); case Token::kw_tensor: - return parseTensorType(); + case Token::kw_memref: case Token::kw_complex: - return parseComplexType(); case Token::kw_tuple: - return parseTupleType(); + case Token::kw_none: + return parseExtendedBuiltinType(); case Token::kw_vector: return parseVectorType(); // integer-type @@ -344,11 +227,6 @@ Type Parser::parseNonFunctionType() { consumeToken(Token::kw_index); return builder.getIndexType(); - // none-type - case Token::kw_none: - consumeToken(Token::kw_none); - return builder.getNoneType(); - // extended type case Token::exclamation_identifier: return parseExtendedType(); @@ -361,89 +239,6 @@ Type Parser::parseNonFunctionType() { } } -/// Parse a tensor type. -/// -/// tensor-type ::= `tensor` `<` dimension-list type `>` -/// dimension-list ::= dimension-list-ranked | `*x` -/// -Type Parser::parseTensorType() { - consumeToken(Token::kw_tensor); - - if (parseToken(Token::less, "expected '<' in tensor type")) - return nullptr; - - bool isUnranked; - SmallVector<int64_t, 4> dimensions; - - if (consumeIf(Token::star)) { - // This is an unranked tensor type. - isUnranked = true; - - if (parseXInDimensionList()) - return nullptr; - - } else { - isUnranked = false; - if (parseDimensionListRanked(dimensions)) - return nullptr; - } - - // Parse the element type. - auto elementTypeLoc = getToken().getLoc(); - auto elementType = parseType(); - - // Parse an optional encoding attribute. - Attribute encoding; - if (consumeIf(Token::comma)) { - auto parseResult = parseOptionalAttribute(encoding); - if (parseResult.has_value()) { - if (failed(parseResult.value())) - return nullptr; - if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) { - if (failed(v.verifyEncoding(dimensions, elementType, - [&] { return emitError(); }))) - return nullptr; - } - } - } - - if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) - return nullptr; - if (!TensorType::isValidElementType(elementType)) - return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; - - if (isUnranked) { - if (encoding) - return emitError("cannot apply encoding to unranked tensor"), nullptr; - return UnrankedTensorType::get(elementType); - } - return RankedTensorType::get(dimensions, elementType, encoding); -} - -/// Parse a tuple type. -/// -/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` -/// -Type Parser::parseTupleType() { - consumeToken(Token::kw_tuple); - - // Parse the '<'. - if (parseToken(Token::less, "expected '<' in tuple type")) - return nullptr; - - // Check for an empty tuple by directly parsing '>'. - if (consumeIf(Token::greater)) - return TupleType::get(getContext()); - - // Parse the element types and the '>'. - SmallVector<Type, 4> types; - if (parseTypeListNoParens(types) || - parseToken(Token::greater, "expected '>' in tuple type")) - return nullptr; - - return TupleType::get(getContext(), types); -} - /// Parse a vector type. /// /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 6b8b747..0679d41 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2132,6 +2132,13 @@ static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { /// Print the given dialect symbol to the stream. static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, StringRef dialectName, StringRef symString) { + // Treat the builtin dialect special by eliding the '<symPrefix>builtin' + // prefix. + if (dialectName == "builtin") { + os << symString; + return; + } + os << symPrefix << dialectName; // If this symbol name is simple enough, print it directly in pretty form, @@ -2599,64 +2606,6 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { printType(vectorTy.getElementType()); os << '>'; }) - .Case<RankedTensorType>([&](RankedTensorType tensorTy) { - os << "tensor<"; - printDimensionList(tensorTy.getShape()); - if (!tensorTy.getShape().empty()) - os << 'x'; - printType(tensorTy.getElementType()); - // Only print the encoding attribute value if set. - if (tensorTy.getEncoding()) { - os << ", "; - printAttribute(tensorTy.getEncoding()); - } - os << '>'; - }) - .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { - os << "tensor<*x"; - printType(tensorTy.getElementType()); - os << '>'; - }) - .Case<MemRefType>([&](MemRefType memrefTy) { - os << "memref<"; - printDimensionList(memrefTy.getShape()); - if (!memrefTy.getShape().empty()) - os << 'x'; - printType(memrefTy.getElementType()); - MemRefLayoutAttrInterface layout = memrefTy.getLayout(); - if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) { - os << ", "; - printAttribute(memrefTy.getLayout(), AttrTypeElision::May); - } - // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) { - os << ", "; - printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); - } - os << '>'; - }) - .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { - os << "memref<*x"; - printType(memrefTy.getElementType()); - // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) { - os << ", "; - printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); - } - os << '>'; - }) - .Case<ComplexType>([&](ComplexType complexTy) { - os << "complex<"; - printType(complexTy.getElementType()); - os << '>'; - }) - .Case<TupleType>([&](TupleType tupleTy) { - os << "tuple<"; - interleaveComma(tupleTy.getTypes(), - [&](Type type) { printType(type); }); - os << '>'; - }) - .Case<NoneType>([&](Type) { os << "none"; }) .Default([&](Type type) { return printDialectType(type); }); } @@ -2799,6 +2748,13 @@ void AsmPrinter::printAttributeWithoutType(Attribute attr) { impl->printAttribute(attr, Impl::AttrTypeElision::Must); } +void AsmPrinter::printAttributeWithoutDefaultType(Attribute attr) { + assert( + impl && + "expected AsmPrinter::printAttributeWithoutDefaultType to be overriden"); + impl->printAttribute(attr, Impl::AttrTypeElision::May); +} + void AsmPrinter::printKeywordOrString(StringRef keyword) { assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); ::printKeywordOrString(keyword, impl->getStream()); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 9b8ee3d..e160c0f 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -10,10 +10,13 @@ #include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/TensorEncoding.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/APFloat.h" @@ -26,6 +29,52 @@ using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// +// Custom printing and parsing +//===----------------------------------------------------------------------===// + +static ParseResult parseMemRefDimension(AsmParser &parser, + SmallVectorImpl<int64_t> &dimension, + bool &isUnranked) { + if (succeeded(parser.parseOptionalStar())) { + isUnranked = true; + return parser.parseXInDimensionList(); + } + + isUnranked = false; + return parser.parseDimensionList(dimension); +} + +static ParseResult parseMemRefSpaceAndLayout(AsmParser &parser, + MemRefLayoutAttrInterface &layout, + Attribute &memorySpace, + bool isUnranked) { + while (succeeded(parser.parseOptionalComma())) { + SMLoc loc = parser.getCurrentLocation(); + Attribute attr; + if (parser.parseAttribute(attr)) + return failure(); + + if (auto memRefLayout = dyn_cast<MemRefLayoutAttrInterface>(attr)) { + layout = memRefLayout; + } else if (memorySpace) { + return parser.emitError( + loc, "multiple memory spaces specified in memref type"); + } else { + memorySpace = attr; + continue; + } + + if (isUnranked) + return parser.emitError( + loc, "cannot have affine map for unranked memref type"); + if (memorySpace) + return parser.emitError( + loc, "expected memory space to be last in memref type"); + } + return success(); +} + +//===----------------------------------------------------------------------===// /// Tablegen Type Definitions //===----------------------------------------------------------------------===// @@ -340,6 +389,46 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, return checkTensorElementType(emitError, elementType); } +Type RankedTensorType::parse(AsmParser &parser) { + SmallVector<int64_t> dimension; + Type elementType; + bool isUnranked; + if (parser.parseLess() || + parseMemRefDimension(parser, dimension, isUnranked) || + parser.parseType(elementType)) + return nullptr; + + Attribute encoding; + if (succeeded(parser.parseOptionalComma())) { + SMLoc loc = parser.getCurrentLocation(); + if (parser.parseAttribute(encoding)) + return nullptr; + + if (isUnranked) { + parser.emitError(loc, "cannot apply encoding to unranked tensor"); + return nullptr; + } + } + + if (failed(parser.parseGreater())) + return nullptr; + + if (isUnranked) + return parser.getChecked<UnrankedTensorType>(elementType); + return parser.getChecked<RankedTensorType>(dimension, elementType, encoding); +} + +void RankedTensorType::print(AsmPrinter &printer) const { + printer << '<'; + printer.printDimensionList(getShape()); + if (!getShape().empty()) + printer << 'x'; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + printer << '>'; +} + //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// @@ -350,6 +439,14 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, return checkTensorElementType(emitError, elementType); } +Type UnrankedTensorType::parse(AsmParser &parser) { + return RankedTensorType::parse(parser); +} + +void UnrankedTensorType::print(AsmPrinter &printer) const { + printer << "<*x" << getElementType() << ">"; +} + //===----------------------------------------------------------------------===// // BaseMemRefType //===----------------------------------------------------------------------===// @@ -652,6 +749,44 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, return success(); } +Type MemRefType::parse(AsmParser &parser) { + SmallVector<int64_t> dimension; + Type elementType; + MemRefLayoutAttrInterface layout; + Attribute memorySpace; + bool isUnranked; + if (parser.parseLess() || + parseMemRefDimension(parser, dimension, isUnranked) || + parser.parseType(elementType) || + parseMemRefSpaceAndLayout(parser, layout, memorySpace, isUnranked) || + parser.parseGreater()) + return nullptr; + + if (isUnranked) + return parser.getChecked<UnrankedMemRefType>(elementType, memorySpace); + return parser.getChecked<MemRefType>(dimension, elementType, layout, + memorySpace); +} + +void MemRefType::print(AsmPrinter &printer) const { + printer << '<'; + printer.printDimensionList(getShape()); + if (!getShape().empty()) + printer << 'x'; + printer << getElementType(); + MemRefLayoutAttrInterface layout = getLayout(); + if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) { + printer << ", "; + printer.printAttributeWithoutDefaultType(getLayout()); + } + // Only print the memory space if it is the non-default one. + if (getMemorySpace()) { + printer << ", "; + printer.printAttributeWithoutDefaultType(getMemorySpace()); + } + printer << '>'; +} + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -672,6 +807,21 @@ UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, return success(); } +Type UnrankedMemRefType::parse(AsmParser &parser) { + return MemRefType::parse(parser); +} + +void UnrankedMemRefType::print(AsmPrinter &printer) const { + printer << "<*x"; + printer << getElementType(); + // Only print the memory space if it is the non-default one. + if (getMemorySpace()) { + printer << ", "; + printer.printAttributeWithoutDefaultType(getMemorySpace()); + } + printer << '>'; +} + // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( // i.e. single term). Accumulate the AffineExpr into the existing one. static void extractStridesFromTerm(AffineExpr e, diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir index 9884212..04995bf 100644 --- a/mlir/test/IR/invalid-builtin-types.mlir +++ b/mlir/test/IR/invalid-builtin-types.mlir @@ -27,7 +27,7 @@ func.func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expec // ----- // Test no map in memref type. -func.func @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}} +func.func @memrefs(memref<2x4xi8, >) // expected-error {{expected attribute}} // ----- // Test non-existent map in memref type. @@ -74,7 +74,7 @@ func.func private @memref_unfinished_strided() -> memref<?x?xf32, strided<>> // ----- -// expected-error @below {{expected a 64-bit signed integer or '?'}} +// expected-error @below {{unbalanced '[' character in pretty dialect name}} func.func private @memref_unfinished_stride_list() -> memref<?x?xf32, strided<[>> // ----- @@ -94,7 +94,7 @@ func.func private @memref_missing_offset_value() -> memref<?x?xf32, strided<[], // ----- -// expected-error @below {{expected '>'}} +// expected-error @below {{unbalanced '<' character in pretty dialect name}} func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<[], offset: 32)> // ----- @@ -170,12 +170,12 @@ func.func @bad_complex(complex<memref<2x4xi8>>) // ----- -// expected-error @+1 {{expected '<' in complex type}} +// expected-error @+1 {{expected '<'}} func.func @bad_complex(complex memref<2x4xi8>>) // ----- -// expected-error @+1 {{expected '>' in complex type}} +// expected-error @+1 {{unbalanced '<' character in pretty dialect name}} func.func @bad_complex(complex<i32) // ----- diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 861f4ef..1e01b47 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -419,12 +419,12 @@ func.func @invalid_unknown_type_dialect_name() -> !invalid.dialect<!x@#]!@#> // ----- -// expected-error @+1 {{expected '<' in tuple type}} +// expected-error @+1 {{expected '<'}} func.func @invalid_tuple_missing_less(tuple i32>) // ----- -// expected-error @+1 {{expected '>' in tuple type}} +// expected-error @+1 {{unbalanced '<' character in pretty dialect name}} func.func @invalid_tuple_missing_greater(tuple<i32) // ----- diff --git a/mlir/test/IR/qualified-builtin.mlir b/mlir/test/IR/qualified-builtin.mlir new file mode 100644 index 0000000..a2f9e63 --- /dev/null +++ b/mlir/test/IR/qualified-builtin.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: @test1 +// CHECK: -> tuple<> +func.func private @test1() -> !builtin.tuple<> + +// CHECK-LABEL: @test2 +// CHECK: -> none +func.func private @test2() -> !builtin.none + + |