diff options
Diffstat (limited to 'mlir/lib/Dialect/GPU')
| -rw-r--r-- | mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp | 28 |
3 files changed, 31 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index b5f8dda..6c6d8d2 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, void WarpExecuteOnLane0Op::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index d2c2138..025d1ac 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -330,7 +330,7 @@ static Value getBase(Value v) { v = op.getSrc(); return true; }) - .Default([](Operation *) { return false; }); + .Default(false); if (!shouldContinue) break; } @@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) { .Case([](memref::TransposeOp transpose) { return transpose.getIn(); }) .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>( [](auto op) { return op.getSrc(); }) - .Default([](Operation *) { return Value(); }); + .Default(nullptr); } /// Returns `true` if the given operation is known to capture the given value, @@ -371,7 +371,7 @@ static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) { // These operations are known not to capture. .Case([](memref::DeallocOp) { return false; }) // By default, we don't know anything. - .Default([](Operation *) { return std::nullopt; }); + .Default(std::nullopt); } /// Returns `true` if the value may be captured by any of its users, i.e., if diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index 81c3069..ec1571a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -416,13 +416,39 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, if (ci.clusterSize >= 32) { if (chipset.majorVersion <= 9) { // Broadcast last value from each row to next row. - // Use row mask to avoid polluting rows 1 and 3. + // Use row mask to avoid polluting row 0 (and row 2 if wave-64). dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15, rewriter.getUnitAttr(), 0xa, allBanks, /*bound_ctrl*/ false); res = vector::makeArithReduction( rewriter, loc, gpu::convertReductionKind(mode), res, dpp); + + // For subgroupSize = 64, at this point lanes [16, 32) contain the full + // reduction over lanes [0, 32), but lanes [0, 16) do not. Similarly, + // lanes [48, 64) contain the full reduction over lanes [32, 64), but + // lanes [32, 48) do not. + // + // If subgroup size is 64 and cluster size is 64, we don't need lanes [0, + // 16) and [32, 48) to have the correct cluster-32 reduction values at + // this point, because only lane 63's value will ultimately be read in + // this full-cluster case. + // + // If subgroup size is 64 and cluster size is 32, we need to ensure that + // lanes [0, 16) and [32, 48) have the correct final cluster-32 reduction + // values (subgroup_reduce guarantees that all lanes within each cluster + // contain the final reduction value). We do this by broadcasting lane + // 31's value to lanes [0, 16) and lanes 63's value to lanes [32, 48). + // + // See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations + // for an illustration of how this within-cluster broadcast works with a + // swizzle. + if (ci.subgroupSize == 64 && ci.clusterSize == 32) { + res = + amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, /*and_mask=*/0, + /*or_mask=*/31, + /*xor_mask=*/0); + } } else if (chipset.majorVersion <= 12) { // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2). Value uint32Max = arith::ConstantOp::create( |
