aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVictor Perez <victor.perez@codeplay.com>2024-04-01 11:57:14 +0200
committerGitHub <noreply@github.com>2024-04-01 11:57:14 +0200
commit8827ff92b96d78ef455157574061d745df2909af (patch)
tree2a2885f19db51baeba13b6a8594975d396084f34
parentda1d3d8fb9e7dba1cc89327f5119fa7c0cadef81 (diff)
downloadllvm-8827ff92b96d78ef455157574061d745df2909af.zip
llvm-8827ff92b96d78ef455157574061d745df2909af.tar.gz
llvm-8827ff92b96d78ef455157574061d745df2909af.tar.bz2
[MLIR][Arith] Add rounding mode attribute to `truncf` (#86152)
Add rounding mode attribute to `arith`. This attribute can be used in different FP `arith` operations to control rounding mode. Rounding modes correspond to IEEE 754-specified rounding modes. Use in `arith.truncf` folding. As this is not supported in dialects other than LLVM, conversion should fail for now in case this attribute is present. --------- Signed-off-by: Victor Perez <victor.perez@codeplay.com>
-rw-r--r--mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h48
-rw-r--r--mlir/include/mlir/Dialect/Arith/IR/ArithBase.td25
-rw-r--r--mlir/include/mlir/Dialect/Arith/IR/ArithOps.td21
-rw-r--r--mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td33
-rw-r--r--mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp31
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp3
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp32
-rw-r--r--mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp9
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp46
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp5
-rw-r--r--mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir15
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir45
-rw-r--r--mlir/test/Dialect/Arith/ops.mlir10
13 files changed, 309 insertions, 14 deletions
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 32d7979..0891e2b 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -36,6 +36,20 @@ convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
LLVM::IntegerOverflowFlagsAttr
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
+/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
+/// mode enum value.
+LLVM::RoundingMode
+convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode);
+
+/// Creates an LLVM rounding mode attribute from a given arithmetic rounding
+/// mode attribute.
+LLVM::RoundingModeAttr
+convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr);
+
+/// Returns an attribute for the default LLVM FP exception behavior.
+LLVM::FPExceptionBehaviorAttr
+getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
+
// Attribute converter that populates a NamedAttrList by removing the fastmath
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM fastmath attribute.
@@ -89,6 +103,40 @@ public:
private:
NamedAttrList convertedAttr;
};
+
+template <typename SourceOp, typename TargetOp>
+class AttrConverterConstrainedFPToLLVM {
+ static_assert(TargetOp::template hasTrait<
+ LLVM::FPExceptionBehaviorOpInterface::Trait>(),
+ "Target constrained FP operations must implement "
+ "LLVM::FPExceptionBehaviorOpInterface");
+
+public:
+ AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
+ // Copy the source attributes.
+ convertedAttr = NamedAttrList{srcOp->getAttrs()};
+
+ if constexpr (TargetOp::template hasTrait<
+ LLVM::RoundingModeOpInterface::Trait>()) {
+ // Get the name of the rounding mode attribute.
+ StringRef arithAttrName = srcOp.getRoundingModeAttrName();
+ // Remove the source attribute.
+ auto arithAttr =
+ cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName));
+ // Set the target attribute.
+ convertedAttr.set(TargetOp::getRoundingModeAttrName(),
+ convertArithRoundingModeAttrToLLVM(arithAttr));
+ }
+ convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
+ getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
+ }
+
+ ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+ NamedAttrList convertedAttr;
+};
+
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index c8a42c4..19a2ade 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
let assemblyFormat = "`<` $value `>`";
}
+//===----------------------------------------------------------------------===//
+// Arith_RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These correspond to LLVM's values defined in:
+// llvm/include/llvm/ADT/FloatingPointMode.h
+
+def Arith_RToNearestTiesToEven // Round to nearest, ties to even
+ : I32EnumAttrCase<"to_nearest_even", 0>;
+def Arith_RDownward // Round toward -inf
+ : I32EnumAttrCase<"downward", 1>;
+def Arith_RUpward // Round toward +inf
+ : I32EnumAttrCase<"upward", 2>;
+def Arith_RTowardZero // Round toward 0
+ : I32EnumAttrCase<"toward_zero", 3>;
+def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
+ : I32EnumAttrCase<"to_nearest_away", 4>;
+
+def Arith_RoundingModeAttr : I32EnumAttr<
+ "RoundingMode", "Floating point rounding mode",
+ [Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
+ Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
+ let cppNamespace = "::mlir::arith";
+}
+
#endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c9df50d..ead19c6 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
// TruncFOp
//===----------------------------------------------------------------------===//
-def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
+def Arith_TruncFOp :
+ Arith_Op<"truncf",
+ [Pure, SameOperandsAndResultShape,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<CastOpInterface>]>,
+ Arguments<(ins FloatLike:$in,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+ Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
Truncate a floating-point value to a smaller floating-point-typed value.
The destination type must be strictly narrower than the source type.
- If the value cannot be exactly represented, it is rounded using the default
- rounding mode. When operating on vectors, casts elementwise.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+ When operating on vectors, casts elementwise.
}];
+ let builders = [
+ OpBuilder<(ins "Type":$out, "Value":$in), [{
+ $_state.addOperands(in);
+ $_state.addTypes(out);
+ }]>
+ ];
let hasFolder = 1;
let hasVerifier = 1;
+ let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 73a5d9c..82d6c9a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
];
}
+def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
+ let description = [{
+ Access to op rounding mode.
+ }];
+
+ let cppNamespace = "::mlir::arith";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a RoundingModeAttr attribute for the operation",
+ /*returnType=*/ "RoundingModeAttr",
+ /*methodName=*/ "getRoundingModeAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getRoundingmodeAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the RoundingModeAttr attribute for
+ the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getRoundingModeAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "roundingmode";
+ }]
+ >
+ ];
+}
+
#endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index dab064a..f12eba9 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -55,3 +55,34 @@ LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
return LLVM::IntegerOverflowFlagsAttr::get(
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
}
+
+LLVM::RoundingMode
+mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
+ switch (roundingMode) {
+ case arith::RoundingMode::downward:
+ return LLVM::RoundingMode::TowardNegative;
+ case arith::RoundingMode::to_nearest_away:
+ return LLVM::RoundingMode::NearestTiesToAway;
+ case arith::RoundingMode::to_nearest_even:
+ return LLVM::RoundingMode::NearestTiesToEven;
+ case arith::RoundingMode::toward_zero:
+ return LLVM::RoundingMode::TowardZero;
+ case arith::RoundingMode::upward:
+ return LLVM::RoundingMode::TowardPositive;
+ }
+ llvm_unreachable("Unhandled rounding mode");
+}
+
+LLVM::RoundingModeAttr mlir::arith::convertArithRoundingModeAttrToLLVM(
+ arith::RoundingModeAttr roundingModeAttr) {
+ assert(roundingModeAttr && "Expecting valid attribute");
+ return LLVM::RoundingModeAttr::get(
+ roundingModeAttr.getContext(),
+ convertArithRoundingModeToLLVM(roundingModeAttr.getValue()));
+}
+
+LLVM::FPExceptionBehaviorAttr
+mlir::arith::getLLVMDefaultFPExceptionBehavior(MLIRContext &context) {
+ return LLVM::FPExceptionBehaviorAttr::get(&context,
+ LLVM::FPExceptionBehavior::Ignore);
+}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b51a13a..0113a3d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
}
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+ // Only supporting default rounding mode as of now.
+ if (op.getRoundingmodeAttr())
+ return failure();
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
if (outVecType.isScalable())
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 1f01f4a..d882f11 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -28,6 +28,31 @@ using namespace mlir;
namespace {
+/// Operations whose conversion will depend on whether they are passed a
+/// rounding mode attribute or not.
+///
+/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
+/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
+/// attribute.
+template <typename SourceOp, typename TargetOp, bool Constrained,
+ template <typename, typename> typename AttrConvert =
+ AttrConvertPassThrough>
+struct ConstrainedVectorConvertToLLVMPattern
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+ using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::VectorConvertToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+ return failure();
+ return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::matchAndRewrite(op, adaptor,
+ rewriter);
+ }
+};
+
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
@@ -112,7 +137,11 @@ using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
- VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
+ ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
+ false>;
+using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
+ arith::AttrConverterConstrainedFPToLLVM>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
using UIToFPOpLowering =
@@ -537,6 +566,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
SubFOpLowering,
SubIOpLowering,
TruncFOpLowering,
+ ConstrainedTruncFOpLowering,
TruncIOpLowering,
UIToFPOpLowering,
XOrIOpLowering
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 7456bf7..8069817 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
} else {
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
adaptor.getOperands());
+ if (auto roundingModeOp =
+ dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+ if (arith::RoundingModeAttr roundingMode =
+ roundingModeOp.getRoundingModeAttr()) {
+ // TODO: Perform rounding mode attribute conversion and attach to new
+ // operation when defined in the dialect.
+ return failure();
+ }
+ }
}
return success();
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2f32d9a..0d46679 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -91,6 +91,29 @@ arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
llvm_unreachable("unknown cmpi predicate kind");
}
+/// Equivalent to
+/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
+///
+/// Not possible to implement as chain of calls as this would introduce a
+/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
+/// on the LLVM dialect and on translation to LLVM.
+static llvm::RoundingMode
+convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
+ switch (roundingMode) {
+ case RoundingMode::downward:
+ return llvm::RoundingMode::TowardNegative;
+ case RoundingMode::to_nearest_away:
+ return llvm::RoundingMode::NearestTiesToAway;
+ case RoundingMode::to_nearest_even:
+ return llvm::RoundingMode::NearestTiesToEven;
+ case RoundingMode::toward_zero:
+ return llvm::RoundingMode::TowardZero;
+ case RoundingMode::upward:
+ return llvm::RoundingMode::TowardPositive;
+ }
+ llvm_unreachable("Unhandled rounding mode");
+}
+
static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
return arith::CmpIPredicateAttr::get(pred.getContext(),
invertPredicate(pred.getValue()));
@@ -1233,13 +1256,12 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
}
/// Attempts to convert `sourceValue` to an APFloat value with
-/// `targetSemantics`, without any information loss or rounding.
-static FailureOr<APFloat>
-convertFloatValue(APFloat sourceValue,
- const llvm::fltSemantics &targetSemantics) {
+/// `targetSemantics` and `roundingMode`, without any information loss.
+static FailureOr<APFloat> convertFloatValue(
+ APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
+ llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
bool losesInfo = false;
- auto status = sourceValue.convert(
- targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+ auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
if (losesInfo || status != APFloat::opOK)
return failure();
@@ -1391,15 +1413,19 @@ LogicalResult arith::TruncIOp::verify() {
//===----------------------------------------------------------------------===//
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
-/// can be represented without precision loss or rounding. This is because the
-/// semantics of `arith.truncf` do not assume a specific rounding mode.
+/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
- [&targetSemantics](const APFloat &a, bool &castStatus) {
- FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
+ [this, &targetSemantics](const APFloat &a, bool &castStatus) {
+ RoundingMode roundingMode =
+ getRoundingmode().value_or(RoundingMode::to_nearest_even);
+ llvm::RoundingMode llvmRoundingMode =
+ convertArithRoundingModeToLLVMIR(roundingMode);
+ FailureOr<APFloat> result =
+ convertFloatValue(a, targetSemantics, llvmRoundingMode);
if (failed(result)) {
castStatus = false;
return a;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 71e14a1..dd04a59 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -253,6 +253,11 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
}
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 29268ee..56ae930 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -289,6 +289,21 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
return
}
+// CHECK-LABEL: experimental_constrained_fptrunc
+func.func @experimental_constrained_fptrunc(%arg0 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
+ %0 = arith.truncf %arg0 to_nearest_even : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
+ %1 = arith.truncf %arg0 downward : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
+ %2 = arith.truncf %arg0 upward : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
+ %3 = arith.truncf %arg0 toward_zero : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
+ %4 = arith.truncf %arg0 to_nearest_away : f64 to f32
+ return
+}
+
// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
func.func @integer_extension_and_truncation(%arg0 : i3) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index bdc6c91..79a3185 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -757,6 +757,51 @@ func.func @truncFPConstant() -> bf16 {
return %0 : bf16
}
+// CHECK-LABEL: @truncFPToNearestEvenConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPToNearestEvenConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst to_nearest_even : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPDownwardConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPDownwardConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst downward : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPUpwardConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPUpwardConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst upward : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPTowardZeroConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPTowardZeroConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst toward_zero : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPToNearestAwayConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPToNearestAwayConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst to_nearest_away : f32 to bf16
+ return %0 : bf16
+}
+
// CHECK-LABEL: @truncFPVectorConstant
// CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16>
// CHECK: return %[[cres]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index e499573..f684e02 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf
return %0 : vector<[8]xbf16>
}
+// CHECK-LABEL: test_truncf_rounding_mode
+func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
+ %0 = arith.truncf %arg0 to_nearest_even : f64 to f32
+ %1 = arith.truncf %arg0 downward : f64 to f32
+ %2 = arith.truncf %arg0 upward : f64 to f32
+ %3 = arith.truncf %arg0 toward_zero : f64 to f32
+ %4 = arith.truncf %arg0 to_nearest_away : f64 to f32
+ return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
+}
+
// CHECK-LABEL: test_uitofp
func.func @test_uitofp(%arg0 : i32) -> f32 {
%0 = arith.uitofp %arg0 : i32 to f32