aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp')
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp342
1 files changed, 185 insertions, 157 deletions
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cf9bb3a..8230591 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
Chipset chipset;
- ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
- : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+ ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset,
+ PatternBenefit benefit)
+ : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const override;
@@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
- Chipset chipset)
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
- chipset(chipset) {}
+ Chipset chipset, PatternBenefit benefit)
+ : OpRewritePattern::OpRewritePattern(ctx, benefit),
+ saturateFP8(saturateFP8), chipset(chipset) {}
Chipset chipset;
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final
: OpRewritePattern<arith::ScalingExtFOp> {
using OpRewritePattern::OpRewritePattern;
- ScalingExtFRewritePattern(MLIRContext *ctx)
- : OpRewritePattern::OpRewritePattern(ctx) {}
-
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const override;
};
@@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final
: OpRewritePattern<arith::ScalingTruncFOp> {
using OpRewritePattern::OpRewritePattern;
- ScalingTruncFRewritePattern(MLIRContext *ctx)
- : OpRewritePattern::OpRewritePattern(ctx) {}
-
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const override;
};
@@ -115,9 +110,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::TruncFOp>(loc, desType, f32);
+ return arith::TruncFOp::create(rewriter, loc, desType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::ExtFOp>(loc, desType, f32);
+ return arith::ExtFOp::create(rewriter, loc, desType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -139,64 +134,64 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), in, 0);
+ Value asFloat = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = inVecType.getNumElements();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
VectorType outType = cast<VectorType>(op.getOut().getType());
if (inVecType.getShape().empty()) {
Value zerodSplat =
- rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
- rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
+ vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
Value scalarExt =
- rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
- ArrayRef<int64_t>{});
+ arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
+ Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
+ zerodSplat, ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
}
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+ Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
inVecType.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
- Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, i, elemsThisOp, 1);
+ Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
+ elemsThisOp, 1);
for (int64_t j = 0; j < elemsThisOp; j += 2) {
if (i + j + 1 < numElements) { // Convert two 8-bit elements
- Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, extResType, inSlice, j / 2);
+ Value asFloats = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, extResType, inSlice, j / 2);
Type desType = VectorType::get(2, outElemType);
Value asType = castF32To(desType, asFloats, loc, rewriter);
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, asType, result, i + j, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
+ result, i + j, 1);
} else { // Convert a 8-bit element
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
+ Value asFloat = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
}
}
}
if (inVecType.getRank() != outType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
}
rewriter.replaceOp(op, result);
@@ -208,9 +203,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
if (type.isF32())
return value;
if (type.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
+ return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
if (type.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
+ return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -250,13 +245,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
loc, arith::CmpFPredicate::OEQ, source, negInf);
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
loc, arith::CmpFPredicate::UNO, source, source);
- Value isNonFinite = rewriter.create<arith::OrIOp>(
- loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+ Value isNonFinite = arith::OrIOp::create(
+ rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
+ isNan);
- Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
- Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+ Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
+ Value clamped =
+ arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
Value res =
- rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+ arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
return res;
}
@@ -290,62 +287,62 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
VectorType truncResType = VectorType::get(4, outElemType);
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
- Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
- loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
+ Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
+ rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
- Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
+ Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = outVecType.getNumElements();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
if (outVecType.getShape().empty()) {
Value scalarIn =
- rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
+ vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
- rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
- ArrayRef<int64_t>{});
+ arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
+ Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
+ ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
}
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outVecType.getElementType());
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+ Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
for (int64_t j = 0; j < elemsThisOp; j += 2) {
- Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
+ Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
Value asFloatA = castToF32(elemA, loc, rewriter);
Value asFloatB = nullptr;
if (j + 1 < elemsThisOp) {
- Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
+ Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
asFloatB = castToF32(elemB, loc, rewriter);
}
- thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
- loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
+ thisResult = amdgpu::PackedTrunc2xFp8Op::create(
+ rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
}
if (elemsThisOp < 4)
- thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, thisResult, 0, elemsThisOp, 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
- result, i, 1);
+ thisResult = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, thisResult, 0, elemsThisOp, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
+ result, i, 1);
}
if (inVectorTy.getRank() != outVecType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}
rewriter.replaceOp(op, result);
@@ -373,22 +370,23 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
// Handle the case where input type is not a vector type
if (!inVectorTy) {
- auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
Value asF16s =
- rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
- Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
+ ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
+ Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}
// Handle the vector case. We also handle the (uncommon) case where the vector
@@ -396,25 +394,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
- Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
- Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
+ Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
if (elemsThisOp == 2) {
- elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
+ elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
}
thisResult =
- rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+ ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
// Place back the truncated result into the possibly larger vector. If we
// are operating on a size 2 vector, these operations should be folded away
- thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, thisResult, 0, elemsThisOp, 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
- result, i, 1);
+ thisResult = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, thisResult, 0, elemsThisOp, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
+ result, i, 1);
}
if (inVectorTy.getRank() != outVecType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}
rewriter.replaceOp(op, result);
@@ -451,7 +449,7 @@ LogicalResult
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opOutWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -462,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Type scaleType = getElementTypeOrSelf(scale);
Type outType = getElementTypeOrSelf(out);
+ int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth();
+
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
@@ -471,28 +471,29 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
- scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
- scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+ scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- VectorType extScaleResultType = VectorType::get(opWidth, outType);
+ VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
- Value inCast =
- rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ Value inCast = vector::BroadcastOp::create(rewriter, loc,
+ VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
- loc, extScaleResultType, inCast, scale, 0);
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inCast, scale, 0);
scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
return success();
}
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+ if (origScaleVecType)
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
originalScaleShape.insert(originalScaleShape.end(),
@@ -507,44 +508,52 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
int64_t blockSize = computeProduct(ratio);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outType, rewriter.getFloatAttr(outType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc, outType,
+ rewriter.getFloatAttr(outType, 0.0));
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value block = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, offsets, ratio, strides);
+ Value block = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
- rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
Value uniformScale =
- rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+ vector::ExtractOp::create(rewriter, loc, scale, offsets);
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
- rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+ for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
- loc, extScaleResultType, slice, uniformScale, 0);
- if (sliceWidth != opWidth)
- scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, scaleExt, 0, sliceWidth, 1);
- blockResult = rewriter.create<vector::InsertStridedSliceOp>(
- loc, scaleExt, blockResult, i, 1);
+ i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+ Value inSlice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, inSliceWidth, 1);
+ for (int64_t j = 0,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+ j < inSliceWidth; j += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inSlice, uniformScale,
+ j / opOutWidth);
+ if (outSliceWidth < opOutWidth) {
+ scaleExt = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+ }
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleExt, blockResult, i + j, 1);
+ }
}
VectorType resultType = VectorType::get(ratio, outType);
Value cast =
- rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
- offsets, strides);
+ vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -556,7 +565,7 @@ LogicalResult
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opInWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -569,28 +578,28 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
if (outVecType && outVecType.isScalable())
return failure();
Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
- scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
- scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+ scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outType, rewriter.getFloatAttr(outType, 0.0));
- unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
- VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+ Value zero = arith::ConstantOp::create(rewriter, loc, outType,
+ rewriter.getFloatAttr(outType, 0.0));
+ int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+ VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
- Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
+ Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
- Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
- loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
+ Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, inCast, scale, 0,
+ /*existing=*/nullptr);
scaleTrunc =
rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
return success();
@@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
- SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
- llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+ SmallVector<int64_t> scaleShape;
+ if (origScaleVecType)
+ llvm::append_range(scaleShape, origScaleVecType.getShape());
- originalScaleShape.insert(originalScaleShape.end(),
- inShape.size() - originalScaleShape.size(), 1);
+ scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
- auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+ auto maybeRatio = computeShapeRatio(inShape, scaleShape);
assert(maybeRatio &&
"failed to derive block size from broadcast or splat operation");
@@ -616,45 +625,62 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
int64_t blockSize = computeProduct(ratio);
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value block = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, offsets, ratio, strides);
+ Value block = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
- rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
Value uniformScale =
- rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+ vector::ExtractOp::create(rewriter, loc, scale, offsets);
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
- rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
-
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
- i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
- Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
- loc, truncScaleResultType, slice, uniformScale, 0,
- /*existing=*/nullptr);
- int64_t packedWidth =
- cast<VectorType>(scaleTrunc.getType()).getNumElements();
- if (packedWidth != opWidth)
- scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, scaleTrunc, 0, sliceWidth, 1);
- blockResult = rewriter.create<vector::InsertStridedSliceOp>(
- loc, scaleTrunc, blockResult, i, 1);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
+
+ for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+ i < blockSize; i += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+ Value scaleTrunc;
+ // Case where <= 2 elements are being truncated.
+ if (outSliceWidth <= opInWidth) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, outSliceWidth, 1);
+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+ /*existing=*/nullptr);
+ } else {
+ scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+ truncScaleResultType, zero);
+ for (int64_t j = 0,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+ j < outSliceWidth; j += opInWidth,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i + j, inSliceWidth, 1);
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale,
+ j / opInWidth, scaleTrunc);
+ }
+ }
+ if (outSliceWidth != opOutWidth) {
+ scaleTrunc = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+ }
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleTrunc, blockResult, i, 1);
}
VectorType resultType = VectorType::get(ratio, outType);
Value cast =
- rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
- offsets, strides);
+ vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -664,19 +690,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
- bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
+ PatternBenefit benefit) {
if (convertFP8Arithmetic) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
- patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
- saturateFP8Truncf, chipset);
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
+ benefit);
+ patterns.add<TruncFToFloat8RewritePattern>(
+ patterns.getContext(), saturateFP8Truncf, chipset, benefit);
}
if (allowPackedF16Rtz)
- patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+ patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
if (chipset >= kGfx950) {
- patterns.add<ScalingExtFRewritePattern>(patterns.getContext());
- patterns.add<ScalingTruncFRewritePattern>(patterns.getContext());
+ patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
+ patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
}
}