diff options
Diffstat (limited to 'mlir/lib/AsmParser/TypeParser.cpp')
-rw-r--r-- | mlir/lib/AsmParser/TypeParser.cpp | 211 |
1 files changed, 3 insertions, 208 deletions
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 5da931b..95df69b 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -11,12 +11,9 @@ //===----------------------------------------------------------------------===// #include "Parser.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/IR/TensorEncoding.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -123,29 +120,6 @@ ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { return success(); } -/// Parse a complex type. -/// -/// complex-type ::= `complex` `<` type `>` -/// -Type Parser::parseComplexType() { - consumeToken(Token::kw_complex); - - // Parse the '<'. - if (parseToken(Token::less, "expected '<' in complex type")) - return nullptr; - - SMLoc elementTypeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType || - parseToken(Token::greater, "expected '>' in complex type")) - return nullptr; - if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType)) - return emitError(elementTypeLoc, "invalid element type for complex"), - nullptr; - - return ComplexType::get(elementType); -} - /// Parse a function type. /// /// function-type ::= type-list-parens `->` function-result-type @@ -162,95 +136,6 @@ Type Parser::parseFunctionType() { return builder.getFunctionType(arguments, results); } -/// Parse a memref type. -/// -/// memref-type ::= ranked-memref-type | unranked-memref-type -/// -/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type -/// (`,` layout-specification)? (`,` memory-space)? `>` -/// -/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` -/// -/// stride-list ::= `[` (dimension (`,` dimension)*)? `]` -/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list -/// layout-specification ::= semi-affine-map | strided-layout | attribute -/// memory-space ::= integer-literal | attribute -/// -Type Parser::parseMemRefType() { - SMLoc loc = getToken().getLoc(); - consumeToken(Token::kw_memref); - - if (parseToken(Token::less, "expected '<' in memref type")) - return nullptr; - - bool isUnranked; - SmallVector<int64_t, 4> dimensions; - - if (consumeIf(Token::star)) { - // This is an unranked memref type. - isUnranked = true; - if (parseXInDimensionList()) - return nullptr; - - } else { - isUnranked = false; - if (parseDimensionListRanked(dimensions)) - return nullptr; - } - - // Parse the element type. - auto typeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType) - return nullptr; - - // Check that memref is formed from allowed types. - if (!BaseMemRefType::isValidElementType(elementType)) - return emitError(typeLoc, "invalid memref element type"), nullptr; - - MemRefLayoutAttrInterface layout; - Attribute memorySpace; - - auto parseElt = [&]() -> ParseResult { - // Either it is MemRefLayoutAttrInterface or memory space attribute. - Attribute attr = parseAttribute(); - if (!attr) - return failure(); - - if (isa<MemRefLayoutAttrInterface>(attr)) { - layout = cast<MemRefLayoutAttrInterface>(attr); - } else if (memorySpace) { - return emitError("multiple memory spaces specified in memref type"); - } else { - memorySpace = attr; - return success(); - } - - if (isUnranked) - return emitError("cannot have affine map for unranked memref type"); - if (memorySpace) - return emitError("expected memory space to be last in memref type"); - - return success(); - }; - - // Parse a list of mappings and address space if present. - if (!consumeIf(Token::greater)) { - // Parse comma separated list of affine maps, followed by memory space. - if (parseToken(Token::comma, "expected ',' or '>' in memref type") || - parseCommaSeparatedListUntil(Token::greater, parseElt, - /*allowEmptyList=*/false)) { - return nullptr; - } - } - - if (isUnranked) - return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace); - - return getChecked<MemRefType>(loc, dimensions, elementType, layout, - memorySpace); -} - /// Parse any type except the function type. /// /// non-function-type ::= integer-type @@ -272,14 +157,12 @@ Type Parser::parseNonFunctionType() { switch (getToken().getKind()) { default: return (emitWrongTokenError("expected non-function type"), nullptr); - case Token::kw_memref: - return parseMemRefType(); case Token::kw_tensor: - return parseTensorType(); + case Token::kw_memref: case Token::kw_complex: - return parseComplexType(); case Token::kw_tuple: - return parseTupleType(); + case Token::kw_none: + return parseExtendedBuiltinType(); case Token::kw_vector: return parseVectorType(); // integer-type @@ -344,11 +227,6 @@ Type Parser::parseNonFunctionType() { consumeToken(Token::kw_index); return builder.getIndexType(); - // none-type - case Token::kw_none: - consumeToken(Token::kw_none); - return builder.getNoneType(); - // extended type case Token::exclamation_identifier: return parseExtendedType(); @@ -361,89 +239,6 @@ Type Parser::parseNonFunctionType() { } } -/// Parse a tensor type. -/// -/// tensor-type ::= `tensor` `<` dimension-list type `>` -/// dimension-list ::= dimension-list-ranked | `*x` -/// -Type Parser::parseTensorType() { - consumeToken(Token::kw_tensor); - - if (parseToken(Token::less, "expected '<' in tensor type")) - return nullptr; - - bool isUnranked; - SmallVector<int64_t, 4> dimensions; - - if (consumeIf(Token::star)) { - // This is an unranked tensor type. - isUnranked = true; - - if (parseXInDimensionList()) - return nullptr; - - } else { - isUnranked = false; - if (parseDimensionListRanked(dimensions)) - return nullptr; - } - - // Parse the element type. - auto elementTypeLoc = getToken().getLoc(); - auto elementType = parseType(); - - // Parse an optional encoding attribute. - Attribute encoding; - if (consumeIf(Token::comma)) { - auto parseResult = parseOptionalAttribute(encoding); - if (parseResult.has_value()) { - if (failed(parseResult.value())) - return nullptr; - if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) { - if (failed(v.verifyEncoding(dimensions, elementType, - [&] { return emitError(); }))) - return nullptr; - } - } - } - - if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) - return nullptr; - if (!TensorType::isValidElementType(elementType)) - return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; - - if (isUnranked) { - if (encoding) - return emitError("cannot apply encoding to unranked tensor"), nullptr; - return UnrankedTensorType::get(elementType); - } - return RankedTensorType::get(dimensions, elementType, encoding); -} - -/// Parse a tuple type. -/// -/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` -/// -Type Parser::parseTupleType() { - consumeToken(Token::kw_tuple); - - // Parse the '<'. - if (parseToken(Token::less, "expected '<' in tuple type")) - return nullptr; - - // Check for an empty tuple by directly parsing '>'. - if (consumeIf(Token::greater)) - return TupleType::get(getContext()); - - // Parse the element types and the '>'. - SmallVector<Type, 4> types; - if (parseTypeListNoParens(types) || - parseToken(Token::greater, "expected '>' in tuple type")) - return nullptr; - - return TupleType::get(getContext(), types); -} - /// Parse a vector type. /// /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` |