diff options
author | Matthias Springer <mspringer@nvidia.com> | 2024-11-14 07:50:26 +0100 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2024-11-14 07:50:26 +0100 |
commit | 78c9896e7c526c81b34db32f19d92ca7eb84bb76 (patch) | |
tree | afe9ca67ab969c608692f6c2609010111d300f1e | |
parent | e9aee4fd80874f80556456f64c303ffb957bd614 (diff) | |
download | llvm-users/matthias-springer/tmp_token.zip llvm-users/matthias-springer/tmp_token.tar.gz llvm-users/matthias-springer/tmp_token.tar.bz2 |
tmp commitusers/matthias-springer/tmp_token
-rw-r--r-- | mlir/lib/AsmParser/AttributeParser.cpp | 90 | ||||
-rw-r--r-- | mlir/lib/AsmParser/TokenKinds.def | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Arith/canonicalize.mlir | 10 | ||||
-rw-r--r-- | mlir/test/IR/attribute.mlir | 54 | ||||
-rw-r--r-- | mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir | 36 | ||||
-rw-r--r-- | mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir | 4 |
6 files changed, 149 insertions, 47 deletions
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index efa65e4..68929d3 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -21,8 +21,10 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Endian.h" +#include <cmath> #include <optional> using namespace mlir; @@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) { // Parse floating point and integer attributes. case Token::floatliteral: + case Token::kw_inf: + case Token::kw_nan: return parseFloatAttr(type, /*isNegative=*/false); case Token::integer: return parseDecOrHexAttr(type, /*isNegative=*/false); @@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) { consumeToken(Token::minus); if (getToken().is(Token::integer)) return parseDecOrHexAttr(type, /*isNegative=*/true); - if (getToken().is(Token::floatliteral)) + if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) || + getToken().is(Token::kw_nan)) return parseFloatAttr(type, /*isNegative=*/true); return (emitWrongTokenError( @@ -340,12 +345,47 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { " in attribute dictionary"); } +static FailureOr<APFloat> +buildAPFloat(function_ref<InFlightDiagnostic()> emitError, const Token &literal, + FloatType type, bool isNegative) { + // Check type semantics to see if the parsed value is supported. Then build + // the attribute. + const auto &sem = type.getFloatSemantics(); + if (literal.is(Token::kw_inf)) { + if (!APFloat::semanticsHasInf(sem)) { + emitError() << "floating point type does not support infinity"; + return failure(); + } + return APFloat::getInf(sem, isNegative); + } else if (literal.is(Token::kw_nan)) { + if (!APFloat::semanticsHasNan(sem)) { + emitError() << "floating point type does not support NaN"; + return failure(); + } + return APFloat::getNaN(sem, isNegative); + } + auto val = literal.getFloatingPointValue(); + if (!val) { + emitError() << "floating point value too large for attribute"; + return failure(); + } + if (std::fpclassify(*val) == FP_ZERO && !APFloat::semanticsHasZero(sem)) { + emitError() << "floating point type does not support zero"; + return failure(); + } + bool unused; + APFloat apFloat(isNegative ? -*val : *val); + apFloat.convert(sem, APFloat::rmNearestTiesToEven, &unused); + return apFloat; +} + /// Parse a float attribute. Attribute Parser::parseFloatAttr(Type type, bool isNegative) { - auto val = getToken().getFloatingPointValue(); - if (!val) - return (emitError("floating point value too large for attribute"), nullptr); - consumeToken(Token::floatliteral); + // Get value token. + const Token literal = getToken(); + consumeToken(); + + // Optionally parse type. if (!type) { // Default to F64 when no type is specified. if (!consumeIf(Token::colon)) @@ -353,10 +393,18 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) { else if (!(type = parseType())) return nullptr; } - if (!isa<FloatType>(type)) - return (emitError("floating point value not valid for specified type"), + auto floatType = dyn_cast<FloatType>(type); + if (!floatType) + return (emitError(literal.getLoc(), + "floating point value not valid for specified type"), nullptr); - return FloatAttr::get(type, isNegative ? -*val : *val); + + auto emitErrorAtToken = [&]() { return emitError(literal.getLoc()); }; + FailureOr<APFloat> apFloat = + buildAPFloat(emitErrorAtToken, literal, floatType, isNegative); + if (failed(apFloat)) + return nullptr; + return FloatAttr::get(type, *apFloat); } /// Construct an APint from a parsed value, a known attribute type and @@ -623,7 +671,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, } // Check to see if floating point values were parsed. - if (token.is(Token::floatliteral)) { + if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) { return p.emitError(tokenLoc) << "expected integer elements, but parsed floating-point"; } @@ -672,22 +720,17 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, } // Check to see if any decimal integers or booleans were parsed. - if (!token.is(Token::floatliteral)) + if (!token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) return p.emitError() << "expected floating-point elements, but parsed integer"; // Build the float values from tokens. - auto val = token.getFloatingPointValue(); - if (!val) - return p.emitError("floating point value too large for attribute"); - - APFloat apVal(isNegative ? -*val : *val); - if (!eltTy.isF64()) { - bool unused; - apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, - &unused); - } - floatValues.push_back(apVal); + auto emitErrorAtToken = [&]() { return p.emitError(token.getLoc()); }; + FailureOr<APFloat> apFloat = + buildAPFloat(emitErrorAtToken, token, eltTy, isNegative); + if (failed(apFloat)) + return failure(); + floatValues.push_back(*apFloat); } return success(); } @@ -755,6 +798,8 @@ ParseResult TensorLiteralParser::parseElement() { // Parse a boolean element. case Token::kw_true: case Token::kw_false: + case Token::kw_inf: + case Token::kw_nan: case Token::floatliteral: case Token::integer: storage.emplace_back(/*isNegative=*/false, p.getToken()); @@ -764,7 +809,8 @@ ParseResult TensorLiteralParser::parseElement() { // Parse a signed integer or a negative floating-point element. case Token::minus: p.consumeToken(Token::minus); - if (!p.getToken().isAny(Token::floatliteral, Token::integer)) + if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan, + Token::integer)) return p.emitError("expected integer or floating point literal"); storage.emplace_back(/*isNegative=*/true, p.getToken()); p.consumeToken(); diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def index 49da8c3..9208c8a 100644 --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv) TOK_KEYWORD(for) TOK_KEYWORD(func) TOK_KEYWORD(index) +TOK_KEYWORD(inf) TOK_KEYWORD(loc) TOK_KEYWORD(max) TOK_KEYWORD(memref) TOK_KEYWORD(min) TOK_KEYWORD(mod) +TOK_KEYWORD(nan) TOK_KEYWORD(none) TOK_KEYWORD(offset) TOK_KEYWORD(size) diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index a386a17..c86b2b5 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 %c0 = arith.constant 0.0 : f32 - %inf = arith.constant 0x7F800000 : f32 + %inf = arith.constant inf : f32 %0 = arith.minimumf %c0, %arg0 : f32 %1 = arith.minimumf %arg0, %arg0 : f32 %2 = arith.minimumf %inf, %arg0 : f32 @@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 %c0 = arith.constant 0.0 : f32 - %-inf = arith.constant 0xFF800000 : f32 + %-inf = arith.constant -inf : f32 %0 = arith.maximumf %c0, %arg0 : f32 %1 = arith.maximumf %arg0, %arg0 : f32 %2 = arith.maximumf %-inf, %arg0 : f32 @@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 %c0 = arith.constant 0.0 : f32 - %inf = arith.constant 0x7F800000 : f32 + %inf = arith.constant inf : f32 %0 = arith.minnumf %c0, %arg0 : f32 %1 = arith.minnumf %arg0, %arg0 : f32 %2 = arith.minnumf %inf, %arg0 : f32 @@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 %c0 = arith.constant 0.0 : f32 - %-inf = arith.constant 0xFF800000 : f32 + %-inf = arith.constant -inf : f32 %0 = arith.maxnumf %c0, %arg0 : f32 %1 = arith.maxnumf %arg0, %arg0 : f32 %2 = arith.maxnumf %-inf, %arg0 : f32 @@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) { // CHECK-DAG: %[[T:.*]] = arith.constant true // CHECK-DAG: %[[F:.*]] = arith.constant false // CHECK: return %[[F]], %[[F]], %[[T]], %[[T]] - %nan = arith.constant 0x7fffffff : f32 + %nan = arith.constant nan : f32 %0 = arith.cmpf olt, %nan, %arg0 : f32 %1 = arith.cmpf olt, %arg0, %nan : f32 %2 = arith.cmpf ugt, %nan, %arg0 : f32 diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index a62de3f..78cfdbe 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -108,9 +108,63 @@ func.func @float_attrs_pass() { // CHECK: float_attr = 2.000000e+00 : f128 float_attr = 2. : f128 } : () -> () + "test.float_attrs"() { + // Note: nan/inf are printed in binary format because there may be multiple + // nan/inf representations. + // CHECK: float_attr = 0x7FC00000 : f32 + float_attr = nan : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0x7C : f32 + float_attr = nan : f8E4M3 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0xFFC00000 : f32 + float_attr = -nan : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0xFC : f32 + float_attr = -nan : f8E4M3 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0x7F800000 : f32 + float_attr = inf : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0x78 : f32 + float_attr = inf : f8E4M3 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0xFF800000 : f32 + float_attr = -inf : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0xF8 : f32 + float_attr = -inf : f8E4M3 + } : () -> () return } +// ----- + +func.func @float_nan_unsupported() { + "test.float_attrs"() { + // expected-error @below{{floating point type does not support NaN}} + float_attr = nan : f4E2M1FN + } : () -> () +} + +// ----- + +func.func @float_inf_unsupported() { + "test.float_attrs"() { + // expected-error @below{{floating point type does not support infinity}} + float_attr = inf : f4E2M1FN + } : () -> () +} + +// ----- + //===----------------------------------------------------------------------===// // Test integer attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir index b886119..28b656b 100644 --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -41,7 +41,7 @@ func.func @tanh() { call @tanh_8xf32(%v2) : (vector<8xf32>) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @tanh_f32(%nan) : (f32) -> () return @@ -87,15 +87,15 @@ func.func @log() { call @log_f32(%zero) : (f32) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @log_f32(%nan) : (f32) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @log_f32(%inf) : (f32) -> () // CHECK: -inf, nan, inf, 0.693147 - %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32> + %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32> call @log_4xf32(%special_vec) : (vector<4xf32>) -> () return @@ -141,11 +141,11 @@ func.func @log2() { call @log2_f32(%neg_one) : (f32) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @log2_f32(%inf) : (f32) -> () // CHECK: -inf, nan, inf, 1.58496 - %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32> + %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32> call @log2_4xf32(%special_vec) : (vector<4xf32>) -> () return @@ -192,11 +192,11 @@ func.func @log1p() { call @log1p_f32(%neg_two) : (f32) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @log1p_f32(%inf) : (f32) -> () // CHECK: -inf, nan, inf, 9.99995e-06 - %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32> + %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32> call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> () return @@ -247,7 +247,7 @@ func.func @erf() { call @erf_f32(%val7) : (f32) -> () // CHECK: -1 - %negativeInf = arith.constant 0xff800000 : f32 + %negativeInf = arith.constant -inf : f32 call @erf_f32(%negativeInf) : (f32) -> () // CHECK: -1, -1, -0.913759, -0.731446 @@ -263,11 +263,11 @@ func.func @erf() { call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> () // CHECK: 1 - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @erf_f32(%inf) : (f32) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @erf_f32(%nan) : (f32) -> () return @@ -306,15 +306,15 @@ func.func @exp() { call @exp_4xf32(%special_vec) : (vector<4xf32>) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @exp_f32(%inf) : (f32) -> () // CHECK: 0 - %negative_inf = arith.constant 0xff800000 : f32 + %negative_inf = arith.constant -inf : f32 call @exp_f32(%negative_inf) : (f32) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @exp_f32(%nan) : (f32) -> () return @@ -358,19 +358,19 @@ func.func @expm1() { call @expm1_8xf32(%v2) : (vector<8xf32>) -> () // CHECK: -1 - %neg_inf = arith.constant 0xff800000 : f32 + %neg_inf = arith.constant -inf : f32 call @expm1_f32(%neg_inf) : (f32) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @expm1_f32(%inf) : (f32) -> () // CHECK: -1, inf, 1e-10 - %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32> + %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32> call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @expm1_f32(%nan) : (f32) -> () return diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir index 80d559c..a574858 100644 --- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -647,7 +647,7 @@ func.func @sinh() { call @sinh_8xf32(%v3) : (vector<8xf32>) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @sinh_f32(%nan) : (f32) -> () return @@ -693,7 +693,7 @@ func.func @cosh() { call @cosh_8xf32(%v3) : (vector<8xf32>) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @cosh_f32(%nan) : (f32) -> () return |