diff options
Diffstat (limited to 'mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp')
-rw-r--r-- | mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 101 |
1 files changed, 63 insertions, 38 deletions
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 8c68b57..8230591 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -449,7 +449,7 @@ LogicalResult ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opOutWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -460,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleType = getElementTypeOrSelf(scale); Type outType = getElementTypeOrSelf(out); + int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth(); + VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); @@ -473,7 +475,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, else if (scaleType.getIntOrFloatBitWidth() > 32) scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - VectorType extScaleResultType = VectorType::get(opWidth, outType); + VectorType extScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Value inCast = vector::BroadcastOp::create(rewriter, loc, @@ -487,10 +489,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) + if (origScaleVecType) llvm::append_range(originalScaleShape, origScaleVecType.getShape()); originalScaleShape.insert(originalScaleShape.end(), @@ -524,19 +527,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Value blockResult = rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); + for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i); i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = vector::ExtractStridedSliceOp::create( - rewriter, loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 - Value scaleExt = amdgpu::ScaledExtPackedOp::create( - rewriter, loc, extScaleResultType, slice, uniformScale, 0); - if (sliceWidth != opWidth) - scaleExt = vector::ExtractStridedSliceOp::create( - rewriter, loc, scaleExt, 0, sliceWidth, 1); - blockResult = vector::InsertStridedSliceOp::create( - rewriter, loc, scaleExt, blockResult, i, 1); + i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) { + Value inSlice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, inSliceWidth, 1); + for (int64_t j = 0, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j); + j < inSliceWidth; j += outSliceWidth, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) { + // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inSlice, uniformScale, + j / opOutWidth); + if (outSliceWidth < opOutWidth) { + scaleExt = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleExt, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleExt, blockResult, i + j, 1); + } } VectorType resultType = VectorType::get(ratio, outType); @@ -555,7 +565,7 @@ LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opInWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -568,7 +578,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); - if (outVecType && outVecType.isScalable()) return failure(); @@ -581,8 +590,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, Value zero = arith::ConstantOp::create(rewriter, loc, outType, rewriter.getFloatAttr(outType, 0.0)); - unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); - VectorType truncScaleResultType = VectorType::get(numPackedElem, outType); + int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth(); + VectorType truncScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Type inVecType = VectorType::get(1, inType); @@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); - SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) - llvm::append_range(originalScaleShape, origScaleVecType.getShape()); + SmallVector<int64_t> scaleShape; + if (origScaleVecType) + llvm::append_range(scaleShape, origScaleVecType.getShape()); - originalScaleShape.insert(originalScaleShape.end(), - inShape.size() - originalScaleShape.size(), 1); + scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1); - auto maybeRatio = computeShapeRatio(inShape, originalScaleShape); + auto maybeRatio = computeShapeRatio(inShape, scaleShape); assert(maybeRatio && "failed to derive block size from broadcast or splat operation"); @@ -633,20 +642,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, Value blockResult = rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); - i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = vector::ExtractStridedSliceOp::create( - rewriter, loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 - Value scaleTrunc = amdgpu::PackedScaledTruncOp::create( - rewriter, loc, truncScaleResultType, slice, uniformScale, 0, - /*existing=*/nullptr); - int64_t packedWidth = - cast<VectorType>(scaleTrunc.getType()).getNumElements(); - if (packedWidth != opWidth) + for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i); + i < blockSize; i += outSliceWidth, + outSliceWidth = std::min(opOutWidth, blockSize - i)) { + Value scaleTrunc; + // Case where <= 2 elements are being truncated. + if (outSliceWidth <= opInWidth) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, outSliceWidth, 1); + // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, 0, + /*existing=*/nullptr); + } else { + scaleTrunc = vector::BroadcastOp::create(rewriter, loc, + truncScaleResultType, zero); + for (int64_t j = 0, + inSliceWidth = std::min(opInWidth, outSliceWidth - j); + j < outSliceWidth; j += opInWidth, + inSliceWidth = std::min(opInWidth, outSliceWidth - j)) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i + j, inSliceWidth, 1); + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, + j / opInWidth, scaleTrunc); + } + } + if (outSliceWidth != opOutWidth) { scaleTrunc = vector::ExtractStridedSliceOp::create( - rewriter, loc, scaleTrunc, 0, sliceWidth, 1); + rewriter, loc, scaleTrunc, 0, outSliceWidth, 1); + } blockResult = vector::InsertStridedSliceOp::create( rewriter, loc, scaleTrunc, blockResult, i, 1); } |