aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/AsmParser/AttributeParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/AsmParser/AttributeParser.cpp')
-rw-r--r--mlir/lib/AsmParser/AttributeParser.cpp90
1 files changed, 68 insertions, 22 deletions
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e4..68929d3 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
+#include <cmath>
#include <optional>
using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
// Parse floating point and integer attributes.
case Token::floatliteral:
+ case Token::kw_inf:
+ case Token::kw_nan:
return parseFloatAttr(type, /*isNegative=*/false);
case Token::integer:
return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseDecOrHexAttr(type, /*isNegative=*/true);
- if (getToken().is(Token::floatliteral))
+ if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+ getToken().is(Token::kw_nan))
return parseFloatAttr(type, /*isNegative=*/true);
return (emitWrongTokenError(
@@ -340,12 +345,47 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
" in attribute dictionary");
}
+static FailureOr<APFloat>
+buildAPFloat(function_ref<InFlightDiagnostic()> emitError, const Token &literal,
+ FloatType type, bool isNegative) {
+ // Check type semantics to see if the parsed value is supported. Then build
+ // the attribute.
+ const auto &sem = type.getFloatSemantics();
+ if (literal.is(Token::kw_inf)) {
+ if (!APFloat::semanticsHasInf(sem)) {
+ emitError() << "floating point type does not support infinity";
+ return failure();
+ }
+ return APFloat::getInf(sem, isNegative);
+ } else if (literal.is(Token::kw_nan)) {
+ if (!APFloat::semanticsHasNan(sem)) {
+ emitError() << "floating point type does not support NaN";
+ return failure();
+ }
+ return APFloat::getNaN(sem, isNegative);
+ }
+ auto val = literal.getFloatingPointValue();
+ if (!val) {
+ emitError() << "floating point value too large for attribute";
+ return failure();
+ }
+ if (std::fpclassify(*val) == FP_ZERO && !APFloat::semanticsHasZero(sem)) {
+ emitError() << "floating point type does not support zero";
+ return failure();
+ }
+ bool unused;
+ APFloat apFloat(isNegative ? -*val : *val);
+ apFloat.convert(sem, APFloat::rmNearestTiesToEven, &unused);
+ return apFloat;
+}
+
/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
- auto val = getToken().getFloatingPointValue();
- if (!val)
- return (emitError("floating point value too large for attribute"), nullptr);
- consumeToken(Token::floatliteral);
+ // Get value token.
+ const Token literal = getToken();
+ consumeToken();
+
+ // Optionally parse type.
if (!type) {
// Default to F64 when no type is specified.
if (!consumeIf(Token::colon))
@@ -353,10 +393,18 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
else if (!(type = parseType()))
return nullptr;
}
- if (!isa<FloatType>(type))
- return (emitError("floating point value not valid for specified type"),
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
+ return (emitError(literal.getLoc(),
+ "floating point value not valid for specified type"),
nullptr);
- return FloatAttr::get(type, isNegative ? -*val : *val);
+
+ auto emitErrorAtToken = [&]() { return emitError(literal.getLoc()); };
+ FailureOr<APFloat> apFloat =
+ buildAPFloat(emitErrorAtToken, literal, floatType, isNegative);
+ if (failed(apFloat))
+ return nullptr;
+ return FloatAttr::get(type, *apFloat);
}
/// Construct an APint from a parsed value, a known attribute type and
@@ -623,7 +671,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
}
// Check to see if floating point values were parsed.
- if (token.is(Token::floatliteral)) {
+ if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}
@@ -672,22 +720,17 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
}
// Check to see if any decimal integers or booleans were parsed.
- if (!token.is(Token::floatliteral))
+ if (!token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan))
return p.emitError()
<< "expected floating-point elements, but parsed integer";
// Build the float values from tokens.
- auto val = token.getFloatingPointValue();
- if (!val)
- return p.emitError("floating point value too large for attribute");
-
- APFloat apVal(isNegative ? -*val : *val);
- if (!eltTy.isF64()) {
- bool unused;
- apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
- &unused);
- }
- floatValues.push_back(apVal);
+ auto emitErrorAtToken = [&]() { return p.emitError(token.getLoc()); };
+ FailureOr<APFloat> apFloat =
+ buildAPFloat(emitErrorAtToken, token, eltTy, isNegative);
+ if (failed(apFloat))
+ return failure();
+ floatValues.push_back(*apFloat);
}
return success();
}
@@ -755,6 +798,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a boolean element.
case Token::kw_true:
case Token::kw_false:
+ case Token::kw_inf:
+ case Token::kw_nan:
case Token::floatliteral:
case Token::integer:
storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -764,7 +809,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a signed integer or a negative floating-point element.
case Token::minus:
p.consumeToken(Token::minus);
- if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+ if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+ Token::integer))
return p.emitError("expected integer or floating point literal");
storage.emplace_back(/*isNegative=*/true, p.getToken());
p.consumeToken();