diff options
Diffstat (limited to 'mlir/lib/AsmParser/AttributeParser.cpp')
-rw-r--r-- | mlir/lib/AsmParser/AttributeParser.cpp | 90 |
1 files changed, 68 insertions, 22 deletions
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index efa65e4..68929d3 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -21,8 +21,10 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Endian.h" +#include <cmath> #include <optional> using namespace mlir; @@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) { // Parse floating point and integer attributes. case Token::floatliteral: + case Token::kw_inf: + case Token::kw_nan: return parseFloatAttr(type, /*isNegative=*/false); case Token::integer: return parseDecOrHexAttr(type, /*isNegative=*/false); @@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) { consumeToken(Token::minus); if (getToken().is(Token::integer)) return parseDecOrHexAttr(type, /*isNegative=*/true); - if (getToken().is(Token::floatliteral)) + if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) || + getToken().is(Token::kw_nan)) return parseFloatAttr(type, /*isNegative=*/true); return (emitWrongTokenError( @@ -340,12 +345,47 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { " in attribute dictionary"); } +static FailureOr<APFloat> +buildAPFloat(function_ref<InFlightDiagnostic()> emitError, const Token &literal, + FloatType type, bool isNegative) { + // Check type semantics to see if the parsed value is supported. Then build + // the attribute. + const auto &sem = type.getFloatSemantics(); + if (literal.is(Token::kw_inf)) { + if (!APFloat::semanticsHasInf(sem)) { + emitError() << "floating point type does not support infinity"; + return failure(); + } + return APFloat::getInf(sem, isNegative); + } else if (literal.is(Token::kw_nan)) { + if (!APFloat::semanticsHasNan(sem)) { + emitError() << "floating point type does not support NaN"; + return failure(); + } + return APFloat::getNaN(sem, isNegative); + } + auto val = literal.getFloatingPointValue(); + if (!val) { + emitError() << "floating point value too large for attribute"; + return failure(); + } + if (std::fpclassify(*val) == FP_ZERO && !APFloat::semanticsHasZero(sem)) { + emitError() << "floating point type does not support zero"; + return failure(); + } + bool unused; + APFloat apFloat(isNegative ? -*val : *val); + apFloat.convert(sem, APFloat::rmNearestTiesToEven, &unused); + return apFloat; +} + /// Parse a float attribute. Attribute Parser::parseFloatAttr(Type type, bool isNegative) { - auto val = getToken().getFloatingPointValue(); - if (!val) - return (emitError("floating point value too large for attribute"), nullptr); - consumeToken(Token::floatliteral); + // Get value token. + const Token literal = getToken(); + consumeToken(); + + // Optionally parse type. if (!type) { // Default to F64 when no type is specified. if (!consumeIf(Token::colon)) @@ -353,10 +393,18 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) { else if (!(type = parseType())) return nullptr; } - if (!isa<FloatType>(type)) - return (emitError("floating point value not valid for specified type"), + auto floatType = dyn_cast<FloatType>(type); + if (!floatType) + return (emitError(literal.getLoc(), + "floating point value not valid for specified type"), nullptr); - return FloatAttr::get(type, isNegative ? -*val : *val); + + auto emitErrorAtToken = [&]() { return emitError(literal.getLoc()); }; + FailureOr<APFloat> apFloat = + buildAPFloat(emitErrorAtToken, literal, floatType, isNegative); + if (failed(apFloat)) + return nullptr; + return FloatAttr::get(type, *apFloat); } /// Construct an APint from a parsed value, a known attribute type and @@ -623,7 +671,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, } // Check to see if floating point values were parsed. - if (token.is(Token::floatliteral)) { + if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) { return p.emitError(tokenLoc) << "expected integer elements, but parsed floating-point"; } @@ -672,22 +720,17 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, } // Check to see if any decimal integers or booleans were parsed. - if (!token.is(Token::floatliteral)) + if (!token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) return p.emitError() << "expected floating-point elements, but parsed integer"; // Build the float values from tokens. - auto val = token.getFloatingPointValue(); - if (!val) - return p.emitError("floating point value too large for attribute"); - - APFloat apVal(isNegative ? -*val : *val); - if (!eltTy.isF64()) { - bool unused; - apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, - &unused); - } - floatValues.push_back(apVal); + auto emitErrorAtToken = [&]() { return p.emitError(token.getLoc()); }; + FailureOr<APFloat> apFloat = + buildAPFloat(emitErrorAtToken, token, eltTy, isNegative); + if (failed(apFloat)) + return failure(); + floatValues.push_back(*apFloat); } return success(); } @@ -755,6 +798,8 @@ ParseResult TensorLiteralParser::parseElement() { // Parse a boolean element. case Token::kw_true: case Token::kw_false: + case Token::kw_inf: + case Token::kw_nan: case Token::floatliteral: case Token::integer: storage.emplace_back(/*isNegative=*/false, p.getToken()); @@ -764,7 +809,8 @@ ParseResult TensorLiteralParser::parseElement() { // Parse a signed integer or a negative floating-point element. case Token::minus: p.consumeToken(Token::minus); - if (!p.getToken().isAny(Token::floatliteral, Token::integer)) + if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan, + Token::integer)) return p.emitError("expected integer or floating point literal"); storage.emplace_back(/*isNegative=*/true, p.getToken()); p.consumeToken(); |