aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2020-12-14 11:53:43 -0800
committerRiver Riddle <riddleriver@gmail.com>2020-12-14 12:00:43 -0800
commit6bc9439f59acbcc5e46a108c2f74a4d5ffe55a3b (patch)
tree83de8ef929354204259f0f757e1e3e1aa26ca57c
parentc234b65cef07b38c91b9ab7dec6a35f8b390e658 (diff)
downloadllvm-6bc9439f59acbcc5e46a108c2f74a4d5ffe55a3b.zip
llvm-6bc9439f59acbcc5e46a108c2f74a4d5ffe55a3b.tar.gz
llvm-6bc9439f59acbcc5e46a108c2f74a4d5ffe55a3b.tar.bz2
[mlir][OpAsmParser] Add support for parsing integer literals without going through IntegerAttr
Some operations use integer literals as part of their custom format that don't necessarily map to an internal IntegerAttr. This revision exposes the same `parseInteger` functions as the DialectAsmParser to allow for these operations to parse integer literals without incurring the otherwise unnecessary roundtrip through IntegerAttr. Differential Revision: https://reviews.llvm.org/D93152
-rw-r--r--mlir/include/mlir/IR/OpImplementation.h29
-rw-r--r--mlir/lib/Parser/DialectSymbolParser.cpp15
-rw-r--r--mlir/lib/Parser/Parser.cpp23
-rw-r--r--mlir/lib/Parser/Parser.h3
-rw-r--r--mlir/test/IR/parser.mlir15
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp28
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td8
7 files changed, 98 insertions, 23 deletions
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 31d3b42..f74eb52 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -413,6 +413,35 @@ public:
/// Parse a `...` token if present;
virtual ParseResult parseOptionalEllipsis() = 0;
+ /// Parse an integer value from the stream.
+ template <typename IntT> ParseResult parseInteger(IntT &result) {
+ auto loc = getCurrentLocation();
+ OptionalParseResult parseResult = parseOptionalInteger(result);
+ if (!parseResult.hasValue())
+ return emitError(loc, "expected integer value");
+ return *parseResult;
+ }
+
+ /// Parse an optional integer value from the stream.
+ virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0;
+
+ template <typename IntT>
+ OptionalParseResult parseOptionalInteger(IntT &result) {
+ auto loc = getCurrentLocation();
+
+ // Parse the unsigned variant.
+ uint64_t uintResult;
+ OptionalParseResult parseResult = parseOptionalInteger(uintResult);
+ if (!parseResult.hasValue() || failed(*parseResult))
+ return parseResult;
+
+ // Try to convert to the provided integer type.
+ result = IntT(uintResult);
+ if (uint64_t(result) != uintResult)
+ return emitError(loc, "integer value too large");
+ return success();
+ }
+
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 3bbc495..11e7e23 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -82,20 +82,7 @@ public:
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(uint64_t &result) override {
- Token curToken = parser.getToken();
- if (curToken.isNot(Token::integer, Token::minus))
- return llvm::None;
-
- bool negative = parser.consumeIf(Token::minus);
- Token curTok = parser.getToken();
- if (parser.parseToken(Token::integer, "expected integer value"))
- return failure();
-
- auto val = curTok.getUInt64IntegerValue();
- if (!val)
- return emitError(curTok.getLoc(), "integer value too large");
- result = negative ? -*val : *val;
- return success();
+ return parser.parseOptionalInteger(result);
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 47fef1c..58ed900 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -94,6 +94,24 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
return emitError(message);
}
+/// Parse an optional integer value from the stream.
+OptionalParseResult Parser::parseOptionalInteger(uint64_t &result) {
+ Token curToken = getToken();
+ if (curToken.isNot(Token::integer, Token::minus))
+ return llvm::None;
+
+ bool negative = consumeIf(Token::minus);
+ Token curTok = getToken();
+ if (parseToken(Token::integer, "expected integer value"))
+ return failure();
+
+ auto val = curTok.getUInt64IntegerValue();
+ if (!val)
+ return emitError(curTok.getLoc(), "integer value too large");
+ result = negative ? -*val : *val;
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// OperationParser
//===----------------------------------------------------------------------===//
@@ -1109,6 +1127,11 @@ public:
return success(parser.consumeIf(Token::star));
}
+ /// Parse an optional integer value from the stream.
+ OptionalParseResult parseOptionalInteger(uint64_t &result) override {
+ return parser.parseOptionalInteger(result);
+ }
+
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index e32cca4..0e9e4ca 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -127,6 +127,9 @@ public:
/// output a diagnostic and return failure.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
+ /// Parse an optional integer value from the stream.
+ OptionalParseResult parseOptionalInteger(uint64_t &result);
+
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 8fcb7863..ca61bf2 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1174,10 +1174,17 @@ func private @ptr_to_function() -> !unreg.ptr<() -> ()>
// CHECK-LABEL: func private @escaped_string_char(i1 {foo.value = "\0A"})
func private @escaped_string_char(i1 {foo.value = "\n"})
-// CHECK-LABEL: func @wrapped_keyword_test
-func @wrapped_keyword_test() {
- // CHECK: test.wrapped_keyword foo.keyword
- test.wrapped_keyword foo.keyword
+// CHECK-LABEL: func @parse_integer_literal_test
+func @parse_integer_literal_test() {
+ // CHECK: test.parse_integer_literal : 5
+ test.parse_integer_literal : 5
+ return
+}
+
+// CHECK-LABEL: func @parse_wrapped_keyword_test
+func @parse_wrapped_keyword_test() {
+ // CHECK: test.parse_wrapped_keyword foo.keyword
+ test.parse_wrapped_keyword foo.keyword
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 933b59d..5e9bae8 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -564,8 +564,28 @@ static void print(OpAsmPrinter &p, AffineScopeOp op) {
// Test parser.
//===----------------------------------------------------------------------===//
-static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
- OperationState &result) {
+static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
+ OperationState &result) {
+ if (parser.parseOptionalColon())
+ return success();
+ uint64_t numResults;
+ if (parser.parseInteger(numResults))
+ return failure();
+
+ IndexType type = parser.getBuilder().getIndexType();
+ for (unsigned i = 0; i < numResults; ++i)
+ result.addTypes(type);
+ return success();
+}
+
+static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
+ p << ParseIntegerLiteralOp::getOperationName();
+ if (unsigned numResults = op->getNumResults())
+ p << " : " << numResults;
+}
+
+static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
+ OperationState &result) {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return failure();
@@ -573,8 +593,8 @@ static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
return success();
}
-static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
- p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
+static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
+ p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9a7eb59..1fc419c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1293,7 +1293,13 @@ def TestMergeBlocksOp : TEST_Op<"merge_blocks"> {
// Test parser.
//===----------------------------------------------------------------------===//
-def WrappedKeywordOp : TEST_Op<"wrapped_keyword"> {
+def ParseIntegerLiteralOp : TEST_Op<"parse_integer_literal"> {
+ let results = (outs Variadic<Index>:$results);
+ let parser = [{ return ::parse$cppClass(parser, result); }];
+ let printer = [{ return ::print(p, *this); }];
+}
+
+def ParseWrappedKeywordOp : TEST_Op<"parse_wrapped_keyword"> {
let arguments = (ins StrAttr:$keyword);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];