From 8827ff92b96d78ef455157574061d745df2909af Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Mon, 1 Apr 2024 11:57:14 +0200 Subject: [MLIR][Arith] Add rounding mode attribute to `truncf` (#86152) Add rounding mode attribute to `arith`. This attribute can be used in different FP `arith` operations to control rounding mode. Rounding modes correspond to IEEE 754-specified rounding modes. Use in `arith.truncf` folding. As this is not supported in dialects other than LLVM, conversion should fail for now in case this attribute is present. --------- Signed-off-by: Victor Perez --- .../Conversion/ArithCommon/AttrToLLVMConverter.h | 48 ++++++++++++++++++++++ mlir/include/mlir/Dialect/Arith/IR/ArithBase.td | 25 +++++++++++ mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 21 ++++++++-- .../mlir/Dialect/Arith/IR/ArithOpsInterfaces.td | 33 +++++++++++++++ .../Conversion/ArithCommon/AttrToLLVMConverter.cpp | 31 ++++++++++++++ .../lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 3 ++ mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 32 ++++++++++++++- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 ++++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 46 ++++++++++++++++----- mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 5 +++ .../test/Conversion/ArithToLLVM/arith-to-llvm.mlir | 15 +++++++ mlir/test/Dialect/Arith/canonicalize.mlir | 45 ++++++++++++++++++++ mlir/test/Dialect/Arith/ops.mlir | 10 +++++ 13 files changed, 309 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h index 32d7979..0891e2b 100644 --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -36,6 +36,20 @@ convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags); LLVM::IntegerOverflowFlagsAttr convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr); +/// Creates an LLVM rounding mode enum value from a given arithmetic rounding +/// mode enum value. +LLVM::RoundingMode +convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode); + +/// Creates an LLVM rounding mode attribute from a given arithmetic rounding +/// mode attribute. +LLVM::RoundingModeAttr +convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr); + +/// Returns an attribute for the default LLVM FP exception behavior. +LLVM::FPExceptionBehaviorAttr +getLLVMDefaultFPExceptionBehavior(MLIRContext &context); + // Attribute converter that populates a NamedAttrList by removing the fastmath // attribute from the source operation attributes, and replacing it with an // equivalent LLVM fastmath attribute. @@ -89,6 +103,40 @@ public: private: NamedAttrList convertedAttr; }; + +template +class AttrConverterConstrainedFPToLLVM { + static_assert(TargetOp::template hasTrait< + LLVM::FPExceptionBehaviorOpInterface::Trait>(), + "Target constrained FP operations must implement " + "LLVM::FPExceptionBehaviorOpInterface"); + +public: + AttrConverterConstrainedFPToLLVM(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + + if constexpr (TargetOp::template hasTrait< + LLVM::RoundingModeOpInterface::Trait>()) { + // Get the name of the rounding mode attribute. + StringRef arithAttrName = srcOp.getRoundingModeAttrName(); + // Remove the source attribute. + auto arithAttr = + cast(convertedAttr.erase(arithAttrName)); + // Set the target attribute. + convertedAttr.set(TargetOp::getRoundingModeAttrName(), + convertArithRoundingModeAttrToLLVM(arithAttr)); + } + convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(), + getLLVMDefaultFPExceptionBehavior(*srcOp->getContext())); + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td index c8a42c4..19a2ade 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr : let assemblyFormat = "`<` $value `>`"; } +//===----------------------------------------------------------------------===// +// Arith_RoundingMode +//===----------------------------------------------------------------------===// + +// These correspond to LLVM's values defined in: +// llvm/include/llvm/ADT/FloatingPointMode.h + +def Arith_RToNearestTiesToEven // Round to nearest, ties to even + : I32EnumAttrCase<"to_nearest_even", 0>; +def Arith_RDownward // Round toward -inf + : I32EnumAttrCase<"downward", 1>; +def Arith_RUpward // Round toward +inf + : I32EnumAttrCase<"upward", 2>; +def Arith_RTowardZero // Round toward 0 + : I32EnumAttrCase<"toward_zero", 3>; +def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero + : I32EnumAttrCase<"to_nearest_away", 4>; + +def Arith_RoundingModeAttr : I32EnumAttr< + "RoundingMode", "Floating point rounding mode", + [Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward, + Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> { + let cppNamespace = "::mlir::arith"; +} + #endif // ARITH_BASE diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index c9df50d..ead19c6 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> { // TruncFOp //===----------------------------------------------------------------------===// -def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> { +def Arith_TruncFOp : + Arith_Op<"truncf", + [Pure, SameOperandsAndResultShape, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, + Arguments<(ins FloatLike:$in, + OptionalAttr:$roundingmode)>, + Results<(outs FloatLike:$out)> { let summary = "cast from floating-point to narrower floating-point"; let description = [{ Truncate a floating-point value to a smaller floating-point-typed value. The destination type must be strictly narrower than the source type. - If the value cannot be exactly represented, it is rounded using the default - rounding mode. When operating on vectors, casts elementwise. + If the value cannot be exactly represented, it is rounded using the + provided rounding mode or the default one if no rounding mode is provided. + When operating on vectors, casts elementwise. }]; + let builders = [ + OpBuilder<(ins "Type":$out, "Value":$in), [{ + $_state.addOperands(in); + $_state.addTypes(out); + }]> + ]; let hasFolder = 1; let hasVerifier = 1; + let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td index 73a5d9c..82d6c9a 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td @@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI ]; } +def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> { + let description = [{ + Access to op rounding mode. + }]; + + let cppNamespace = "::mlir::arith"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a RoundingModeAttr attribute for the operation", + /*returnType=*/ "RoundingModeAttr", + /*methodName=*/ "getRoundingModeAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + return op.getRoundingmodeAttr(); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the RoundingModeAttr attribute for + the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getRoundingModeAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "roundingmode"; + }] + > + ]; +} + #endif // ARITH_OPS_INTERFACES diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp index dab064a..f12eba9 100644 --- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp +++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp @@ -55,3 +55,34 @@ LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM( return LLVM::IntegerOverflowFlagsAttr::get( flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags)); } + +LLVM::RoundingMode +mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) { + switch (roundingMode) { + case arith::RoundingMode::downward: + return LLVM::RoundingMode::TowardNegative; + case arith::RoundingMode::to_nearest_away: + return LLVM::RoundingMode::NearestTiesToAway; + case arith::RoundingMode::to_nearest_even: + return LLVM::RoundingMode::NearestTiesToEven; + case arith::RoundingMode::toward_zero: + return LLVM::RoundingMode::TowardZero; + case arith::RoundingMode::upward: + return LLVM::RoundingMode::TowardPositive; + } + llvm_unreachable("Unhandled rounding mode"); +} + +LLVM::RoundingModeAttr mlir::arith::convertArithRoundingModeAttrToLLVM( + arith::RoundingModeAttr roundingModeAttr) { + assert(roundingModeAttr && "Expecting valid attribute"); + return LLVM::RoundingModeAttr::get( + roundingModeAttr.getContext(), + convertArithRoundingModeToLLVM(roundingModeAttr.getValue())); +} + +LLVM::FPExceptionBehaviorAttr +mlir::arith::getLLVMDefaultFPExceptionBehavior(MLIRContext &context) { + return LLVM::FPExceptionBehaviorAttr::get(&context, + LLVM::FPExceptionBehavior::Ignore); +} diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index b51a13a..0113a3d 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc, } LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { + // Only supporting default rounding mode as of now. + if (op.getRoundingmodeAttr()) + return failure(); Type outType = op.getOut().getType(); if (auto outVecType = outType.dyn_cast()) { if (outVecType.isScalable()) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 1f01f4a..d882f11 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -28,6 +28,31 @@ using namespace mlir; namespace { +/// Operations whose conversion will depend on whether they are passed a +/// rounding mode attribute or not. +/// +/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower +/// to; `AttrConvert` is the attribute conversion to convert the rounding mode +/// attribute. +template typename AttrConvert = + AttrConvertPassThrough> +struct ConstrainedVectorConvertToLLVMPattern + : public VectorConvertToLLVMPattern { + using VectorConvertToLLVMPattern::VectorConvertToLLVMPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (Constrained != static_cast(op.getRoundingModeAttr())) + return failure(); + return VectorConvertToLLVMPattern::matchAndRewrite(op, adaptor, + rewriter); + } +}; + //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// @@ -112,7 +137,11 @@ using SubIOpLowering = VectorConvertToLLVMPattern; using TruncFOpLowering = - VectorConvertToLLVMPattern; + ConstrainedVectorConvertToLLVMPattern; +using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< + arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, + arith::AttrConverterConstrainedFPToLLVM>; using TruncIOpLowering = VectorConvertToLLVMPattern; using UIToFPOpLowering = @@ -537,6 +566,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns( SubFOpLowering, SubIOpLowering, TruncFOpLowering, + ConstrainedTruncFOpLowering, TruncIOpLowering, UIToFPOpLowering, XOrIOpLowering diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 7456bf7..8069817 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern { } else { rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); + if (auto roundingModeOp = + dyn_cast(*op)) { + if (arith::RoundingModeAttr roundingMode = + roundingModeOp.getRoundingModeAttr()) { + // TODO: Perform rounding mode attribute conversion and attach to new + // operation when defined in the dialect. + return failure(); + } + } } return success(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 2f32d9a..0d46679 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -91,6 +91,29 @@ arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { llvm_unreachable("unknown cmpi predicate kind"); } +/// Equivalent to +/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)). +/// +/// Not possible to implement as chain of calls as this would introduce a +/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend +/// on the LLVM dialect and on translation to LLVM. +static llvm::RoundingMode +convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) { + switch (roundingMode) { + case RoundingMode::downward: + return llvm::RoundingMode::TowardNegative; + case RoundingMode::to_nearest_away: + return llvm::RoundingMode::NearestTiesToAway; + case RoundingMode::to_nearest_even: + return llvm::RoundingMode::NearestTiesToEven; + case RoundingMode::toward_zero: + return llvm::RoundingMode::TowardZero; + case RoundingMode::upward: + return llvm::RoundingMode::TowardPositive; + } + llvm_unreachable("Unhandled rounding mode"); +} + static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { return arith::CmpIPredicateAttr::get(pred.getContext(), invertPredicate(pred.getValue())); @@ -1233,13 +1256,12 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { } /// Attempts to convert `sourceValue` to an APFloat value with -/// `targetSemantics`, without any information loss or rounding. -static FailureOr -convertFloatValue(APFloat sourceValue, - const llvm::fltSemantics &targetSemantics) { +/// `targetSemantics` and `roundingMode`, without any information loss. +static FailureOr convertFloatValue( + APFloat sourceValue, const llvm::fltSemantics &targetSemantics, + llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) { bool losesInfo = false; - auto status = sourceValue.convert( - targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo); + auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo); if (losesInfo || status != APFloat::opOK) return failure(); @@ -1391,15 +1413,19 @@ LogicalResult arith::TruncIOp::verify() { //===----------------------------------------------------------------------===// /// Perform safe const propagation for truncf, i.e., only propagate if FP value -/// can be represented without precision loss or rounding. This is because the -/// semantics of `arith.truncf` do not assume a specific rounding mode. +/// can be represented without precision loss. OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { auto resElemType = cast(getElementTypeOrSelf(getType())); const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); return constFoldCastOp( adaptor.getOperands(), getType(), - [&targetSemantics](const APFloat &a, bool &castStatus) { - FailureOr result = convertFloatValue(a, targetSemantics); + [this, &targetSemantics](const APFloat &a, bool &castStatus) { + RoundingMode roundingMode = + getRoundingmode().value_or(RoundingMode::to_nearest_even); + llvm::RoundingMode llvmRoundingMode = + convertArithRoundingModeToLLVMIR(roundingMode); + FailureOr result = + convertFloatValue(a, targetSemantics, llvmRoundingMode); if (failed(result)) { castStatus = false; return a; diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 71e14a1..dd04a59 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -253,6 +253,11 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16."); } + if (op.getRoundingmodeAttr()) { + return rewriter.notifyMatchFailure( + op, "only applicable to default rounding mode."); + } + Type i16Ty = b.getI16Type(); Type i32Ty = b.getI32Type(); Type f32Ty = b.getF32Type(); diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 29268ee..56ae930 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -289,6 +289,21 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) { return } +// CHECK-LABEL: experimental_constrained_fptrunc +func.func @experimental_constrained_fptrunc(%arg0 : f64) { +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32 + %0 = arith.truncf %arg0 to_nearest_even : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32 + %1 = arith.truncf %arg0 downward : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32 + %2 = arith.truncf %arg0 upward : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32 + %3 = arith.truncf %arg0 toward_zero : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32 + %4 = arith.truncf %arg0 to_nearest_away : f64 to f32 + return +} + // Check sign and zero extension and truncation of integers. // CHECK-LABEL: @integer_extension_and_truncation func.func @integer_extension_and_truncation(%arg0 : i3) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index bdc6c91..79a3185 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -757,6 +757,51 @@ func.func @truncFPConstant() -> bf16 { return %0 : bf16 } +// CHECK-LABEL: @truncFPToNearestEvenConstant +// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16 +// CHECK: return %[[cres]] +func.func @truncFPToNearestEvenConstant() -> bf16 { + %cst = arith.constant 1.000000e+00 : f32 + %0 = arith.truncf %cst to_nearest_even : f32 to bf16 + return %0 : bf16 +} + +// CHECK-LABEL: @truncFPDownwardConstant +// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16 +// CHECK: return %[[cres]] +func.func @truncFPDownwardConstant() -> bf16 { + %cst = arith.constant 1.000000e+00 : f32 + %0 = arith.truncf %cst downward : f32 to bf16 + return %0 : bf16 +} + +// CHECK-LABEL: @truncFPUpwardConstant +// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16 +// CHECK: return %[[cres]] +func.func @truncFPUpwardConstant() -> bf16 { + %cst = arith.constant 1.000000e+00 : f32 + %0 = arith.truncf %cst upward : f32 to bf16 + return %0 : bf16 +} + +// CHECK-LABEL: @truncFPTowardZeroConstant +// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16 +// CHECK: return %[[cres]] +func.func @truncFPTowardZeroConstant() -> bf16 { + %cst = arith.constant 1.000000e+00 : f32 + %0 = arith.truncf %cst toward_zero : f32 to bf16 + return %0 : bf16 +} + +// CHECK-LABEL: @truncFPToNearestAwayConstant +// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16 +// CHECK: return %[[cres]] +func.func @truncFPToNearestAwayConstant() -> bf16 { + %cst = arith.constant 1.000000e+00 : f32 + %0 = arith.truncf %cst to_nearest_away : f32 to bf16 + return %0 : bf16 +} + // CHECK-LABEL: @truncFPVectorConstant // CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16> // CHECK: return %[[cres]] diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index e499573..f684e02 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf return %0 : vector<[8]xbf16> } +// CHECK-LABEL: test_truncf_rounding_mode +func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) { + %0 = arith.truncf %arg0 to_nearest_even : f64 to f32 + %1 = arith.truncf %arg0 downward : f64 to f32 + %2 = arith.truncf %arg0 upward : f64 to f32 + %3 = arith.truncf %arg0 toward_zero : f64 to f32 + %4 = arith.truncf %arg0 to_nearest_away : f64 to f32 + return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32 +} + // CHECK-LABEL: test_uitofp func.func @test_uitofp(%arg0 : i32) -> f32 { %0 = arith.uitofp %arg0 : i32 to f32 -- cgit v1.1