aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKrzysztof Drewniak <Krzysztof.Drewniak@amd.com>2024-01-23 16:52:21 -0600
committerGitHub <noreply@github.com>2024-01-23 16:52:21 -0600
commit750e90e4403df23d6b271afb90e6b4d463739965 (patch)
tree9c89cd12c1c4a803eb02b08c747b7e73cad41687
parent575568de4166bf69e0a5bc68978580afbe936878 (diff)
downloadllvm-750e90e4403df23d6b271afb90e6b4d463739965.zip
llvm-750e90e4403df23d6b271afb90e6b4d463739965.tar.gz
llvm-750e90e4403df23d6b271afb90e6b4d463739965.tar.bz2
[mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 (#74153)
Many machine-learning applications (and most software written at AMD) expect the operation that truncates floats to 8-bit floats to be saturatinng. That is, they expect `truncf 256.0 : f32 to f8E4M3FNUZ` to yield `240.0`, not `NaN`, and similarly for negative numbers. However, the underlying hardware instruction that can be used for this truncation implements overflow-to-NaN semantics. To enable handling this usecase, we add the saturate-fp8-truncf option to ArithToAMDGPU (off by default), which causes the requisite clamping code to be emitted. Said clamping code ensures that Inf and NaN are passed through exactly (and thus trancate to NaN). Per review feedback, this commit efactors createScalarOrSplatConstant() to the Arith dialect utilities and uses it in this code. It also fixes naming of existing patterns and switches from vector.extractelement/insertelement to vector.extract/insert.
-rw-r--r--mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h8
-rw-r--r--mlir/include/mlir/Conversion/Passes.td6
-rw-r--r--mlir/include/mlir/Dialect/Arith/Utils/Utils.h11
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp108
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp30
-rw-r--r--mlir/lib/Dialect/Arith/Utils/Utils.cpp34
-rw-r--r--mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir54
-rw-r--r--mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir33
9 files changed, 208 insertions, 77 deletions
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 7f445fe..78c79c9 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -20,7 +20,13 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace arith {
-void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+/// Add patterns for rewriting `arith.extf` and `arith.truncf` on FP8 types
+/// to wrappers around AMDGPU--specific intrinsics. If `saturateFP8TruncF`
+/// is set, values outside the range of the destination type are clamped
+/// to the largest value of that type instead of being rewritten to Inf (aka
+/// NaN).
+void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
+ bool saturateFP8TruncF);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3467e04..ec0a628 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -125,6 +125,12 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
}];
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+
+ let options = [
+ Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
+ /*default=*/"false",
+ "Use saturating truncation for 8-bit float types">,
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 62a84ee..402bd19 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -53,6 +53,17 @@ Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
Type toType, bool isUnsignedCast);
+/// Create a constant of type `type` at location `loc` whose value is `value`
+/// (an APInt or APFloat whose type must match the element type of `type`).
+/// If `type` is a shaped type, create a splat constant of the given value.
+/// Constants are folded if possible.
+Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
+ const APInt &value);
+Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
+ int64_t value);
+Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
+ const APFloat &value);
+
/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support.
struct ArithBuilder {
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 7785405..c625a30 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
@@ -34,17 +35,17 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};
-struct ExtfOnFloat8RewritePattern final
- : public OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
-struct TruncfToFloat8RewritePattern final
- : public OpRewritePattern<arith::TruncFOp> {
- using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
+ bool saturateFP8 = false;
+ TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
@@ -62,7 +63,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
llvm_unreachable("The only 32-bit float type is f32");
}
-LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
+LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
Type inType = op.getIn().getType();
if (auto inVecType = inType.dyn_cast<VectorType>()) {
if (inVecType.isScalable())
@@ -75,7 +76,7 @@ LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
}
-void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
+void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
@@ -93,11 +94,13 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Value result =
rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
if (inType.getShape().empty()) {
- Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
+ Value scalarIn =
+ rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- result = rewriter.create<vector::InsertElementOp>(loc, scalarExt, zero);
+ result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+ ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
for (int64_t i = 0; i < numElements; i += 4) {
@@ -108,9 +111,7 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), inSlice, j);
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
- result = rewriter.create<vector::InsertElementOp>(
- loc, asType, result,
- rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j));
+ result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
}
}
rewriter.replaceOp(op, result);
@@ -127,7 +128,53 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
llvm_unreachable("The only 32-bit float type is f32");
}
-LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
+// If `in` is a finite value, clamp it between the maximum and minimum values
+// of `outElemType` so that subsequent conversion instructions don't
+// overflow those out-of-range values to NaN. These semantics are commonly
+// used in machine-learning contexts where failure to clamp would lead to
+// excessive NaN production.
+static Value clampInput(PatternRewriter &rewriter, Location loc,
+ Type outElemType, Value source) {
+ Type sourceType = source.getType();
+ const llvm::fltSemantics &sourceSem =
+ cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
+ const llvm::fltSemantics &targetSem =
+ cast<FloatType>(outElemType).getFloatSemantics();
+
+ APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
+ APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
+ bool ignoredLosesInfo = false;
+ // We can ignore conversion failures here because this conversion promotes
+ // from a smaller type to a larger one - ex. there can be no loss of precision
+ // when casting fp8 to f16.
+ (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+ (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+
+ Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
+ Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
+
+ Value inf = createScalarOrSplatConstant(
+ rewriter, loc, sourceType,
+ APFloat::getInf(sourceSem, /*Negative=*/false));
+ Value negInf = createScalarOrSplatConstant(
+ rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
+ Value isInf = rewriter.createOrFold<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::OEQ, source, inf);
+ Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::OEQ, source, negInf);
+ Value isNan = rewriter.createOrFold<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UNO, source, source);
+ Value isNonFinite = rewriter.create<arith::OrIOp>(
+ loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+
+ Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
+ Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+ Value res =
+ rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+ return res;
+}
+
+LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
if (outVecType.isScalable())
@@ -137,22 +184,27 @@ LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
return failure();
outType = outVecType.getElementType();
}
+ auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
+ if (inType && inType.getWidth() <= 8 && saturateFP8)
+ // Conversion between 8-bit floats is not supported with truncation enabled.
+ return failure();
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
}
-void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
+void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+ if (saturateFP8)
+ in = clampInput(rewriter, loc, outElemType, in);
VectorType truncResType = VectorType::get(4, outElemType);
if (!in.getType().isa<VectorType>()) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
- Value result = rewriter.create<vector::ExtractElementOp>(
- loc, asF8s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+ Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
return rewriter.replaceOp(op, result);
}
VectorType outType = op.getOut().getType().cast<VectorType>();
@@ -161,11 +213,13 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
if (outType.getShape().empty()) {
- Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
+ Value scalarIn =
+ rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
- result = rewriter.create<vector::InsertElementOp>(loc, scalarTrunc, zero);
+ result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
+ ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
@@ -173,14 +227,11 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
for (int64_t j = 0; j < elemsThisOp; j += 2) {
- Value elemA = rewriter.create<vector::ExtractElementOp>(
- loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i + j));
+ Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
Value asFloatA = castToF32(elemA, loc, rewriter);
Value asFloatB = nullptr;
if (j + 1 < elemsThisOp) {
- Value elemB = rewriter.create<vector::ExtractElementOp>(
- loc, in,
- rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j + 1));
+ Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
asFloatB = castToF32(elemB, loc, rewriter);
}
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -196,15 +247,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
}
void mlir::arith::populateArithToAMDGPUConversionPatterns(
- RewritePatternSet &patterns) {
- patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
- patterns.getContext());
+ RewritePatternSet &patterns, bool saturateFP8TruncF) {
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
+ saturateFP8TruncF);
}
void ArithToAMDGPUConversionPass::runOnOperation() {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
- arith::populateArithToAMDGPUConversionPatterns(patterns);
+ arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
index 359015b..e2c951b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
MLIRArithDialect
+ MLIRArithUtils
MLIRVectorDialect
MLIRPass
MLIRTransforms
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 9e783c5..8a4080e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -58,35 +59,6 @@ static Type reduceInnermostDim(VectorType type) {
return VectorType::get(newShape, type.getElementType());
}
-/// Returns a constant of integer of vector type filled with (repeated) `value`.
-static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
- Location loc, Type type,
- const APInt &value) {
- TypedAttr attr;
- if (dyn_cast<IntegerType>(type)) {
- attr = rewriter.getIntegerAttr(type, value);
- } else {
- auto vecTy = cast<VectorType>(type);
- attr = SplatElementsAttr::get(vecTy, value);
- }
-
- return rewriter.create<arith::ConstantOp>(loc, attr);
-}
-
-/// Returns a constant of integer of vector type filled with (repeated) `value`.
-static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
- Location loc, Type type,
- int64_t value) {
- unsigned elementBitWidth = 0;
- if (auto intTy = dyn_cast<IntegerType>(type))
- elementBitWidth = intTy.getWidth();
- else
- elementBitWidth = cast<VectorType>(type).getElementTypeBitWidth();
-
- return createScalarOrSplatConstant(rewriter, loc, type,
- APInt(elementBitWidth, value));
-}
-
/// Extracts the `input` vector slice with elements at the last dimension offset
/// by `lastOffset`. Returns a value of vector type with the last dimension
/// reduced to x1 or fully scalarized, e.g.:
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 0f39c24..bf274d4 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -197,6 +197,40 @@ mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
}));
}
+Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
+ Type type, const APInt &value) {
+ TypedAttr attr;
+ if (isa<IntegerType>(type)) {
+ attr = builder.getIntegerAttr(type, value);
+ } else {
+ auto vecTy = cast<ShapedType>(type);
+ attr = SplatElementsAttr::get(vecTy, value);
+ }
+
+ return builder.create<arith::ConstantOp>(loc, attr);
+}
+
+Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
+ Type type, int64_t value) {
+ unsigned elementBitWidth = 0;
+ if (auto intTy = dyn_cast<IntegerType>(type))
+ elementBitWidth = intTy.getWidth();
+ else
+ elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
+
+ return createScalarOrSplatConstant(builder, loc, type,
+ APInt(elementBitWidth, value));
+}
+
+Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
+ Type type, const APFloat &value) {
+ if (isa<FloatType>(type))
+ return builder.createOrFold<arith::ConstantOp>(
+ loc, type, builder.getFloatAttr(type, value));
+ TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
+ return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
+}
+
Value ArithBuilder::_and(Value lhs, Value rhs) {
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
new file mode 100644
index 0000000..c7f3944
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt --split-input-file %s \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
+// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
+// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ>
+// CHECK: return [[W]] : f8E5M2FNUZ
+func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
+ %w = arith.truncf %v : f16 to f8E5M2FNUZ
+ return %w : f8E5M2FNUZ
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc
+// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[F0:%.+]] = vector.extract [[SATURATED]][0]
+// CHECK: [[F1:%.+]] = vector.extract [[SATURATED]][1]
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FNUZ> to vector<2xf8E4M3FNUZ>
+// CHECK: return [[W]] : vector<2xf8E4M3FNUZ>
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+ %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FNUZ>
+ return %w : vector<2xf8E4M3FNUZ>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index a6c11d02..159a2f0 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -17,14 +17,12 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
// CHECK-LABEL: func.func @vector_ext_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
-// CHECK: [[W0:%.+]] = vector.insertelement [[EXT0]], [[ZEROES]]{{\[}}[[C0]]
+// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
-// CHECK: [[W1:%.+]] = vector.insertelement [[EXT1]], [[W0]]{{\[}}[[C1]]
+// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
// CHECK: return [[W1]] : vector<2xf64>
func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
@@ -38,27 +36,27 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insertelement [[F0]]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insertelement [[F1]], [[W0]]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insertelement [[F2]], [[W1]]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insertelement [[F3]], [[W2]]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insertelement [[F4]], [[W3]]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insertelement [[F5]], [[W4]]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insertelement [[F6]], [[W5]]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insertelement [[F7]], [[W6]]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insertelement [[F8]], [[W7]]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
// CHECK: return [[W8]]
func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
%w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
@@ -69,10 +67,9 @@ func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
// CHECK-LABEL: func.func @scalar_trunc
// CHECK-SAME: ([[V:%.+]]: f16)
-// CHECK: [[C0:%.+]] = arith.constant 0 : index
// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
-// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ>
// CHECK: return [[W]] : f8E5M2FNUZ
func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
%w = arith.truncf %v : f16 to f8E5M2FNUZ
@@ -85,11 +82,9 @@ func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
// CHECK-LABEL: func.func @vector_trunc_short
// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
-// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
-// CHECK: [[V0:%.+]] = vector.extractelement [[V]]{{\[}}[[C0]] : index]
+// CHECK: [[V0:%.+]] = vector.extract [[V]][0]
// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
-// CHECK: [[V1:%.+]] = vector.extractelement [[V]]{{\[}}[[C1]] : index]
+// CHECK: [[V1:%.+]] = vector.extract [[V]][1]
// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ>