aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/AsmParser/TypeParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/AsmParser/TypeParser.cpp')
-rw-r--r--mlir/lib/AsmParser/TypeParser.cpp211
1 files changed, 3 insertions, 208 deletions
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 5da931b..95df69b 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -11,12 +11,9 @@
//===----------------------------------------------------------------------===//
#include "Parser.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -123,29 +120,6 @@ ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
return success();
}
-/// Parse a complex type.
-///
-/// complex-type ::= `complex` `<` type `>`
-///
-Type Parser::parseComplexType() {
- consumeToken(Token::kw_complex);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in complex type"))
- return nullptr;
-
- SMLoc elementTypeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType ||
- parseToken(Token::greater, "expected '>' in complex type"))
- return nullptr;
- if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
- return emitError(elementTypeLoc, "invalid element type for complex"),
- nullptr;
-
- return ComplexType::get(elementType);
-}
-
/// Parse a function type.
///
/// function-type ::= type-list-parens `->` function-result-type
@@ -162,95 +136,6 @@ Type Parser::parseFunctionType() {
return builder.getFunctionType(arguments, results);
}
-/// Parse a memref type.
-///
-/// memref-type ::= ranked-memref-type | unranked-memref-type
-///
-/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
-/// (`,` layout-specification)? (`,` memory-space)? `>`
-///
-/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
-///
-/// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
-/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-/// layout-specification ::= semi-affine-map | strided-layout | attribute
-/// memory-space ::= integer-literal | attribute
-///
-Type Parser::parseMemRefType() {
- SMLoc loc = getToken().getLoc();
- consumeToken(Token::kw_memref);
-
- if (parseToken(Token::less, "expected '<' in memref type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked memref type.
- isUnranked = true;
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto typeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType)
- return nullptr;
-
- // Check that memref is formed from allowed types.
- if (!BaseMemRefType::isValidElementType(elementType))
- return emitError(typeLoc, "invalid memref element type"), nullptr;
-
- MemRefLayoutAttrInterface layout;
- Attribute memorySpace;
-
- auto parseElt = [&]() -> ParseResult {
- // Either it is MemRefLayoutAttrInterface or memory space attribute.
- Attribute attr = parseAttribute();
- if (!attr)
- return failure();
-
- if (isa<MemRefLayoutAttrInterface>(attr)) {
- layout = cast<MemRefLayoutAttrInterface>(attr);
- } else if (memorySpace) {
- return emitError("multiple memory spaces specified in memref type");
- } else {
- memorySpace = attr;
- return success();
- }
-
- if (isUnranked)
- return emitError("cannot have affine map for unranked memref type");
- if (memorySpace)
- return emitError("expected memory space to be last in memref type");
-
- return success();
- };
-
- // Parse a list of mappings and address space if present.
- if (!consumeIf(Token::greater)) {
- // Parse comma separated list of affine maps, followed by memory space.
- if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
- parseCommaSeparatedListUntil(Token::greater, parseElt,
- /*allowEmptyList=*/false)) {
- return nullptr;
- }
- }
-
- if (isUnranked)
- return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
-
- return getChecked<MemRefType>(loc, dimensions, elementType, layout,
- memorySpace);
-}
-
/// Parse any type except the function type.
///
/// non-function-type ::= integer-type
@@ -272,14 +157,12 @@ Type Parser::parseNonFunctionType() {
switch (getToken().getKind()) {
default:
return (emitWrongTokenError("expected non-function type"), nullptr);
- case Token::kw_memref:
- return parseMemRefType();
case Token::kw_tensor:
- return parseTensorType();
+ case Token::kw_memref:
case Token::kw_complex:
- return parseComplexType();
case Token::kw_tuple:
- return parseTupleType();
+ case Token::kw_none:
+ return parseExtendedBuiltinType();
case Token::kw_vector:
return parseVectorType();
// integer-type
@@ -344,11 +227,6 @@ Type Parser::parseNonFunctionType() {
consumeToken(Token::kw_index);
return builder.getIndexType();
- // none-type
- case Token::kw_none:
- consumeToken(Token::kw_none);
- return builder.getNoneType();
-
// extended type
case Token::exclamation_identifier:
return parseExtendedType();
@@ -361,89 +239,6 @@ Type Parser::parseNonFunctionType() {
}
}
-/// Parse a tensor type.
-///
-/// tensor-type ::= `tensor` `<` dimension-list type `>`
-/// dimension-list ::= dimension-list-ranked | `*x`
-///
-Type Parser::parseTensorType() {
- consumeToken(Token::kw_tensor);
-
- if (parseToken(Token::less, "expected '<' in tensor type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked tensor type.
- isUnranked = true;
-
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto elementTypeLoc = getToken().getLoc();
- auto elementType = parseType();
-
- // Parse an optional encoding attribute.
- Attribute encoding;
- if (consumeIf(Token::comma)) {
- auto parseResult = parseOptionalAttribute(encoding);
- if (parseResult.has_value()) {
- if (failed(parseResult.value()))
- return nullptr;
- if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
- if (failed(v.verifyEncoding(dimensions, elementType,
- [&] { return emitError(); })))
- return nullptr;
- }
- }
- }
-
- if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
- return nullptr;
- if (!TensorType::isValidElementType(elementType))
- return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
-
- if (isUnranked) {
- if (encoding)
- return emitError("cannot apply encoding to unranked tensor"), nullptr;
- return UnrankedTensorType::get(elementType);
- }
- return RankedTensorType::get(dimensions, elementType, encoding);
-}
-
-/// Parse a tuple type.
-///
-/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
-///
-Type Parser::parseTupleType() {
- consumeToken(Token::kw_tuple);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in tuple type"))
- return nullptr;
-
- // Check for an empty tuple by directly parsing '>'.
- if (consumeIf(Token::greater))
- return TupleType::get(getContext());
-
- // Parse the element types and the '>'.
- SmallVector<Type, 4> types;
- if (parseTypeListNoParens(types) ||
- parseToken(Token::greater, "expected '>' in tuple type"))
- return nullptr;
-
- return TupleType::get(getContext(), types);
-}
-
/// Parse a vector type.
///
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`