diff options
-rw-r--r-- | mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 109 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | 3 | ||||
-rw-r--r-- | mlir/include/mlir/IR/Builders.h | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 26 | ||||
-rw-r--r-- | mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 105 | ||||
-rw-r--r-- | mlir/lib/IR/Builders.cpp | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Arith/expand-ops.mlir | 178 |
7 files changed, 411 insertions, 13 deletions
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 599b3b9..adc27ae 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1216,6 +1216,58 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast } //===----------------------------------------------------------------------===// +// Scaling ExtFOp +//===----------------------------------------------------------------------===// +def Arith_ScalingExtFOp + : Arith_Op< + "scaling_extf", [Pure, SameInputOutputTensorDims, + DeclareOpInterfaceMethods<ArithFastMathInterface>, + DeclareOpInterfaceMethods<CastOpInterface>]>, + Arguments<(ins FloatLike:$in, FloatLike:$scale, + OptionalAttr<Arith_FastMathAttr>:$fastmath)>, + Results<(outs FloatLike:$out)> { + let summary = "Upcasts input floats using provided scales values following " + "OCP MXFP Spec"; + let description = [{ + This operation upcasts input floating-point values using provided scale + values. It expects both scales and the input operand to be of the same shape, + making the operation elementwise. Scales are usually calculated per block + following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. + + If scales are calculated per block where blockSize != 1, then scales may + require broadcasting to make this operation elementwise. For example, let's + say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and + assuming quantization happens on the last axis, the input can be reshaped to + `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated + per block on the last axis. Therefore, scales will be of shape + `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other + shape as long as it is broadcast compatible with the input, e.g., + `<1 x 1 x ... (dimN/blockSize) x 1>`. + + In this example, before calling into `arith.scaling_extf`, scales must be + broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note + that there could be multiple quantization axes. Internally, + `arith.scaling_extf` would perform the following: + + ``` + resultTy = get_type(result) + scaleTy = get_type(scale) + inputTy = get_type(input) + scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 + scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy + input.extf = arith.extf(input) : inputTy to resultTy + result = arith.mulf(scale.extf, input.extf) + ``` + It propagates NaN values. Therefore, if either scale or the input element + contains NaN, then the output element value will also be a NaN. + }]; + let hasVerifier = 1; + let assemblyFormat = + [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` + type($in) `,` type($scale) `to` type($out)}]; +} + +//===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// @@ -1281,6 +1333,63 @@ def Arith_TruncFOp : } //===----------------------------------------------------------------------===// +// Scaling TruncFOp +//===----------------------------------------------------------------------===// + +def Arith_ScalingTruncFOp + : Arith_Op<"scaling_truncf", + [Pure, SameInputOutputTensorDims, + DeclareOpInterfaceMethods<ArithRoundingModeInterface>, + DeclareOpInterfaceMethods<ArithFastMathInterface>, + DeclareOpInterfaceMethods<CastOpInterface>]>, + Arguments<(ins FloatLike:$in, FloatLike:$scale, + OptionalAttr<Arith_RoundingModeAttr>:$roundingmode, + OptionalAttr<Arith_FastMathAttr>:$fastmath)>, + Results<(outs FloatLike:$out)> { + let summary = "Downcasts input floating point values using provided scales " + "values following OCP MXFP Spec"; + let description = [{ + This operation downcasts input using the provided scale values. It expects + both scales and the input operand to be of the same shape and, therefore, + makes the operation elementwise. Scales are usually calculated per block + following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. + Users are required to normalize and clamp the scales as necessary before calling + passing them to this operation. OCP MXFP spec also does the flushing of denorms + on the input operand, which should be handled during lowering by passing appropriate + fastMath flag to this operation. + + If scales are calculated per block where blockSize != 1, scales may require + broadcasting to make this operation elementwise. For example, let's say the + input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and + assuming quantization happens on the last axis, the input can be reshaped to + `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated + per block on the last axis. Therefore, scales will be of shape + `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other + shape as long as it is broadcast compatible with the input, e.g., + `<1 x 1 x ... (dimN/blockSize) x 1>`. + + In this example, before calling into `arith.scaling_truncf`, scales must be + broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note + that there could be multiple quantization axes. Internally, + `arith.scaling_truncf` would perform the following: + + ``` + scaleTy = get_type(scale) + inputTy = get_type(input) + resultTy = get_type(result) + scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 + scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy + result = arith.divf(input, scale.extf) + result.cast = arith.truncf(result, resultTy) + ``` + }]; + let hasVerifier = 1; + let assemblyFormat = + [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` + type($in) `,` type($scale) `to` type($out)}]; +} + +//===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index 5aaac8d..e0a4567 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns); /// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts. void populateExpandF8E8M0Patterns(RewritePatternSet &patterns); +/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops +void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns); + /// Add patterns to expand Arith ops. void populateArithExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3f7b326..d68dbdb 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -60,6 +60,7 @@ public: Attribute metadata = Attribute()); // Types. + FloatType getF8E8M0Type(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getTF32Type(); diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 41f2d0f..9e53e19 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1452,6 +1452,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } //===----------------------------------------------------------------------===// +// ScalingExtFOp +//===----------------------------------------------------------------------===// + +bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs); +} + +LogicalResult arith::ScalingExtFOp::verify() { + return verifyExtOp<FloatType>(*this); +} + +//===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// @@ -1566,6 +1579,19 @@ LogicalResult arith::TruncFOp::verify() { } //===----------------------------------------------------------------------===// +// ScalingTruncFOp +//===----------------------------------------------------------------------===// + +bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs); +} + +LogicalResult arith::ScalingTruncFOp::verify() { + return verifyTruncateOp<FloatType>(*this); +} + +//===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 95546bb..534aff9 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/Transforms/Passes.h" - #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value, return rewriter.create<arith::ConstantOp>( loc, DenseElementsAttr::get(shapedTy, attr)); } - return rewriter.create<arith::ConstantOp>(loc, attr); } @@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits); Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits); if (resultETy.getIntOrFloatBitWidth() < 32) { - result = b.create<arith::TruncFOp>(resultTy, result); + result = b.create<arith::TruncFOp>(resultTy, result, nullptr, + op.getFastmathAttr()); } else if (resultETy.getIntOrFloatBitWidth() > 32) { - result = b.create<arith::ExtFOp>(resultTy, result); + result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr()); } rewriter.replaceOp(op, result); return success(); @@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); if (operandETy.getIntOrFloatBitWidth() < 32) { - operand = b.create<arith::ExtFOp>(f32Ty, operand); + operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr()); } else if (operandETy.getIntOrFloatBitWidth() > 32) { - operand = b.create<arith::TruncFOp>(f32Ty, operand); + operand = b.create<arith::TruncFOp>( + f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); } Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand); Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); @@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { } }; +struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ScalingExtFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value inputOperand = op.getIn(); + Value scaleOperand = op.getScale(); + Type scaleTy = scaleOperand.getType(); + Type scaleETy = getElementTypeOrSelf(scaleOperand); + // allow implicit exponent extraction from 16/32 bits floats + if (scaleETy.getIntOrFloatBitWidth() >= 16) { + scaleETy = b.getF8E8M0Type(); + scaleTy = cloneToShapedType(scaleTy, scaleETy); + scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, + op.getFastmathAttr()); + } + if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { + return rewriter.notifyMatchFailure( + op, "scaling_extf is using scales of type which can not be converted " + "to f8E8M0FNU"); + } + Type resultTy = op.getType(); + // extf on scale will essentially create floating point number + // of type resulTy that is 2^scale and will also propagate NaNs + Value scaleExt = + b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr()); + Value inputExt = + b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr()); + Value result = + b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr()); + rewriter.replaceOp(op, result); + return success(); + } +}; + +/* +Expands arith.ScalingTruncFOp(in, scale) into + scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU + result = arith.truncf(in / (2^scale)) + */ +struct ScalingTruncFOpConverter + : public OpRewritePattern<arith::ScalingTruncFOp> { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value inputOperand = op.getIn(); + Value scaleOperand = op.getScale(); + Type scaleTy = scaleOperand.getType(); + Type scaleETy = getElementTypeOrSelf(scaleOperand); + // allow implicit exponent extraction from 16/32 bits floats + if (scaleETy.getIntOrFloatBitWidth() >= 16) { + scaleETy = b.getF8E8M0Type(); + scaleTy = cloneToShapedType(scaleTy, scaleETy); + scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, + op.getFastmathAttr()); + } + if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { + return rewriter.notifyMatchFailure( + op, "scaling_truncf is using scales type which can not be converted " + "to f8E8M0FNU"); + } + Type resultTy = op.getType(); + Type inputTy = inputOperand.getType(); + // this will create a floating point number of type + // inputTy that is 2^scale and will also propagate NaNs + scaleOperand = + b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr()); + Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand, + op.getFastmathAttr()); + Value resultCast = b.create<arith::TruncFOp>( + resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); + rewriter.replaceOp(op, resultCast); + return success(); + } +}; + struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> { using ArithExpandOpsPassBase::ArithExpandOpsPassBase; @@ -432,7 +510,9 @@ struct ArithExpandOpsPass arith::MaximumFOp, arith::MinimumFOp, arith::MaxNumFOp, - arith::MinNumFOp + arith::MinNumFOp, + arith::ScalingExtFOp, + arith::ScalingTruncFOp >(); if (includeBf16) { @@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) { patterns.getContext()); } +void mlir::arith::populateExpandScalingExtTruncPatterns( + RewritePatternSet &patterns) { + patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>( + patterns.getContext()); +} + void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); + populateExpandScalingExtTruncPatterns(patterns); // clang-format off patterns.add< MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, @@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>, MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>, MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>, - MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> + MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> >(patterns.getContext()); // clang-format on } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8910211..5f7bc50 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) { // Types. //===----------------------------------------------------------------------===// +FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); } + FloatType Builder::getBF16Type() { return BFloat16Type::get(context); } FloatType Builder::getF16Type() { return Float16Type::get(context); } diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index 5b6badf..db1349f 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s +// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK // Test ceil divide with signed integer // CHECK-LABEL: func @ceildivi @@ -253,7 +254,7 @@ func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU { %0 = arith.truncf %arg0 : f32 to f8E8M0FNU return %0 : f8E8M0FNU } -// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU +// CHECK-LABEL: @truncf_f32_to_f8E8M0FNU // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32 // CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32 // CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32 @@ -267,7 +268,7 @@ func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU { %0 = arith.truncf %arg0 : f16 to f8E8M0FNU return %0 : f8E8M0FNU } -// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU +// CHECK-LABEL: @truncf_f16_to_f8E8M0FNU // CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32 // CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32 // CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32 @@ -305,9 +306,76 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf // CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU // CHECK-NOT: arith.truncf +// CHECK: return +// ----- + +func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN { + %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN + return %0 : f4E2M1FN +} + +// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN +// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 +// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32 +// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> { + %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN> + return %0 : vector<4xf6E3M2FN> +} + +// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN +// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16> +// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16> +// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN> +// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN> // ----- + +func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> { + %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN> + return %0 : vector<4xf6E3M2FN> +} +// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math +// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU> +// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16> +// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16> +// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN> +// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN> + +// ----- + +func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN { + %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN + return %0 : f4E2M1FN +} +// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales +// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN +// SCHECK: return + +// ----- +func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> { + %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN> + return %0 : vector<4xf4E2M1FN> +} +// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales +// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU> +// SCHECK: return + +// ----- + +func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN { + // expected-error@+1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}} + %0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN + return %0 : f4E2M1FN +} + +// ----- + func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 { %0 = arith.extf %arg0 : f8E8M0FNU to f32 return %0 : f32 @@ -332,7 +400,7 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 { return %0 : f16 } -// CHECK-LABLE: @extf_f8E8M0FNU_to_f16 +// CHECK-LABEL: @extf_f8E8M0FNU_to_f16 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8 // CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8 // CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32 @@ -374,7 +442,109 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector< // CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16 // CHECK-NOT: arith.extf +// CHECK: return + +// ----- + +func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 { + %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32 + return %0 : f32 +} + +// SCHECK-LABEL: @scaling_extf_to_f32 +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32 +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32 +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 { + %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32 + return %0 : f32 +} + +// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales +// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32 +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32 +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32 +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 { + // expected-error@+1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}} + %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32 + return %0 : f32 +} + +// ----- + +func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> { + %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32> + return %0 : vector<4xf32> +} + +// SCHECK-LABEL: @scaling_extf_vector_to_f32 +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> { + %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16> + return %0 : vector<4xf16> +} + +// SCHECK-LABEL: @scaling_extf_vector_to_f16 +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16> +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> { + %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16> + return %0 : vector<4xbf16> +} + +// SCHECK-LABEL: @scaling_extf_vector_to_bf16 +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16> +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> { + %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32> + return %0 : vector<4xf32> +} + +// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales +// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU> +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> { + %0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32> + return %0 : vector<4xf32> +} +// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath +// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU> +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32> +// SCHECK: return %[[RESULT]] // ----- |