diff options
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp')
-rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp | 174 |
1 files changed, 173 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index 61edce5..b00c65c 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -13,13 +13,17 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include <cassert> +#include <cstdint> using namespace mlir; @@ -58,7 +62,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); if (elemBitwidth >= maxShuffleBitwidth) return rewriter.notifyMatchFailure( - op, llvm::formatv("element type too large {0}, cannot break down " + op, llvm::formatv("element type too large ({0}), cannot break down " "into vectors of bitwidth {1} or less", elemBitwidth, maxShuffleBitwidth)); @@ -139,6 +143,167 @@ struct ScalarizeSingleElementReduce final } }; +/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn` +/// and `unpackFn` to convert to the native shuffle type and to the reduction +/// type, respectively. For example, with `input` of type `f16`, `packFn` could +/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn` +/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that +/// the subgroup is `subgroupSize` lanes wide and reduces across all of them. +static Value createSubgroupShuffleReduction( + OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode, + unsigned subgroupSize, function_ref<Value(Value)> packFn, + function_ref<Value(Value)> unpackFn) { + assert(llvm::isPowerOf2_32(subgroupSize)); + // Lane value always stays in the original type. We use it to perform arith + // reductions. + Value laneVal = input; + // Parallel reduction using butterfly shuffles. + for (unsigned i = 1; i < subgroupSize; i <<= 1) { + Value shuffled = builder + .create<gpu::ShuffleOp>(loc, packFn(laneVal), i, + /*width=*/subgroupSize, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); + laneVal = vector::makeArithReduction(builder, loc, + gpu::convertReductionKind(mode), + laneVal, unpackFn(shuffled)); + assert(laneVal.getType() == input.getType()); + } + + return laneVal; +} + +/// Lowers scalar gpu subgroup reductions to a series of shuffles. +struct ScalarSubgroupReduceToShuffles final + : OpRewritePattern<gpu::SubgroupReduceOp> { + ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, + unsigned shuffleBitwidth, + PatternBenefit benefit) + : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), + shuffleBitwidth(shuffleBitwidth) {} + + LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, + PatternRewriter &rewriter) const override { + Type valueTy = op.getType(); + unsigned elemBitwidth = + getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth(); + if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth) + return rewriter.notifyMatchFailure( + op, "value type is not a compatible scalar"); + + Location loc = op.getLoc(); + // Since this is already a native shuffle scalar, no packing is necessary. + if (elemBitwidth == shuffleBitwidth) { + auto identityFn = [](Value v) { return v; }; + rewriter.replaceOp(op, createSubgroupShuffleReduction( + rewriter, loc, op.getValue(), op.getOp(), + subgroupSize, identityFn, identityFn)); + return success(); + } + + auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); + auto equivIntType = rewriter.getIntegerType(elemBitwidth); + auto packFn = [loc, &rewriter, equivIntType, + shuffleIntType](Value unpackedVal) -> Value { + auto asInt = + rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal); + return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt); + }; + auto unpackFn = [loc, &rewriter, equivIntType, + valueTy](Value packedVal) -> Value { + auto asInt = + rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal); + return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt); + }; + + rewriter.replaceOp(op, createSubgroupShuffleReduction( + rewriter, loc, op.getValue(), op.getOp(), + subgroupSize, packFn, unpackFn)); + return success(); + } + +private: + unsigned subgroupSize = 0; + unsigned shuffleBitwidth = 0; +}; + +/// Lowers vector gpu subgroup reductions to a series of shuffles. +struct VectorSubgroupReduceToShuffles final + : OpRewritePattern<gpu::SubgroupReduceOp> { + VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, + unsigned shuffleBitwidth, + PatternBenefit benefit) + : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), + shuffleBitwidth(shuffleBitwidth) {} + + LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, + PatternRewriter &rewriter) const override { + auto vecTy = dyn_cast<VectorType>(op.getType()); + if (!vecTy) + return rewriter.notifyMatchFailure(op, "value type is not a vector"); + + unsigned vecBitwidth = + vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); + if (vecBitwidth > shuffleBitwidth) + return rewriter.notifyMatchFailure( + op, + llvm::formatv("vector type bitwidth too large ({0}), cannot lower " + "to shuffles of size {1}", + vecBitwidth, shuffleBitwidth)); + + unsigned elementsPerShuffle = + shuffleBitwidth / vecTy.getElementTypeBitWidth(); + if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth) + return rewriter.notifyMatchFailure( + op, "shuffle bitwidth is not a multiple of the element bitwidth"); + + Location loc = op.getLoc(); + + // If the reduced type is smaller than the native shuffle size, extend it, + // perform the shuffles, and extract at the end. + auto extendedVecTy = VectorType::get( + static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType()); + Value extendedInput = op.getValue(); + if (vecBitwidth < shuffleBitwidth) { + auto zero = rewriter.create<arith::ConstantOp>( + loc, rewriter.getZeroAttr(extendedVecTy)); + extendedInput = rewriter.create<vector::InsertStridedSliceOp>( + loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1); + } + + auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); + auto shuffleVecType = VectorType::get(1, shuffleIntType); + + auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value { + auto asIntVec = + rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal); + return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0); + }; + auto unpackFn = [loc, &rewriter, shuffleVecType, + extendedVecTy](Value packedVal) -> Value { + auto asIntVec = + rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal); + return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec); + }; + + Value res = + createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(), + subgroupSize, packFn, unpackFn); + + if (vecBitwidth < shuffleBitwidth) { + res = rewriter.create<vector::ExtractStridedSliceOp>( + loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(), + /*strides=*/1); + } + + rewriter.replaceOp(op, res); + return success(); + } + +private: + unsigned subgroupSize = 0; + unsigned shuffleBitwidth = 0; +}; } // namespace void mlir::populateGpuBreakDownSubgrupReducePatterns( @@ -148,3 +313,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns( maxShuffleBitwidth, benefit); patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit); } + +void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs( + RewritePatternSet &patterns, unsigned subgroupSize, + unsigned shuffleBitwidth, PatternBenefit benefit) { + patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( + patterns.getContext(), subgroupSize, shuffleBitwidth, benefit); +} |