aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-11-14 07:50:26 +0100
committerMatthias Springer <mspringer@nvidia.com>2024-11-14 07:50:26 +0100
commit78c9896e7c526c81b34db32f19d92ca7eb84bb76 (patch)
treeafe9ca67ab969c608692f6c2609010111d300f1e
parente9aee4fd80874f80556456f64c303ffb957bd614 (diff)
downloadllvm-users/matthias-springer/tmp_token.zip
llvm-users/matthias-springer/tmp_token.tar.gz
llvm-users/matthias-springer/tmp_token.tar.bz2
-rw-r--r--mlir/lib/AsmParser/AttributeParser.cpp90
-rw-r--r--mlir/lib/AsmParser/TokenKinds.def2
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir10
-rw-r--r--mlir/test/IR/attribute.mlir54
-rw-r--r--mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir36
-rw-r--r--mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir4
6 files changed, 149 insertions, 47 deletions
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e4..68929d3 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
+#include <cmath>
#include <optional>
using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
// Parse floating point and integer attributes.
case Token::floatliteral:
+ case Token::kw_inf:
+ case Token::kw_nan:
return parseFloatAttr(type, /*isNegative=*/false);
case Token::integer:
return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseDecOrHexAttr(type, /*isNegative=*/true);
- if (getToken().is(Token::floatliteral))
+ if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+ getToken().is(Token::kw_nan))
return parseFloatAttr(type, /*isNegative=*/true);
return (emitWrongTokenError(
@@ -340,12 +345,47 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
" in attribute dictionary");
}
+static FailureOr<APFloat>
+buildAPFloat(function_ref<InFlightDiagnostic()> emitError, const Token &literal,
+ FloatType type, bool isNegative) {
+ // Check type semantics to see if the parsed value is supported. Then build
+ // the attribute.
+ const auto &sem = type.getFloatSemantics();
+ if (literal.is(Token::kw_inf)) {
+ if (!APFloat::semanticsHasInf(sem)) {
+ emitError() << "floating point type does not support infinity";
+ return failure();
+ }
+ return APFloat::getInf(sem, isNegative);
+ } else if (literal.is(Token::kw_nan)) {
+ if (!APFloat::semanticsHasNan(sem)) {
+ emitError() << "floating point type does not support NaN";
+ return failure();
+ }
+ return APFloat::getNaN(sem, isNegative);
+ }
+ auto val = literal.getFloatingPointValue();
+ if (!val) {
+ emitError() << "floating point value too large for attribute";
+ return failure();
+ }
+ if (std::fpclassify(*val) == FP_ZERO && !APFloat::semanticsHasZero(sem)) {
+ emitError() << "floating point type does not support zero";
+ return failure();
+ }
+ bool unused;
+ APFloat apFloat(isNegative ? -*val : *val);
+ apFloat.convert(sem, APFloat::rmNearestTiesToEven, &unused);
+ return apFloat;
+}
+
/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
- auto val = getToken().getFloatingPointValue();
- if (!val)
- return (emitError("floating point value too large for attribute"), nullptr);
- consumeToken(Token::floatliteral);
+ // Get value token.
+ const Token literal = getToken();
+ consumeToken();
+
+ // Optionally parse type.
if (!type) {
// Default to F64 when no type is specified.
if (!consumeIf(Token::colon))
@@ -353,10 +393,18 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
else if (!(type = parseType()))
return nullptr;
}
- if (!isa<FloatType>(type))
- return (emitError("floating point value not valid for specified type"),
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
+ return (emitError(literal.getLoc(),
+ "floating point value not valid for specified type"),
nullptr);
- return FloatAttr::get(type, isNegative ? -*val : *val);
+
+ auto emitErrorAtToken = [&]() { return emitError(literal.getLoc()); };
+ FailureOr<APFloat> apFloat =
+ buildAPFloat(emitErrorAtToken, literal, floatType, isNegative);
+ if (failed(apFloat))
+ return nullptr;
+ return FloatAttr::get(type, *apFloat);
}
/// Construct an APint from a parsed value, a known attribute type and
@@ -623,7 +671,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
}
// Check to see if floating point values were parsed.
- if (token.is(Token::floatliteral)) {
+ if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}
@@ -672,22 +720,17 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
}
// Check to see if any decimal integers or booleans were parsed.
- if (!token.is(Token::floatliteral))
+ if (!token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan))
return p.emitError()
<< "expected floating-point elements, but parsed integer";
// Build the float values from tokens.
- auto val = token.getFloatingPointValue();
- if (!val)
- return p.emitError("floating point value too large for attribute");
-
- APFloat apVal(isNegative ? -*val : *val);
- if (!eltTy.isF64()) {
- bool unused;
- apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
- &unused);
- }
- floatValues.push_back(apVal);
+ auto emitErrorAtToken = [&]() { return p.emitError(token.getLoc()); };
+ FailureOr<APFloat> apFloat =
+ buildAPFloat(emitErrorAtToken, token, eltTy, isNegative);
+ if (failed(apFloat))
+ return failure();
+ floatValues.push_back(*apFloat);
}
return success();
}
@@ -755,6 +798,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a boolean element.
case Token::kw_true:
case Token::kw_false:
+ case Token::kw_inf:
+ case Token::kw_nan:
case Token::floatliteral:
case Token::integer:
storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -764,7 +809,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a signed integer or a negative floating-point element.
case Token::minus:
p.consumeToken(Token::minus);
- if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+ if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+ Token::integer))
return p.emitError("expected integer or floating point literal");
storage.emplace_back(/*isNegative=*/true, p.getToken());
p.consumeToken();
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 49da8c3..9208c8a 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
TOK_KEYWORD(index)
+TOK_KEYWORD(inf)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)
TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mod)
+TOK_KEYWORD(nan)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
TOK_KEYWORD(size)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a17..c86b2b5 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minimumf %c0, %arg0 : f32
%1 = arith.minimumf %arg0, %arg0 : f32
%2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maximumf %c0, %arg0 : f32
%1 = arith.maximumf %arg0, %arg0 : f32
%2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
// CHECK-DAG: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
- %nan = arith.constant 0x7fffffff : f32
+ %nan = arith.constant nan : f32
%0 = arith.cmpf olt, %nan, %arg0 : f32
%1 = arith.cmpf olt, %arg0, %nan : f32
%2 = arith.cmpf ugt, %nan, %arg0 : f32
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f..78cfdbe 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f128
float_attr = 2. : f128
} : () -> ()
+ "test.float_attrs"() {
+ // Note: nan/inf are printed in binary format because there may be multiple
+ // nan/inf representations.
+ // CHECK: float_attr = 0x7FC00000 : f32
+ float_attr = nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7C : f32
+ float_attr = nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFFC00000 : f32
+ float_attr = -nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFC : f32
+ float_attr = -nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7F800000 : f32
+ float_attr = inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x78 : f32
+ float_attr = inf : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFF800000 : f32
+ float_attr = -inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xF8 : f32
+ float_attr = -inf : f8E4M3
+ } : () -> ()
return
}
+// -----
+
+func.func @float_nan_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support NaN}}
+ float_attr = nan : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
+func.func @float_inf_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support infinity}}
+ float_attr = inf : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Test integer attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b886119..28b656b 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -41,7 +41,7 @@ func.func @tanh() {
call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @tanh_f32(%nan) : (f32) -> ()
return
@@ -87,15 +87,15 @@ func.func @log() {
call @log_f32(%zero) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @log_f32(%nan) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 0.693147
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -141,11 +141,11 @@ func.func @log2() {
call @log2_f32(%neg_one) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log2_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 1.58496
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -192,11 +192,11 @@ func.func @log1p() {
call @log1p_f32(%neg_two) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log1p_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 9.99995e-06
- %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
+ %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -247,7 +247,7 @@ func.func @erf() {
call @erf_f32(%val7) : (f32) -> ()
// CHECK: -1
- %negativeInf = arith.constant 0xff800000 : f32
+ %negativeInf = arith.constant -inf : f32
call @erf_f32(%negativeInf) : (f32) -> ()
// CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
// CHECK: 1
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @erf_f32(%inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @erf_f32(%nan) : (f32) -> ()
return
@@ -306,15 +306,15 @@ func.func @exp() {
call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @exp_f32(%inf) : (f32) -> ()
// CHECK: 0
- %negative_inf = arith.constant 0xff800000 : f32
+ %negative_inf = arith.constant -inf : f32
call @exp_f32(%negative_inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @exp_f32(%nan) : (f32) -> ()
return
@@ -358,19 +358,19 @@ func.func @expm1() {
call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: -1
- %neg_inf = arith.constant 0xff800000 : f32
+ %neg_inf = arith.constant -inf : f32
call @expm1_f32(%neg_inf) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @expm1_f32(%inf) : (f32) -> ()
// CHECK: -1, inf, 1e-10
- %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+ %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @expm1_f32(%nan) : (f32) -> ()
return
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 80d559c..a574858 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -647,7 +647,7 @@ func.func @sinh() {
call @sinh_8xf32(%v3) : (vector<8xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @sinh_f32(%nan) : (f32) -> ()
return
@@ -693,7 +693,7 @@ func.func @cosh() {
call @cosh_8xf32(%v3) : (vector<8xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @cosh_f32(%nan) : (f32) -> ()
return