diff options
Diffstat (limited to 'mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp')
-rw-r--r-- | mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 342 |
1 files changed, 185 insertions, 157 deletions
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index cf9bb3a..8230591 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { using OpRewritePattern::OpRewritePattern; Chipset chipset; - ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) - : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} + ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset, + PatternBenefit benefit) + : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {} LogicalResult matchAndRewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; @@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { bool saturateFP8 = false; TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, - Chipset chipset) - : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), - chipset(chipset) {} + Chipset chipset, PatternBenefit benefit) + : OpRewritePattern::OpRewritePattern(ctx, benefit), + saturateFP8(saturateFP8), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(arith::TruncFOp op, @@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp> { using OpRewritePattern::OpRewritePattern; - ScalingExtFRewritePattern(MLIRContext *ctx) - : OpRewritePattern::OpRewritePattern(ctx) {} - LogicalResult matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const override; }; @@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final : OpRewritePattern<arith::ScalingTruncFOp> { using OpRewritePattern::OpRewritePattern; - ScalingTruncFRewritePattern(MLIRContext *ctx) - : OpRewritePattern::OpRewritePattern(ctx) {} - LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const override; }; @@ -115,9 +110,9 @@ static Value castF32To(Type desType, Value f32, Location loc, if (elementType.isF32()) return f32; if (elementType.getIntOrFloatBitWidth() < 32) - return rewriter.create<arith::TruncFOp>(loc, desType, f32); + return arith::TruncFOp::create(rewriter, loc, desType, f32); if (elementType.getIntOrFloatBitWidth() > 32) - return rewriter.create<arith::ExtFOp>(loc, desType, f32); + return arith::ExtFOp::create(rewriter, loc, desType, f32); llvm_unreachable("The only 32-bit float type is f32"); } @@ -139,64 +134,64 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Type outElemType = getElementTypeOrSelf(op.getOut().getType()); VectorType extResType = VectorType::get(2, rewriter.getF32Type()); if (!inVecType) { - Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, rewriter.getF32Type(), in, 0); + Value asFloat = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, rewriter.getF32Type(), in, 0); Value result = castF32To(outElemType, asFloat, loc, rewriter); rewriter.replaceOp(op, result); return success(); } int64_t numElements = inVecType.getNumElements(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); VectorType outType = cast<VectorType>(op.getOut().getType()); if (inVecType.getShape().empty()) { Value zerodSplat = - rewriter.createOrFold<vector::SplatOp>(loc, outType, zero); + rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero); Value scalarIn = - rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{}); Value scalarExt = - rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn); - Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat, - ArrayRef<int64_t>{}); + arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarExt, + zerodSplat, ArrayRef<int64_t>{}); rewriter.replaceOp(op, result); return success(); } VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, outType.getElementType()); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); + Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero); if (inVecType.getRank() > 1) { inVecType = VectorType::get(SmallVector<int64_t>{numElements}, inVecType.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; - Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, i, elemsThisOp, 1); + Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i, + elemsThisOp, 1); for (int64_t j = 0; j < elemsThisOp; j += 2) { if (i + j + 1 < numElements) { // Convert two 8-bit elements - Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, extResType, inSlice, j / 2); + Value asFloats = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, extResType, inSlice, j / 2); Type desType = VectorType::get(2, outElemType); Value asType = castF32To(desType, asFloats, loc, rewriter); - result = rewriter.create<vector::InsertStridedSliceOp>( - loc, asType, result, i + j, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, asType, + result, i + j, 1); } else { // Convert a 8-bit element - Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, rewriter.getF32Type(), inSlice, j / 2 * 2); + Value asFloat = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2); Value asType = castF32To(outElemType, asFloat, loc, rewriter); - result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j); + result = vector::InsertOp::create(rewriter, loc, asType, result, i + j); } } } if (inVecType.getRank() != outType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outType, result); } rewriter.replaceOp(op, result); @@ -208,9 +203,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { if (type.isF32()) return value; if (type.getIntOrFloatBitWidth() < 32) - return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value); + return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value); if (type.getIntOrFloatBitWidth() > 32) - return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value); + return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value); llvm_unreachable("The only 32-bit float type is f32"); } @@ -250,13 +245,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc, loc, arith::CmpFPredicate::OEQ, source, negInf); Value isNan = rewriter.createOrFold<arith::CmpFOp>( loc, arith::CmpFPredicate::UNO, source, source); - Value isNonFinite = rewriter.create<arith::OrIOp>( - loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan); + Value isNonFinite = arith::OrIOp::create( + rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf), + isNan); - Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst); - Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst); + Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst); + Value clamped = + arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst); Value res = - rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped); + arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped); return res; } @@ -290,62 +287,62 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, VectorType truncResType = VectorType::get(4, outElemType); if (!inVectorTy) { Value asFloat = castToF32(in, loc, rewriter); - Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( - loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, + Value asF8s = amdgpu::PackedTrunc2xFp8Op::create( + rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, /*existing=*/nullptr); - Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0); + Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0); rewriter.replaceOp(op, result); return success(); } int64_t numElements = outVecType.getNumElements(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); if (outVecType.getShape().empty()) { Value scalarIn = - rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarTrunc = - rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn); - Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero, - ArrayRef<int64_t>{}); + arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero, + ArrayRef<int64_t>{}); rewriter.replaceOp(op, result); return success(); } VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, outVecType.getElementType()); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); + Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero); if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, inVectorTy.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value thisResult = nullptr; for (int64_t j = 0; j < elemsThisOp; j += 2) { - Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j); Value asFloatA = castToF32(elemA, loc, rewriter); Value asFloatB = nullptr; if (j + 1 < elemsThisOp) { - Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1); + Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1); asFloatB = castToF32(elemB, loc, rewriter); } - thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( - loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); + thisResult = amdgpu::PackedTrunc2xFp8Op::create( + rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); } if (elemsThisOp < 4) - thisResult = rewriter.create<vector::ExtractStridedSliceOp>( - loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, - result, i, 1); + thisResult = vector::ExtractStridedSliceOp::create( + rewriter, loc, thisResult, 0, elemsThisOp, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, + result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -373,22 +370,23 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( // Handle the case where input type is not a vector type if (!inVectorTy) { - auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); + auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); Value asF16s = - rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB); - Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB); + Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0); rewriter.replaceOp(op, result); return success(); } int64_t numElements = outVecType.getNumElements(); Value zero = rewriter.createOrFold<arith::ConstantOp>( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero); + Value result = + rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero); if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, inVectorTy.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } // Handle the vector case. We also handle the (uncommon) case where the vector @@ -396,25 +394,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( for (int64_t i = 0; i < numElements; i += 2) { int64_t elemsThisOp = std::min(numElements, i + 2) - i; Value thisResult = nullptr; - Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i); - Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i); + Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); if (elemsThisOp == 2) { - elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1); + elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1); } thisResult = - rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB); // Place back the truncated result into the possibly larger vector. If we // are operating on a size 2 vector, these operations should be folded away - thisResult = rewriter.create<vector::ExtractStridedSliceOp>( - loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, - result, i, 1); + thisResult = vector::ExtractStridedSliceOp::create( + rewriter, loc, thisResult, 0, elemsThisOp, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, + result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -451,7 +449,7 @@ LogicalResult ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opOutWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -462,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleType = getElementTypeOrSelf(scale); Type outType = getElementTypeOrSelf(out); + int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth(); + VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); @@ -471,28 +471,29 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; if (scaleType.getIntOrFloatBitWidth() < 32) - scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale); + scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale); else if (scaleType.getIntOrFloatBitWidth() > 32) - scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale); + scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - VectorType extScaleResultType = VectorType::get(opWidth, outType); + VectorType extScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { - Value inCast = - rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in); + Value inCast = vector::BroadcastOp::create(rewriter, loc, + VectorType::get(1, inType), in); // TODO: replace this with non-packed ScaledExtOp - Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>( - loc, extScaleResultType, inCast, scale, 0); + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inCast, scale, 0); scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0); return success(); } VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) + if (origScaleVecType) llvm::append_range(originalScaleShape, origScaleVecType.getShape()); originalScaleShape.insert(originalScaleShape.end(), @@ -507,44 +508,52 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, int64_t blockSize = computeProduct(ratio); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outType, rewriter.getFloatAttr(outType, 0.0)); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero); + Value zero = arith::ConstantOp::create(rewriter, loc, outType, + rewriter.getFloatAttr(outType, 0.0)); + Value result = + rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero); for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector<int64_t> strides(offsets.size(), 1); - Value block = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, offsets, ratio, strides); + Value block = vector::ExtractStridedSliceOp::create( + rewriter, loc, in, offsets, ratio, strides); VectorType block1DType = VectorType::get(blockSize, inType); Value block1D = - rewriter.create<vector::ShapeCastOp>(loc, block1DType, block); + vector::ShapeCastOp::create(rewriter, loc, block1DType, block); Value uniformScale = - rewriter.create<vector::ExtractOp>(loc, scale, offsets); + vector::ExtractOp::create(rewriter, loc, scale, offsets); VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = - rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero); + rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); + for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i); i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 - Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>( - loc, extScaleResultType, slice, uniformScale, 0); - if (sliceWidth != opWidth) - scaleExt = rewriter.create<vector::ExtractStridedSliceOp>( - loc, scaleExt, 0, sliceWidth, 1); - blockResult = rewriter.create<vector::InsertStridedSliceOp>( - loc, scaleExt, blockResult, i, 1); + i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) { + Value inSlice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, inSliceWidth, 1); + for (int64_t j = 0, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j); + j < inSliceWidth; j += outSliceWidth, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) { + // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inSlice, uniformScale, + j / opOutWidth); + if (outSliceWidth < opOutWidth) { + scaleExt = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleExt, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleExt, blockResult, i + j, 1); + } } VectorType resultType = VectorType::get(ratio, outType); Value cast = - rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result, - offsets, strides); + vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult); + result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result, + offsets, strides); } rewriter.replaceOp(op, result); @@ -556,7 +565,7 @@ LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opInWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -569,28 +578,28 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); - if (outVecType && outVecType.isScalable()) return failure(); Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; if (scaleType.getIntOrFloatBitWidth() < 32) - scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale); + scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale); else if (scaleType.getIntOrFloatBitWidth() > 32) - scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale); + scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outType, rewriter.getFloatAttr(outType, 0.0)); - unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); - VectorType truncScaleResultType = VectorType::get(numPackedElem, outType); + Value zero = arith::ConstantOp::create(rewriter, loc, outType, + rewriter.getFloatAttr(outType, 0.0)); + int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth(); + VectorType truncScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Type inVecType = VectorType::get(1, inType); - Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in); + Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in); // TODO: replace this with non-packed ScaledTruncOp - Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>( - loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr); + Value scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, inCast, scale, 0, + /*existing=*/nullptr); scaleTrunc = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0); return success(); @@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); - SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) - llvm::append_range(originalScaleShape, origScaleVecType.getShape()); + SmallVector<int64_t> scaleShape; + if (origScaleVecType) + llvm::append_range(scaleShape, origScaleVecType.getShape()); - originalScaleShape.insert(originalScaleShape.end(), - inShape.size() - originalScaleShape.size(), 1); + scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1); - auto maybeRatio = computeShapeRatio(inShape, originalScaleShape); + auto maybeRatio = computeShapeRatio(inShape, scaleShape); assert(maybeRatio && "failed to derive block size from broadcast or splat operation"); @@ -616,45 +625,62 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, int64_t blockSize = computeProduct(ratio); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero); + Value result = + rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero); for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector<int64_t> strides(offsets.size(), 1); - Value block = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, offsets, ratio, strides); + Value block = vector::ExtractStridedSliceOp::create( + rewriter, loc, in, offsets, ratio, strides); VectorType block1DType = VectorType::get(blockSize, inType); Value block1D = - rewriter.create<vector::ShapeCastOp>(loc, block1DType, block); + vector::ShapeCastOp::create(rewriter, loc, block1DType, block); Value uniformScale = - rewriter.create<vector::ExtractOp>(loc, scale, offsets); + vector::ExtractOp::create(rewriter, loc, scale, offsets); VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = - rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero); - - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); - i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 - Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>( - loc, truncScaleResultType, slice, uniformScale, 0, - /*existing=*/nullptr); - int64_t packedWidth = - cast<VectorType>(scaleTrunc.getType()).getNumElements(); - if (packedWidth != opWidth) - scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>( - loc, scaleTrunc, 0, sliceWidth, 1); - blockResult = rewriter.create<vector::InsertStridedSliceOp>( - loc, scaleTrunc, blockResult, i, 1); + rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); + + for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i); + i < blockSize; i += outSliceWidth, + outSliceWidth = std::min(opOutWidth, blockSize - i)) { + Value scaleTrunc; + // Case where <= 2 elements are being truncated. + if (outSliceWidth <= opInWidth) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, outSliceWidth, 1); + // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, 0, + /*existing=*/nullptr); + } else { + scaleTrunc = vector::BroadcastOp::create(rewriter, loc, + truncScaleResultType, zero); + for (int64_t j = 0, + inSliceWidth = std::min(opInWidth, outSliceWidth - j); + j < outSliceWidth; j += opInWidth, + inSliceWidth = std::min(opInWidth, outSliceWidth - j)) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i + j, inSliceWidth, 1); + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, + j / opInWidth, scaleTrunc); + } + } + if (outSliceWidth != opOutWidth) { + scaleTrunc = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleTrunc, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleTrunc, blockResult, i, 1); } VectorType resultType = VectorType::get(ratio, outType); Value cast = - rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result, - offsets, strides); + vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult); + result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result, + offsets, strides); } rewriter.replaceOp(op, result); @@ -664,19 +690,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, void mlir::arith::populateArithToAMDGPUConversionPatterns( RewritePatternSet &patterns, bool convertFP8Arithmetic, - bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { + bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset, + PatternBenefit benefit) { if (convertFP8Arithmetic) { - patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset); - patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(), - saturateFP8Truncf, chipset); + patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset, + benefit); + patterns.add<TruncFToFloat8RewritePattern>( + patterns.getContext(), saturateFP8Truncf, chipset, benefit); } if (allowPackedF16Rtz) - patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext()); + patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit); if (chipset >= kGfx950) { - patterns.add<ScalingExtFRewritePattern>(patterns.getContext()); - patterns.add<ScalingTruncFRewritePattern>(patterns.getContext()); + patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit); + patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit); } } |