diff options
Diffstat (limited to 'mlir/lib/Dialect/XeGPU')
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 390 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 137 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt | 17 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp | 695 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp | 490 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 556 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 619 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 27 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 267 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 97 |
12 files changed, 2950 insertions, 347 deletions
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6..46b8251 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index f9aa28d5..1a19ab5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -8,10 +8,8 @@ #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -38,55 +36,61 @@ void XeGPUDialect::initialize() { >(); } -/// Generates instructions to compute offsets for a subgroup identified by -/// its multidimensional indices (sgId), using the specified subgroup layout -/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data -/// dimensions (sizePerWg). +// A `srcShape` consists of N distribution units, each being `subShapesLayout` x +// `subShape`. A `delinearizedId` is used to identify a particular `subShape` +// within each distribution unit. +// Example: +// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a +// distribution unit of shape 64x64, we have 2x4 such distribution units. +// `delinearizedId` is used to identify a 16x32 of a subgroup in each +// distribution unit. static SmallVector<SmallVector<Value>> -genOffsetsComputingInsts(OpBuilder &builder, Location loc, - SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout, - ArrayRef<int64_t> sizePerSg, - ArrayRef<int64_t> sizePerWg) { - - SmallVector<SmallVector<Value>> offsets; +genCoordinates(OpBuilder &builder, Location loc, + SmallVector<Value> delinearizedId, + ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape, + ArrayRef<int64_t> srcShape) { + SmallVector<SmallVector<Value>> coordinates; + + // A distribution unit must be less than or equal to `srcShape` + SmallVector<int64_t> distUnitShape = llvm::map_to_vector( + llvm::zip_equal(srcShape, + computeElementwiseMul(subShapesLayout, subShape)), + [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] - SmallVector<Value> localOffsets = llvm::map_to_vector( - llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { - return builder.createOrFold<index::MulOp>( + // Get the offset of `subShape` within a distribution unit. + SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector( + llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value { + return builder.createOrFold<arith::MulIOp>( loc, std::get<0>(t), builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t))); }); - // distUnit[i] is the minimum value between sizePerWg[i] and - // sgLayout[i] * sizePerSg[i] - SmallVector<int64_t> distUnit = llvm::map_to_vector( - llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), - [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - + // For each dist unit for (SmallVector<int64_t> unitOffs : - StaticTileOffsetRange(sizePerWg, distUnit)) { + StaticTileOffsetRange(srcShape, distUnitShape)) { + // Get dist unit offset within `srcShape`. SmallVector<Value> base = llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value { return arith::ConstantIndexOp::create(builder, loc, d); }); - - SmallVector<Value> adds = llvm::map_to_vector( - llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { - return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t), - std::get<1>(t)); - }); - + // Calculate `subShape` offset within `srcShape`. + SmallVector<Value> adds = + llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset), + [&](const auto &t) -> Value { + return builder.createOrFold<arith::AddIOp>( + loc, std::get<0>(t), std::get<1>(t)); + }); + // Do not go beyond `srcShape` bounds. SmallVector<Value> mods = llvm::map_to_vector( - llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { - return builder.createOrFold<index::RemUOp>( + llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value { + return builder.createOrFold<arith::RemUIOp>( loc, std::get<0>(t), arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); }); - offsets.push_back(mods); + coordinates.push_back(mods); } - return offsets; + return coordinates; } // Checks if the given shape can be evenly distributed based on the layout @@ -229,8 +233,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, } if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) { - return emitError() - << "expected inst_data and lane_layout to have the same rank"; + return emitError() << "expected inst_data and lane_layout to have the same " + "rank, got inst_data " + << inst_data.size() << ", lane_layout " + << lane_layout.size(); } // sg_data is optional for Workgroup layout, but its presence requires @@ -271,56 +277,197 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, } FailureOr<SmallVector<Value>> -LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, - Value linearId) { - // delinearizeSubgroupId is only available for - // workgroup-level layout attribute - if (!isForWorkgroup()) +LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { + + SmallVector<int64_t> sgLayoutInt; + if (isForWorkgroup()) { + sgLayoutInt = getEffectiveSgLayoutAsInt(); + } else if (isForSubgroup()) { + sgLayoutInt = getEffectiveLaneLayoutAsInt(); + } else { return failure(); + } - // TODO: handle order attribute - auto hasDefaultOrder = [&]() { - DenseI32ArrayAttr order = getOrder(); - return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( - llvm::reverse(order.asArrayRef()))); - }; - if (!hasDefaultOrder()) - return mlir::emitError(loc, "order attribute is currently not supported."); + DenseI32ArrayAttr orderAttr = getOrder(); - auto dims = - llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value { - return builder.createOrFold<arith::ConstantIndexOp>(loc, d); - }); + // Handle order attribute + SmallVector<int64_t> order; + if (orderAttr && !orderAttr.empty()) { + order = llvm::to_vector( + llvm::map_range(orderAttr.asArrayRef(), + [](int32_t idx) { return static_cast<int64_t>(idx); })); + } else { + // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc. + order = llvm::to_vector( + llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size()))); + } - return affine::delinearizeIndex(builder, loc, linearId, dims); + if (order.size() != sgLayoutInt.size()) { + return failure(); + } + + SmallVector<Value> result(sgLayoutInt.size()); + Value remaining = linearId; + + /// Process dimensions in the order they appear in the order array + /// The first dimension in order is the fastest-changing + /// + /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]: + /// + /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx], + /// result=[?,?,?] + /// + /// i=0 (process columns, dimIdx=2, dimSize=4): + /// result[2] = 22 % 4 = 2 (column coordinate) + /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed) + /// + /// i=1 (process rows, dimIdx=1, dimSize=4): + /// result[1] = 5 % 4 = 1 (row coordinate) + /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed) + /// + /// i=2 (process layers, dimIdx=0, dimSize=2): + /// result[0] = 1 % 2 = 1 (layer coordinate) + /// (no remaining update - last iteration) + /// + /// Final result: [1,1,2] = Layer 1, Row 1, Column 2 + for (size_t i = 0; i < order.size(); ++i) { + int64_t dimIdx = order[i]; + int64_t dimSize = sgLayoutInt[dimIdx]; + + Value dimSizeVal = + builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize); + + /// Extract the coordinate for this dimension using modulo operation + /// This gives us "how far within this dimension" we are + /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within + /// this dimension) + result[dimIdx] = + builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal); + + /// Update remaining for the next dimension by removing what we've already + /// processed. Division tells us "how many complete groups of this dimension + /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've + /// completed 5 groups of 4) Skip this for the last iteration since there's + /// no next dimension to process + if (i < order.size() - 1) { + remaining = + builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal); + } + } + return result; } -/// Implements DistributeLayoutAttr::getOffsets to generate +/// Implements DistributeLayoutAttr::computeDistributedCoords to generate /// instructions for computing multi-dimensional offsets when distributed by /// LayoutAttr. FailureOr<SmallVector<SmallVector<Value>>> -LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, - ArrayRef<int64_t> shape) { - if (!isForWorkgroup()) +LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc, + Value linearId, ArrayRef<int64_t> shape) { + SmallVector<int64_t> layout; + SmallVector<int64_t> subShape; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + subShape = getEffectiveSgDataAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + subShape = getEffectiveLaneDataAsInt(); + } else { return failure(); - - SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); - SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); - if (sgShape.empty()) { - if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); + } + if (subShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, layout)) + subShape = derivedShape.value(); else return failure(); } // delinearize Ids - auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + auto maybeIds = delinearizeId(builder, loc, linearId); if (failed(maybeIds)) return failure(); - SmallVector<Value> sgIds = *maybeIds; + SmallVector<Value> ids = *maybeIds; + + return genCoordinates(builder, loc, ids, layout, subShape, shape); +} + +bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::SliceAttr>(other)) + return false; + + return *this == dyn_cast<xegpu::LayoutAttr>(other); +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) { + auto sgDataOpt = getSgData(); + auto instDataOpt = getInstData(); + auto laneDataOpt = getLaneData(); + + SmallVector<int32_t> sgData; + SmallVector<int32_t> instData; + SmallVector<int32_t> laneData; + + if (sgDataOpt) { + sgData = llvm::to_vector(sgDataOpt.asArrayRef()); + } + if (instDataOpt) { + instData = llvm::to_vector(instDataOpt.asArrayRef()); + } + if (laneDataOpt) { + laneData = llvm::to_vector(laneDataOpt.asArrayRef()); + } - return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, - shape); + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgData.size())) + sgData[dim] = 1; + if (dim < static_cast<int64_t>(instData.size())) + instData[dim] = 1; + if (dim < static_cast<int64_t>(laneData.size())) + laneData[dim] = 1; + } + + return LayoutAttr::get( + getContext(), getSgLayout(), + sgData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgData), + instData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), instData), + getLaneLayout(), + laneData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneData), + getOrder()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + auto sgLayoutOpt = getSgLayout(); + auto laneLayoutOpt = getLaneLayout(); + + SmallVector<int32_t> sgLayout; + SmallVector<int32_t> laneLayout; + + if (sgLayoutOpt) { + sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef()); + } + if (laneLayoutOpt) { + laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef()); + } + + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgLayout.size())) + sgLayout[dim] = 1; + if (dim < static_cast<int64_t>(laneLayout.size())) + laneLayout[dim] = 1; + } + + return LayoutAttr::get( + getContext(), + sgLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgLayout), + getSgData(), getInstData(), + laneLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneLayout), + getLaneData(), getOrder()); } //===----------------------------------------------------------------------===// @@ -374,34 +521,43 @@ SliceAttr SliceAttr::flatten() const { } FailureOr<SmallVector<Value>> -SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, - Value linearId) { +SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { SliceAttr attr = flatten(); auto parent = dyn_cast<LayoutAttr>(attr.getParent()); - return parent.delinearizeSubgroupId(builder, loc, linearId); + return parent.delinearizeId(builder, loc, linearId); } -/// Implements DistributeLayoutAttr::getOffsets to generate -/// instructions for computing multi-dimensional offsets when distributed by -/// SliceAttr. +// Implements DistributeLayoutAttr::computeDistributedCoords to generate +// instructions for computing multi-dimensional offsets when distributed by +// LayoutAttr. FailureOr<SmallVector<SmallVector<Value>>> -SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, - ArrayRef<int64_t> shape) { +SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc, + Value linearId, ArrayRef<int64_t> shape) { assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape."); if (!isForWorkgroup()) return failure(); - SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); - SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); - if (sgShape.empty()) { - if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); + SmallVector<int64_t> layout; + SmallVector<int64_t> subShape; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + subShape = getEffectiveSgDataAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + subShape = getEffectiveLaneDataAsInt(); + } else { + return failure(); + } + + if (subShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, layout)) + subShape = derivedShape.value(); else return failure(); } // delinearize Ids - auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + auto maybeIds = delinearizeId(builder, loc, linearId); if (failed(maybeIds)) return failure(); @@ -411,8 +567,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, SmallVector<Value> sgIds = XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims); - return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, - shape); + return genCoordinates(builder, loc, sgIds, layout, subShape, shape); } bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { @@ -435,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { [&](int64_t dim) { return thisDims.contains(dim); }); } +bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::LayoutAttr>(other)) + return false; + + auto flattenedThis = flatten(); + auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten(); + + return ((flattenedThis.getParent() == flattenedOther.getParent()) && + (flattenedThis.getDims() == flattenedOther.getDims())); +} + +// Helper function to adjust unit dimensions from sliced space to parent space +static SetVector<int64_t> +adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims, + ArrayRef<int64_t> sliceDims) { + // Reconstruct parent's non-sliced dimensions + + int64_t parentRank = sliceDims.size() + unitDims.size(); + llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(), + sliceDims.end()); + SmallVector<int64_t> nonSlicedDims; + for (int64_t i = 0; i < parentRank; ++i) { + if (!slicedDimsSet.contains(i)) + nonSlicedDims.push_back(i); + } + + // Map unit dims from sliced space to parent space + SetVector<int64_t> adjustUnitDims; + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(nonSlicedDims.size())) { + adjustUnitDims.insert(nonSlicedDims[dim]); + } + } + + return adjustUnitDims; +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims), + attr.getDims()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims), + attr.getDims()); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// @@ -569,8 +787,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError, // for gather and scatter ops, Low-precision types are packed in 32-bit units. unsigned bitWidth = elementType.getIntOrFloatBitWidth(); int chunkAlignmentFactor = - bitWidth < targetinfo::packedSizeInBitsForGatherScatter - ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth + bitWidth < xegpu::uArch::generalPackedFormatBitSize + ? xegpu::uArch::generalPackedFormatBitSize / bitWidth : 1; auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding); if (scatterAttr) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index abd12e2..91ba07a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, - UnitAttr subgroup_block_io, + UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref<InFlightDiagnostic()> emitError) { if (!dataTy) { if (subgroup_block_io) return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; + "are only allowed when result is a VectorType."; else return success(); } @@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, ArrayRef<int64_t> dataShape = dataTy.getShape(); ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); + ArrayAttr strideAttr = mdescTy.getStrideAttr(); + SmallVector<int64_t> strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast<IntegerAttr>(attr).getInt()); + } + if (subgroup_block_io && layout) { + auto laneData = layout.getEffectiveLaneDataAsInt(); + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + if (!laneData.empty()) { + bool isLaneDataContiguous = + std::all_of(laneData.begin(), std::prev(laneData.end()), + [](int x) { return x == 1; }); + if (!isLaneDataContiguous) + return emitError() << "With subgroup_block_io, accessed data must be " + "contiguous and coalesced."; + for (size_t i = 0; i < laneData.size(); ++i) { + if (laneLayout[i] != blockShape[i]) + return emitError() << "With subgroup_block_io, the block shape must " + "match the lane layout."; + if (laneLayout[i] != 1 && strides[i] != 1) + return emitError() << "With subgroup_block_io, the distributed " + "dimensions must be contiguous."; + } + } + } if (dataShape.size() == 2) { - if (subgroup_block_io) - return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitError() << "data shape must not exceed mem_desc shape."; } else { - SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); // if the subgroup_block_io attribute is set, mdescTy must have block // attribute if (subgroup_block_io && !blockShape.size()) @@ -258,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them - // to keep the IR print clean. - if (staticShape == memrefShape && staticStrides == memrefStrides) { + // to keep the IR print clean (only do so for full-static case, otherwise + // printer would fail trying to print empty array-attr). + if (staticShape == memrefShape && staticStrides == memrefStrides && + dynamicShape.empty() && dynamicStrides.empty()) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } @@ -320,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them - // to keep the IR print clean. - if (staticShape == memrefShape && staticStrides == memrefStrides) { + // to keep the IR print clean (only do so for full-static case, otherwise + // printer would fail trying to print empty array-attr). + if (staticShape == memrefShape && staticStrides == memrefStrides && + dynamicShape.empty() && dynamicStrides.empty()) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } @@ -439,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -454,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult PrefetchNdOp::verify() { @@ -472,11 +499,8 @@ LogicalResult PrefetchNdOp::verify() { return emitOpError("invalid l3_hint: ") << getL3HintAttr(); int64_t tDescRank = tdescTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -496,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, - l3_hint); + l3_hint, /*anchor_layout=*/nullptr); } void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, @@ -504,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -512,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, - packed, transpose, l1_hint, l2_hint, l3_hint); + packed, transpose, l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/layout); } LogicalResult LoadNdOp::verify() { @@ -597,11 +623,8 @@ LogicalResult LoadNdOp::verify() { << tdescTy; int64_t tDescRank = tdescTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -618,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, value, tensorDesc, ValueRange(), - DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); + DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/nullptr); } void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -633,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult StoreNdOp::verify() { @@ -691,11 +716,8 @@ LogicalResult StoreNdOp::verify() { << dstTy; int64_t tDescRank = dstTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -809,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, - IntegerAttr{}); + IntegerAttr{}, /*anchor_layout=*/nullptr); } //===----------------------------------------------------------------------===// @@ -859,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, valueType, source, Value(), mask, IntegerAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -875,7 +897,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, + ArrayRef<OpFoldResult> offsets, Value mask, + IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, + DistributeLayoutAttr layout) { + auto loc = source.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, + l2_hint, l3_hint, layout); } //===----------------------------------------------------------------------===// @@ -926,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, @@ -944,7 +983,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, // Call the correct builder overload that does not expect result types. build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, - l3_hint); + l3_hint, /*anchor_layout=*/nullptr); +} + +void StoreScatterOp::build( + OpBuilder &builder, OperationState &state, Value value, Value dest, + ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size, + xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) { + auto loc = dest.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + // Call the correct builder overload that does not expect result types. + build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, + l3_hint, layout); } //===----------------------------------------------------------------------===// @@ -1105,7 +1160,7 @@ LogicalResult LoadMatrixOp::verify() { MemDescType mdescTy = getMemDesc().getType(); return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + getLayoutAttr(), [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1129,7 +1184,7 @@ LogicalResult StoreMatrixOp::verify() { UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); MemDescType mdescTy = getMemDesc().getType(); return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + getLayoutAttr(), [&]() { return emitError(); }); } namespace mlir { diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..48fe841 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRXeGPUTransformOps + XeGPUTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/ + + DEPENDS + MLIRXeGPUTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRXeGPUDialect + MLIRXeGPUTransforms + MLIRIR + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp new file mode 100644 index 0000000..e6009d5 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -0,0 +1,695 @@ +//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" + +#include <optional> + +#include "llvm/Support/DebugLog.h" +#define DEBUG_TYPE "xegpu-transforms" + +using namespace mlir; +using namespace mlir::transform; + +/// Assuming that `ofr` is an index attr or a param of index type +/// or a transform dialect handle mapped to exactly one op +/// with one index result, get that value and cast it to int type. +static DiagnosedSilenceableFailure convertMixedValuesToInt( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) { + for (OpFoldResult ofr : ofrs) { + // Attribute case. + if (auto attr = dyn_cast<Attribute>(ofr)) { + if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { + result.push_back(intAttr.getInt()); + continue; + } + return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; + } + + // Transform param case. + Value transformValue = cast<Value>(ofr); + if (isa<TransformParamTypeInterface>(transformValue.getType())) { + ArrayRef<Attribute> params = state.getParams(transformValue); + if (params.size() != 1) + return transformOp.emitDefiniteFailure() + << "requires exactly one parameter associated"; + result.push_back( + cast<IntegerAttr>(params.front()).getValue().getSExtValue()); + continue; + } + + // Payload value case. + auto payloadOps = state.getPayloadOps(transformValue); + if (!llvm::hasSingleElement(payloadOps)) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "handle must be mapped to exactly one payload op"; + diag.attachNote(transformValue.getLoc()) + << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; + return diag; + } + + Operation *op = *payloadOps.begin(); + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + + IntegerAttr intAttr; + if (!matchPattern(op->getResult(0), m_Constant(&intAttr))) + return transformOp.emitSilenceableError() + << "requires param or handle to be the result of a constant like " + "op"; + + result.push_back(intAttr.getInt()); + } + return DiagnosedSilenceableFailure::success(); +} + +/// Find producer operation of type T for the given value. +/// It's assumed that producer ops are chained through their first operand. +/// Producer chain is traced trough loop block arguments (init values). +template <typename T> +static std::optional<T> findProducerOfType(Value val) { + Value currentValue = val; + if (!currentValue.getDefiningOp()) { + // Value may be a block argument initialized outside a loop. + if (val.getNumUses() == 0) { + LDBG() << "Failed to find producer op, value has no uses."; + return std::nullopt; + } + auto userOp = val.getUsers().begin(); + auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>(); + if (!parentLoop) { + LDBG() << "Failed to find producer op, not in a loop."; + return std::nullopt; + } + int64_t iterArgIdx; + if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) { + auto numInductionVars = parentLoop.getLoopInductionVars()->size(); + iterArgIdx = iterArg.getArgNumber() - numInductionVars; + currentValue = parentLoop.getInits()[iterArgIdx]; + } else { + LDBG() << "Failed to find producer op, value not in init values."; + return std::nullopt; + } + } + Operation *producerOp = currentValue.getDefiningOp(); + + if (auto matchingOp = dyn_cast<T>(producerOp)) + return matchingOp; + + if (producerOp->getNumOperands() == 0) + return std::nullopt; + + return findProducerOfType<T>(producerOp->getOperand(0)); +} + +/// Create a layout attribute from the given parameters. +static xegpu::LayoutAttr +createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, + ArrayRef<int32_t> sgData, + std::optional<ArrayRef<int32_t>> instData) { + return xegpu::LayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, sgLayout), + DenseI32ArrayAttr::get(ctx, sgData), + instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, + /*lane_layout=*/nullptr, + /*lane_data=*/nullptr, + /*order=*/nullptr); +} + +/// Generate `xegpu::LayoutAttr` from op mixed layout values. +DiagnosedSilenceableFailure +getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, + TransformOpInterface transformOp, + ArrayRef<::mlir::OpFoldResult> mixedSgLayout, + ArrayRef<::mlir::OpFoldResult> mixedSgData, + ArrayRef<::mlir::OpFoldResult> mixedInstData, + xegpu::LayoutAttr &layoutAttr) { + SmallVector<int32_t> sgLayout, sgData, instData; + auto status = + convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout); + if (!status.succeeded()) + return status; + + status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData); + if (!status.succeeded()) + return status; + + status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData); + if (!status.succeeded()) + return status; + auto maybeInstData = instData.empty() + ? std::nullopt + : std::optional<ArrayRef<int32_t>>(instData); + + layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData); + + return DiagnosedSilenceableFailure::success(); +} + +/// Replace xegpu.create_nd_desc op with a new one with the given layout. +static xegpu::CreateNdDescOp +setDescLayout(transform::TransformRewriter &rewriter, + xegpu::CreateNdDescOp descOp, + xegpu::DistributeLayoutAttr layout) { + assert(descOp.getMixedOffsets().size() == 0 && + "create desc op with offsets is not supported"); + auto oldTensorDesc = descOp.getType(); + auto descType = xegpu::TensorDescType::get( + oldTensorDesc.getShape(), oldTensorDesc.getElementType(), + /*array_length=*/oldTensorDesc.getArrayLength(), + /*boundary_check=*/oldTensorDesc.getBoundaryCheck(), + /*memory_space=*/oldTensorDesc.getMemorySpace(), + /*layout=*/layout); + + rewriter.setInsertionPointAfter(descOp); + auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>( + descOp, descType, descOp.getSource(), descOp.getMixedSizes(), + descOp.getMixedStrides()); + return newDescOp; +} + +DiagnosedSilenceableFailure +transform::GetDescOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) { + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + } + + auto maybeDescOp = + findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin()); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) + << "Could not find a matching descriptor op when walking the " + "producer chain of the first operand."; + } + + results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp}); + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetDescLayoutOp::build(OpBuilder &builder, + OperationState &result, Value target, + ArrayRef<OpFoldResult> mixedSgLayout, + ArrayRef<OpFoldResult> mixedSgData, + ArrayRef<OpFoldResult> mixedInstData, + ArrayRef<int64_t> sliceDims) { + SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; + SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, result, target.getType(), + /*target=*/target, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims); +} + +DiagnosedSilenceableFailure +transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(getContext(), state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); + if (!status.succeeded()) + return status; + + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + + // For now only create_nd_desc op is supported. + auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); + if (!descOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a xegpu.create_nd_desc op, but got: " + << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + // Set layout attr in desc op's return type. Replaces old desc op. + auto newdescOp = setDescLayout(rewriter, descOp, layout); + + // Map result handles. + results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetDescLayoutOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +void transform::SetOpLayoutAttrOp::build( + OpBuilder &builder, OperationState &ostate, Value target, int64_t index, + ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, + ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims, + bool result) { + SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; + SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*index=*/index, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims, + /*result=*/result); +} + +DiagnosedSilenceableFailure +transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + bool resultTarget = getResult(); + + int64_t index = getIndex(); + if (resultTarget && index >= target->getNumResults()) { + return emitSilenceableFailure(getLoc()) + << "Index exceeds the number of op results"; + } + if (!resultTarget && index >= target->getNumOperands()) { + return emitSilenceableFailure(getLoc()) + << "Index exceeds the number of op operands"; + } + + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(getContext(), state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); + if (!status.succeeded()) + return status; + + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + + // Set layout attribute for the op result or operand + if (resultTarget) + xegpu::setDistributeLayoutAttr(target->getResult(index), layout); + else + xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout); + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetOpLayoutAttrOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + modifiesPayload(effects); +} + +void transform::SetGPULaunchThreadsOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef<OpFoldResult> mixedThreads) { + SmallVector<int64_t> staticThreads; + SmallVector<Value> dynamicThreads; + dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads); + build(builder, ostate, target.getType(), + /*target=*/target, + /*threads=*/dynamicThreads, + /*static_threads=*/staticThreads); +} + +DiagnosedSilenceableFailure +transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + auto launchOp = dyn_cast<gpu::LaunchOp>(target); + if (!launchOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a gpu.launch op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + SmallVector<int32_t> threads; + DiagnosedSilenceableFailure status = + convertMixedValuesToInt(state, (*this), threads, getMixedThreads()); + if (!status.succeeded()) + return status; + + if (threads.size() != 3) { + return emitSilenceableFailure(getLoc()) + << "Expected threads argument to consist of three values (got " + << threads.size() << ")"; + } + + rewriter.setInsertionPoint(launchOp); + auto createConstValue = [&](int value) { + return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value); + }; + + // Replace threads in-place. + launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0])); + launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1])); + launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2])); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetGPULaunchThreadsOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getThreadsMutable(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + auto value = *targetValues.begin(); + + int64_t nbPrefetch = getStaticNbPrefetch(); + if (getDynamicNbPrefetch()) { + // Get dynamic prefetch count from transform param or handle. + SmallVector<int32_t> dynamicNbPrefetch; + auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch, + {getDynamicNbPrefetch()}); + if (!status.succeeded()) + return status; + if (dynamicNbPrefetch.size() != 1) + return emitDefiniteFailure() + << "requires exactly one value for dynamic_nb_prefetch"; + nbPrefetch = dynamicNbPrefetch[0]; + } + if (nbPrefetch <= 0) + return emitSilenceableFailure(getLoc()) + << "nb_prefetch must be a positive integer."; + + // Find load operation of the operand. + auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value); + if (!maybeLoadOp) + return emitSilenceableFailure(getLoc()) << "Could not find load op."; + auto loadOp = *maybeLoadOp; + if (loadOp.getMixedOffsets().size() == 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op must have offsets."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find the parent scf.for loop. + auto forOp = loadOp->getParentOfType<scf::ForOp>(); + if (!forOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op is not contained in a scf.for loop."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find descriptor op. + auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value); + if (!maybeDescOp) + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + auto descOp = *maybeDescOp; + if (descOp.getMixedOffsets().size() > 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "desc op with offsets is not supported."; + diag.attachNote(descOp.getLoc()) << "desc op"; + } + + // Clone desc op outside the loop. + rewriter.setInsertionPoint(forOp); + auto newDescOp = + cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation())); + + // Clone reduction loop to emit initial prefetches. + // Compute upper bound of the init loop: start + nbPrefetch * step. + auto nbPrefetchCst = + arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch); + auto nbStep = rewriter.createOrFold<arith::MulIOp>( + forOp.getLoc(), nbPrefetchCst, forOp.getStep()); + auto initUpBound = rewriter.createOrFold<arith::AddIOp>( + forOp.getLoc(), forOp.getLowerBound(), nbStep); + auto initForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + initUpBound, forOp.getStep()); + + auto ctx = rewriter.getContext(); + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + + // Modify loadOp mixedOffsets by replacing the for loop induction variable + // with the given value. + auto getPrefetchOffsets = + [&](Value replacementVal) -> SmallVector<OpFoldResult> { + IRMapping mapping; + mapping.map(forOp.getInductionVar(), replacementVal); + SmallVector<Value> dynamicOffsets = + llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) { + return mapping.lookupOrDefault(v); + })); + auto constOffsets = loadOp.getConstOffsets().value(); + return getMixedValues(constOffsets, dynamicOffsets, ctx); + }; + + // Insert prefetch op in init loop. + // Replace induction var with the init loop induction var. + rewriter.setInsertionPointToStart(initForOp.getBody()); + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(initForOp.getInductionVar()), + readCacheHint, readCacheHint, readCacheHint, + /*layout=*/nullptr); + + // Insert prefetch op in main loop. + // Calculate prefetch offset after the init prefetches have been issued. + rewriter.setInsertionPointToStart(forOp.getBody()); + auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(), + forOp.getInductionVar(), nbStep); + // Replace induction var with correct offset. + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(prefetchOffset), readCacheHint, + readCacheHint, readCacheHint, /*layout=*/nullptr); + + // Unroll the init loop. + if (failed(loopUnrollFull(initForOp))) + return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop"; + + results.set(llvm::cast<OpResult>(getResult()), {newDescOp}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::InsertPrefetchOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getDynamicNbPrefetchMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +void transform::ConvertLayoutOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef<OpFoldResult> mixedInputSgLayout, + ArrayRef<OpFoldResult> mixedInputSgData, + ArrayRef<OpFoldResult> mixedInputInstData, + ArrayRef<OpFoldResult> mixedTargetSgLayout, + ArrayRef<OpFoldResult> mixedTargetSgData, + ArrayRef<OpFoldResult> mixedTargetInstData) { + SmallVector<int64_t> staticInputSgLayout, staticInputSgData, + staticInputInstData; + SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData, + dynamicInputInstData; + dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout, + staticInputSgLayout); + dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData, + staticInputSgData); + dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData, + staticInputInstData); + SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData, + staticTargetInstData; + SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData, + dynamicTargetInstData; + dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout, + staticTargetSgLayout); + dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData, + staticTargetSgData); + dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData, + staticTargetInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*input_sg_layout=*/dynamicInputSgLayout, + /*input_sg_data=*/dynamicInputSgData, + /*input_inst_data=*/dynamicInputInstData, + /*target_sg_layout=*/dynamicTargetSgLayout, + /*target_sg_data=*/dynamicTargetSgData, + /*target_inst_data=*/dynamicTargetInstData, + /*static_input_sg_layout=*/staticInputSgLayout, + /*static_input_sg_data=*/staticInputSgData, + /*static_input_inst_data=*/staticInputInstData, + /*static_target_sg_layout=*/staticTargetSgLayout, + /*static_target_sg_data=*/staticTargetSgData, + /*static_target_inst_data=*/staticTargetInstData); +} + +DiagnosedSilenceableFailure +transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + auto value = *targetValues.begin(); + + // Construct layout attributes. + xegpu::LayoutAttr inputLayoutAttr = nullptr; + auto status = getLayoutAttrFromOperands( + getContext(), state, (*this), getMixedInputSgLayout(), + getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr); + if (!status.succeeded()) + return status; + + xegpu::LayoutAttr targetLayoutAttr = nullptr; + status = getLayoutAttrFromOperands( + getContext(), state, (*this), getMixedTargetSgLayout(), + getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr); + if (!status.succeeded()) + return status; + + // Find first user op to define insertion point for layout conversion. + if (value.use_empty()) + return emitSilenceableFailure(getLoc()) + << "Value has no users to insert layout conversion."; + Operation *userOp = *value.getUsers().begin(); + + // Emit convert_layout op. + rewriter.setInsertionPoint(userOp); + auto convLayoutOp = + xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(), + value, inputLayoutAttr, targetLayoutAttr); + // Replace load op result with the converted layout. + rewriter.replaceUsesWithIf( + value, convLayoutOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != convLayoutOp.getOperation(); + }); + + results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp}); + return DiagnosedSilenceableFailure::success(); +} + +void transform::ConvertLayoutOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getInputSgLayoutMutable(), effects); + onlyReadsHandle(getInputSgDataMutable(), effects); + onlyReadsHandle(getInputInstDataMutable(), effects); + onlyReadsHandle(getTargetSgLayoutMutable(), effects); + onlyReadsHandle(getTargetSgDataMutable(), effects); + onlyReadsHandle(getTargetInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +namespace { +class XeGPUTransformDialectExtension + : public transform::TransformDialectExtension< + XeGPUTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) + + using Base::Base; + + void init(); +}; + +void XeGPUTransformDialectExtension::init() { + declareGeneratedDialect<scf::SCFDialect>(); + declareGeneratedDialect<arith::ArithDialect>(); + declareGeneratedDialect<xegpu::XeGPUDialect>(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + >(); +} +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + +void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions<XeGPUTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index e6f7606..29b645f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms XeGPUWgToSgDistribute.cpp XeGPUPropagateLayout.cpp XeGPUVectorLinearize.cpp + XeGPUOptimizeBlockLoads.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp new file mode 100644 index 0000000..ab41fe4 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp @@ -0,0 +1,490 @@ +//===- XeGPUOptimizeBlockLoads.cpp - XeGPU optimize block loads -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" +#include "mlir/Dialect/XeGPU/uArch/uArchBase.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include <optional> + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUOPTIMIZEBLOCKLOADS +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +#define DEBUG_TYPE "xegpu-optimize-block-loads" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +using namespace mlir; + +namespace { + +/// Get the 2D lane data from a tensor desc type if it exists. +static std::optional<SmallVector<int64_t>> +getMaybeLaneData(xegpu::TensorDescType tdescType) { + auto layout = tdescType.getLayoutAttr(); + if (!layout) + return std::nullopt; + auto laneData = layout.getEffectiveLaneDataAsInt(); + if (laneData.size() != 2) + return std::nullopt; + return laneData; +} + +/// Get the 2D lane layout from a tensor desc type if it exists. +static std::optional<SmallVector<int64_t>> +getMaybeLaneLayout(xegpu::TensorDescType tdescType) { + auto layout = tdescType.getLayoutAttr(); + if (!layout) + return std::nullopt; + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + if (laneLayout.size() != 2) + return std::nullopt; + return laneLayout; +} + +/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 && +/// lane[1] == 1), but inner lane data is not equal to [1, 1]. +/// Example: +/// !xegpu.tensor_desc<16x16xf16, +/// #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> +/// In this case, lane layout is transposed (from the usual [1, SG_SIZE] form) +/// indicating that this is a load that requires transpose effect. However, +/// lane data is [1, 2], meaning that each lane must grab 2 f16 elements from +/// the inner dimension. We convert this to a optimized form by converting the +/// tensor_desc to i32 type such that lane data becomes [1, 1]. This makes the +/// later lowering easily use the load with transpose instruction. +static bool canBeOptimizedForTranspose(ArrayRef<int64_t> laneLayout, + ArrayRef<int64_t> laneData) { + if (laneLayout.size() != 2 || laneData.size() != 2) + return false; + if (laneLayout[0] == 1 || laneLayout[1] != 1) + return false; + if (laneData[0] != 1 || laneData[1] == 1) + return false; + return true; +} + +/// A tensor desc type can be optimized if its element type is less than 32 bits +/// and its layout can be optimized. +static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) { + // If the dtype is greater or equal to 32 bits, layout must be valid. + int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth(); + if (elementTyBitwidth >= 32) + return false; + auto maybeLaneLayout = getMaybeLaneLayout(tdescType); + auto maybeLaneData = getMaybeLaneData(tdescType); + if (!maybeLaneData || !maybeLaneLayout) + return false; + return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData); +} + +/// Check if a tensor desc type can be optimized for transpose, if so return the +/// new optimized tensor desc type with a valid transpose layout. +static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType, + const uArch *targetuArch) { + if (!canBeOptimizedForTranspose(tdescType)) + return tdescType; + auto laneData = getMaybeLaneData(tdescType) + .value(); // Lane data must exist if we reach here. + int64_t innerLaneData = laneData[1]; + int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth(); + // Required shape is total shape of the vector result that this tensor desc + // must eventually load after adjusting for the new bitwidth and array + // length. + SmallVector<int64_t> requiredShape(tdescType.getShape()); + requiredShape.back() = + requiredShape.back() * tdescType.getArrayLength() / innerLaneData; + int newBitWidth = elementTyBitwidth * innerLaneData; + Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth); + // Supported shape is the max transpose shape that can be supported by + // hardware that is less than or equal to required shape. + auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>( + targetuArch->getInstruction(InstructionKind::Subgroup2DBlockLoad)); + auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount( + newElemTy, /** has transform */ false, /** has transpose */ true); + // If no HW params found, return the original type. + if (!maybeHWParams) + return tdescType; + auto [widths, heights, counts] = maybeHWParams.value(); + // TODO: Currently we expect array length to be 1 for transpose case. + if (counts.size() != 1 || counts[0] != 1) + return tdescType; + int arrayLen = counts[0]; + int supportedHeight = + xegpu::getLargestDivisor(static_cast<int>(requiredShape[0]), heights); + int supportedWidth = + xegpu::getLargestDivisor(static_cast<int>(requiredShape[1]), widths); + // If no supported height or width found, return the original type. + if (supportedHeight == -1 || supportedWidth == -1) + return tdescType; + + SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth}; + xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get( + tdescType.getContext(), + tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1}); + // Array length can not be larger than 1 for transpose case. + return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen, + tdescType.getBoundaryCheck(), + tdescType.getMemorySpace(), newLayout); +} + +/// Helper to convert an OpFoldResult to Value. +static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc, + OpFoldResult ofr) { + std::optional<int64_t> mayBeInt = getConstantIntValue(ofr); + if (mayBeInt) + return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt).getResult(); + return llvm::cast<Value>(ofr); +} + +/// Helper to divide a Value by a constant integer. +static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc, + Value val, int64_t constant) { + // If the constant is a power of 2, use right shift for division. + if (llvm::isPowerOf2_64(constant)) { + int64_t shiftAmount = llvm::Log2_64(constant); + return arith::ShRUIOp::create( + rewriter, loc, val, + arith::ConstantIndexOp::create(rewriter, loc, shiftAmount) + .getResult()) + .getResult(); + } + auto constantOp = + arith::ConstantIndexOp::create(rewriter, loc, constant).getResult(); + return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult(); +} + +/// This function takes a larger register block `data` and generates multiple +/// smaller loads (size given by `newTensorDesc`) to fill in the `data` block +/// starting from `offsets`. +static Value generateLoads(ConversionPatternRewriter &rewriter, + TypedValue<VectorType> data, + SmallVector<OpFoldResult> offsets, + TypedValue<xegpu::TensorDescType> newTensorDesc, + xegpu::LoadNdOp origLoadOp) { + Location loc = data.getLoc(); + assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp"); + Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]); + Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]); + SmallVector<int64_t> supportedShape(newTensorDesc.getType().getShape()); + // Compute the ratio between original shape and supported shape. We need to + // generate loads in this ratio arrangement. + auto shapeRatio = computeShapeRatio(data.getType().getShape(), + supportedShape) + .value(); // `ratio` must be defined if we reach here. + for (int64_t h = 0; h < shapeRatio[0]; ++h) { + for (int64_t w = 0; w < shapeRatio[1]; ++w) { + int64_t localOffsetDim0 = h * supportedShape[0]; + int64_t localOffsetDim1 = w * supportedShape[1]; + Value loadOffsetX = arith::AddIOp::create( + rewriter, loc, offsetDim0, + arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim0) + .getResult()); + Value loadOffsetY = arith::AddIOp::create( + rewriter, loc, offsetDim1, + arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim1) + .getResult()); + auto loadOp = xegpu::LoadNdOp::create( + rewriter, loc, + VectorType::get(supportedShape, data.getType().getElementType()), + newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY}, + origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(), + origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(), + origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr()); + // Set the layout for the loadOp. + auto layoutAttr = newTensorDesc.getType().getLayoutAttr(); + xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr); + // Insert the loaded block into the right position in data. + auto insertOp = vector::InsertStridedSliceOp::create( + rewriter, loc, loadOp.getResult(), data, + ArrayRef<int64_t>{localOffsetDim0, localOffsetDim1}, + ArrayRef<int64_t>{1, 1}); + // InsertOp must have the same layout as newTensorDesc. + xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr); + data = insertOp.getResult(); + } + } + return data; +} + +/// Checks if a CreateNdDescOp can be optimized for transpose, if so creates a +/// new CreateNdDescOp with optimized tensor desc type. This involves extracting +/// the base pointer from the original memory source and adjusting the shape and +/// strides of the tensor desc to fit with the new optimized transpose layout. +class XeGPUCreateNdDescOpPattern final + : public OpConversionPattern<xegpu::CreateNdDescOp> { +public: + using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tdescTy = createNdOp.getType(); + // Get the target uArch info. + auto chipStr = xegpu::getChipStr(createNdOp); + // Check if the chip is supported. + assert( + chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg") && + "Expecting target chip to be pvc or bmg for transpose optimization."); + const uArch *targetuArch = xegpu::uArch::getUArch(chipStr.value()); + + auto convertType = tryOptimize(tdescTy, targetuArch); + if (convertType == tdescTy) + return failure(); + auto strides = createNdOp.getMixedStrides(); + auto maybeConstInnerStride = getConstantIntValue(strides.back()); + // Only row-major memrefs are expected for now. + if (!maybeConstInnerStride || *maybeConstInnerStride != 1) + return rewriter.notifyMatchFailure( + createNdOp, "Expecting row-major memref for transpose optimization."); + Value source = createNdOp.getSource(); + auto optionalLaneData = getMaybeLaneData(tdescTy); + assert(optionalLaneData && "Expected 2D lane data"); + auto laneData = optionalLaneData.value(); + int64_t innerLaneData = laneData[1]; + auto memrefType = dyn_cast<MemRefType>(source.getType()); + // Inner dimension of the shape must be adjusted based on innerLaneData. + SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes()); + modifiedShape.back() = divideByConstant( + rewriter, createNdOp.getLoc(), + convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()), + innerLaneData); + // Similarly, second to last stride must be adjusted. + assert(strides.size() >= 2 && + "Expected at least 2 strides for CreateNdDescOp"); + SmallVector<OpFoldResult> modifiedStrides(strides); + modifiedStrides[modifiedStrides.size() - 2] = divideByConstant( + rewriter, createNdOp.getLoc(), + convertToValue(rewriter, createNdOp.getLoc(), + modifiedStrides[modifiedStrides.size() - 2]), + innerLaneData); + + // If the source is a static memref, we need to extract the pointer to + // base address. + if (memrefType && memrefType.hasStaticShape()) { + auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, createNdOp.getLoc(), source); + source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(), + rewriter.getI64Type(), + extractOp.getResult()) + .getResult(); + } + // Create a new CreateNdDescOp with the modified shape and converted type. + auto newCreateNdDescOp = xegpu::CreateNdDescOp::create( + rewriter, createNdOp.getLoc(), convertType, source, modifiedShape, + modifiedStrides); + rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult()); + return success(); + } +}; + +/// Checks if a LoadNdOp consumes a tensor desc type that was rewritten for +/// tranpose optimization. If so, rewrites the LoadNdOp to to align with the +/// adjusted tensor desc type. This can result in multiple LoadNdOps being +/// generated to fill in the original load shape. +class XeGPULoadNdDescOpPattern final + : public OpConversionPattern<xegpu::LoadNdOp> { +public: + using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto origTensorDescType = loadNdOp.getTensorDescType(); + auto adaptorType = + cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType()); + if (adaptorType == origTensorDescType) + return failure(); + // Offsets must be adjusted based on innerLaneData. + auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value(); + int64_t innerLaneData = laneData[1]; + auto offsets = loadNdOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(loadNdOp, + "Expecting offsets in LoadNd"); + SmallVector<OpFoldResult> modifiedOffsets(offsets); + modifiedOffsets.back() = divideByConstant( + rewriter, loadNdOp.getLoc(), + convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()), + innerLaneData); + // Get the 2D data shape of this loadNdOp in its original type including + // array length. + SmallVector<int64_t> origDataShape(origTensorDescType.getShape()); + // Adjust the data shape based on innerLaneData. + origDataShape.back() /= innerLaneData; + // HW supported shape is the new tensor desc shape after conversion. + SmallVector<int64_t> hwSupportedShape(adaptorType.getShape()); + VectorType origVectorType = + VectorType::get(origDataShape, adaptorType.getElementType()); + Value data; + // Orig data shape is 3D for the array length case. + if (origTensorDescType.getArrayLength() > 1) { + SmallVector<Value> arraySlices; + for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) { + Value slice = arith::ConstantOp::create( + rewriter, loadNdOp->getLoc(), origVectorType, + rewriter.getZeroAttr(origVectorType)); + // Increase the Y offset for each array slice. + Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(), + modifiedOffsets.back()); + modifiedOffsets.back() = + arith::AddIOp::create( + rewriter, loadNdOp->getLoc(), offsetY, + arith::ConstantIndexOp::create(rewriter, loadNdOp->getLoc(), + i * origDataShape[1]) + .getResult()) + .getResult(); + slice = generateLoads( + rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets, + cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()), + loadNdOp); + // BitCast back to original load shape without array length. + auto bitcastType = VectorType::get(origTensorDescType.getShape(), + origTensorDescType.getElementType()); + auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(), + bitcastType, slice); + // BitCastOp must have the same layout as the original loadNdOp. + xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0), + origTensorDescType.getLayoutAttr()); + arraySlices.push_back(bitCastOp.getResult()); + } + rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices}); + return success(); + } + data = arith::ConstantOp::create( + rewriter, loadNdOp->getLoc(), + VectorType::get(origDataShape, adaptorType.getElementType()), + rewriter.getZeroAttr(origVectorType)); + data = generateLoads( + rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets, + cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()), + loadNdOp); + auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(), + loadNdOp.getType(), data); + // BitCastOp must have the same layout as the original loadNdOp. + xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0), + origTensorDescType.getLayoutAttr()); + rewriter.replaceOp(loadNdOp, bitCastOp); + return success(); + } +}; + +/// Vector ExtractOp must be processed if the original tensor desc type has +/// array length greater than 1. In this case, the LoadNdOp is replaced with +/// multiple LoadNdOps for each array slice making the extraction unnecessary. +/// In this case, we simply remove the ExtractOp. +class VectorExtractOpPattern final + : public OpConversionPattern<vector::ExtractOp> { +public: + using OpConversionPattern<vector::ExtractOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Check if the source of the extraction is split to multiple values. + if (adaptor.getSource().size() == 1) + return failure(); + auto mixedPos = extractOp.getMixedPosition(); + if (mixedPos.size() != 1) + return failure(); + auto mayBeInt = getConstantIntValue(mixedPos[0]); + if (!mayBeInt) + return failure(); + rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]); + return success(); + } +}; + +} // namespace + +void xegpu::populateXeGPUOptimizeBlockLoadsPatterns( + RewritePatternSet &patterns) { + patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern, + VectorExtractOpPattern>(patterns.getContext()); +} + +namespace { + +struct XeGPUOptimizeBlockLoadsPass final + : public xegpu::impl::XeGPUOptimizeBlockLoadsBase< + XeGPUOptimizeBlockLoadsPass> { + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter converter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + // This pass is only meant for PVC and BMG targets. If unsupported target + // is found, exit early. + bool isTargetSupported = false; + getOperation()->walk([&](gpu::GPUFuncOp funcOp) { + auto chipStr = xegpu::getChipStr(funcOp); + if (chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg")) + isTargetSupported = true; + }); + + if (!isTargetSupported) { + DBGS() << "XeGPUOptimizeBlockLoadsPass only supports PVC and BMG targets." + << "\n"; + return; + } + + // CreateNdDescOp and LoadNdOp with optimizable tensor desc types must be + // converted. + target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>( + [&](xegpu::CreateNdDescOp createNdOp) { + return !canBeOptimizedForTranspose(createNdOp.getType()); + }); + target.addDynamicallyLegalOp<xegpu::LoadNdOp>( + [&](xegpu::LoadNdOp loadNdOp) { + return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType()); + }); + // Vector ExtractOps can have optimizable layouts if they extract from + // LoadNdOps with array length greater than 1. These ExtractOps must be + // converted. + target.addDynamicallyLegalOp<vector::ExtractOp>( + [&](vector::ExtractOp extractOp) { + auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult()); + if (!layout) + return true; + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + auto laneData = layout.getEffectiveLaneDataAsInt(); + return !canBeOptimizedForTranspose(laneLayout, laneData); + }); + converter.addConversion([](Type type) { return type; }); + + target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect, + vector::VectorDialect>(); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + xegpu::populateXeGPUOptimizeBlockLoadsPatterns(patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + DBGS() << "Optimize block loads pass failed.\n"; + return signalPassFailure(); + } + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 8fab255..dc9eb96 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/IR/Attributes.h" @@ -37,6 +36,8 @@ #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" + namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT @@ -52,6 +53,8 @@ using namespace mlir::dataflow; namespace { +enum class LayoutKind { Lane, InstData }; + //===----------------------------------------------------------------------===// // LayoutInfo //===----------------------------------------------------------------------===// @@ -104,6 +107,8 @@ public: SmallVector<int> getLaneData() const; + SmallVector<int> getInstData() const; + bool isSliceLayout() const { if (!isAssigned()) return false; @@ -137,6 +142,13 @@ SmallVector<int> LayoutInfo::getLaneData() const { [](int64_t val) { return static_cast<int>(val); }); } +SmallVector<int> LayoutInfo::getInstData() const { + if (!isAssigned()) + return {}; + return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(), + [](int64_t val) { return static_cast<int>(val); }); +} + void LayoutInfo::print(raw_ostream &os) const { if (isAssigned()) { os << storage; @@ -156,7 +168,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) { llvm_unreachable("Join should not be triggered by layout propagation."); } -/// Construct a new layout with the transposed lane layout and lane data. +/// Construct a new layout with the transposed inst_data or lane_layout, +/// lane_data. LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const { if (!isAssigned()) return {}; @@ -174,12 +187,22 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const { SmallVector<int32_t> laneLayout; SmallVector<int32_t> laneData; + SmallVector<int32_t> instData; for (int64_t idx : permutation) { - laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); - laneData.push_back(static_cast<int32_t>(getLaneData()[idx])); + if (getLaneLayout().size()) { + laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); + laneData.push_back(static_cast<int32_t>(getLaneData()[idx])); + } + if (getInstData().size()) + instData.push_back(static_cast<int32_t>(getInstData()[idx])); } - return LayoutInfo( - xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData)); + xegpu::LayoutAttr layoutAttr; + if (getLaneLayout().size()) + layoutAttr = + xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData); + if (getInstData().size()) + layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData); + return LayoutInfo(layoutAttr); } //===----------------------------------------------------------------------===// @@ -200,18 +223,30 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> { /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1]. /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1]. static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, - unsigned rank) { + unsigned rank, + const xegpu::uArch::uArch *uArch) { assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); if (rank == 1) { return LayoutInfo( - xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1})); + xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1})); } - return LayoutInfo(xegpu::LayoutAttr::get( - ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1})); + return LayoutInfo( + xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1})); +} + +static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, + unsigned rank, int subgroupSize) { + assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); + if (rank == 1) { + return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1})); + } + return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1})); } /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, + const xegpu::uArch::uArch *uArch, + unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && @@ -221,28 +256,24 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (vectorTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1); + return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. - int packingFactor = 1; unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); + int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { - packingFactor = - bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter - ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth - : 1; - return LayoutInfo(xegpu::LayoutAttr::get( - vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1}, - {1, packingFactor})); + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), + {uArch->getSubgroupSize(), 1}, + {1, packingFactor})); } - if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) - packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth; return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), - {1, xegpu::targetinfo::subgroupSize}, + {1, uArch->getSubgroupSize()}, {1, packingFactor})); } /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, + const xegpu::uArch::uArch *uArch, + unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) && @@ -252,27 +283,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (tdescTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1); + return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); - + int subgroupSize = uArch->getSubgroupSize(); + int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { - int packingFactor = - bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter - ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth - : 1; return LayoutInfo(xegpu::LayoutAttr::get( - tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1}, - {1, packingFactor})); + tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor})); } - int packingFactor = - (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) - ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth - : 1; - return LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), - {1, xegpu::targetinfo::subgroupSize}, - {1, packingFactor})); + return LayoutInfo(xegpu::LayoutAttr::get( + tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor})); } /// Helper Function to get the expected layouts for DPAS operands. `lane_data` @@ -281,25 +303,25 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, /// `packedSizeInBitsForDefault` /// * For B operand, the data must be packed in minimum /// `packedSizeInBitsForDpasB` -static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, - unsigned operandNum) { +static LayoutInfo +getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, + const xegpu::uArch::uArch *uArch, + unsigned packingSize) { Type elementTy = vectorTy.getElementType(); assert(elementTy.isIntOrFloat() && "Expected int or float type in DPAS operands"); - SmallVector<int32_t, 2> layout({1, xegpu::targetinfo::subgroupSize}); + SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()}); // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and // must have the VNNI format. - if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < - xegpu::targetinfo::packedSizeInBitsForDpasB) { + if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) { SmallVector<int32_t, 2> data( - {static_cast<int32_t>(xegpu::targetinfo::packedSizeInBitsForDpasB / - elementTy.getIntOrFloatBitWidth()), + {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()), 1}); return LayoutInfo( xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data)); } // Otherwise, return the default layout for the vector type. - return getDefaultSIMTLayoutInfo(vectorTy); + return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize); } //===----------------------------------------------------------------------===// @@ -314,6 +336,7 @@ static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, class LayoutInfoPropagation : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> { private: + LayoutKind layoutKind; void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results); @@ -364,10 +387,14 @@ private: ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results); + bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout); + public: LayoutInfoPropagation(DataFlowSolver &solver, - SymbolTableCollection &symbolTable) - : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + SymbolTableCollection &symbolTable, + LayoutKind layoutKind) + : SparseBackwardDataFlowAnalysis(solver, symbolTable), + layoutKind(layoutKind) {} using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; LogicalResult @@ -450,13 +477,71 @@ LogicalResult LayoutInfoPropagation::visitOperation( return success(); } +bool LayoutInfoPropagation::hasParamsOfLayoutKind( + xegpu::DistributeLayoutAttr anchorLayout) { + if (anchorLayout == nullptr) { + return false; + } + if (layoutKind == LayoutKind::InstData) { + return !(anchorLayout.getEffectiveInstDataAsInt().empty()); + } else if (layoutKind == LayoutKind::Lane) { + return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() || + anchorLayout.getEffectiveLaneDataAsInt().empty()); + } + return false; +} + void LayoutInfoPropagation::visitPrefetchNdOp( xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // Here we assign the default layout to the tensor descriptor operand of - // prefetch. - auto tdescTy = prefetch.getTensorDescType(); - auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy); + + LayoutInfo prefetchLayout; + xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + prefetchLayout = LayoutInfo(anchorLayout); + } else { + // Here we assign the default layout to the tensor descriptor operand of + // prefetch. + auto tdescTy = prefetch.getTensorDescType(); + + auto uArch = getUArch(getChipStr(prefetch).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); + + auto blockWHC = + uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); + if (!blockWHC) + prefetch.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = xegpu::getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth); + if (instWidth == -1) + prefetch.emitWarning( + "No suitable instruction multiple found for the given shape."); + if (tdescTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = xegpu::getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); + if (instHeight == -1) + prefetch.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + + if (layoutKind == LayoutKind::InstData) + prefetchLayout = + LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData)); + else + prefetchLayout = getDefaultSIMTLayoutInfo( + tdescTy, uArch, uArchInstruction->getPackedFormatBitSize()); + + prefetch.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get())); + } // Propagate the layout to the source tensor descriptor. propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); } @@ -475,10 +560,11 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp( reduction.emitWarning("Expecting output type to be 1D vector."); return; } + auto uArch = getUArch(xegpu::getChipStr(reduction).value_or("")); // Given that the result is 1D, the layout of the operand should be 2D with // default layout. - LayoutInfo operandLayout = - getDefaultSIMTLayoutInfo(reduction->getContext(), 2); + LayoutInfo operandLayout = getDefaultSIMTLayoutInfo( + reduction->getContext(), 2, uArch->getSubgroupSize()); propagateIfChanged(operands[0], operands[0]->meet(operandLayout)); // Accumulator should have the same layout as the result. propagateIfChanged(operands[1], operands[1]->meet(resultLayout)); @@ -494,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp( // Only consider vector to vector broadcasts for now. VectorType resultTy = broadcast.getResultVectorType(); VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType()); - if (!sourceTy) { - broadcast.emitWarning("Expecting source type to be a vector type."); + // skip layout propagation for non-vector source operand. + if (!sourceTy) return; - } - // Only consider nD -> nD broadcast. + // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case. if (sourceTy.getRank() != resultTy.getRank()) { - broadcast.emitWarning("Expecting source and result to have same rank."); + auto sourceDims = sourceTy.getShape(); + auto resultDims = resultTy.getShape(); + SmallVector<int64_t> bcastDims; + auto dimDiff = resultTy.getRank() - sourceTy.getRank(); + // adding the missing leading dims + for (int i = 0; i < dimDiff; i++) + bcastDims.push_back(i); + + // for the rest dims in the resultTy, if sourceTy dim is 1, then it's + // broadcasted dim + for (size_t i = 0; i < sourceDims.size(); i++) + if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1)) + bcastDims.push_back(i + dimDiff); + + // create a slice layout for the source + xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( + broadcast->getContext(), + cast<xegpu::DistributeLayoutAttr>(resultLayout.get()), + DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims)); + + propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); return; } + SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims(); - if (broadcastUnitDims.size() != 1) { - broadcast.emitWarning("Expecting source type to be nD vector only with " - "one broadcasted dimension."); - return; - } - // Propagate the result layout to the source operand. + resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get()) + .setUnitDimData(broadcastUnitDims); propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); } @@ -555,17 +657,97 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp( void LayoutInfoPropagation::visitDpasOp( xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - VectorType aTy = dpas.getLhsType(); - VectorType bTy = dpas.getRhsType(); - propagateIfChanged( - operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0))); - propagateIfChanged( - operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1))); + + LayoutInfo dpasALayout; + LayoutInfo dpasBLayout; + LayoutInfo dpasCDLayout; + + xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr(); + if (hasParamsOfLayoutKind(anchorLayoutCD)) { + xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr(); + xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr(); + assert(hasParamsOfLayoutKind(anchorLayoutA) && + "Expected anchor layout for DPAS A operand."); + assert(hasParamsOfLayoutKind(anchorLayoutB) && + "Expected anchor layout for DPAS B operand."); + dpasALayout = LayoutInfo(anchorLayoutA); + dpasBLayout = LayoutInfo(anchorLayoutB); + dpasCDLayout = LayoutInfo(anchorLayoutCD); + + } else { + + VectorType aTy = dpas.getLhsType(); + VectorType bTy = dpas.getRhsType(); + + auto uArch = getUArch(getChipStr(dpas).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( + xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); + + const unsigned dataALen = aTy.getShape().front(); + auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); + const int maxALen = + xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); + if (maxALen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + + const unsigned dataBLen = bTy.getShape().back(); + auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType()); + + const int maxBLen = + xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); + + if (maxBLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataA = {maxALen, subgroupSize}; + SmallVector<int> instDataB = {subgroupSize, maxBLen}; + + if (layoutKind == LayoutKind::InstData) { + dpasALayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA)); + dpasBLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB)); + } else { + dpasALayout = getSIMTLayoutInfoForDPASOperand( + aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA()); + dpasBLayout = getSIMTLayoutInfoForDPASOperand( + bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB()); + } + + if (operands.size() > 2) { + VectorType cTy = dpas.getAccType(); + if (layoutKind == LayoutKind::InstData) { + const unsigned dataCLen = bTy.getShape().back(); + auto supportedCLen = + uArchInstruction->getSupportedN(bTy.getElementType()); + const int maxCLen = xegpu::getLargestDivisor( + dataCLen, ArrayRef<unsigned>(supportedCLen)); + if (maxCLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataC = {maxALen, maxCLen}; + dpasCDLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC)); + } else + dpasCDLayout = getSIMTLayoutInfoForDPASOperand( + cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB()); + + dpas.setLayoutCdAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get())); + } + dpas.setLayoutAAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get())); + dpas.setLayoutBAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get())); + } + + propagateIfChanged(operands[0], operands[0]->meet(dpasALayout)); + propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout)); if (operands.size() > 2) { - VectorType cTy = dpas.getAccType(); - propagateIfChanged( - operands[2], - operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2))); + propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout)); } } @@ -573,7 +755,51 @@ void LayoutInfoPropagation::visitDpasOp( void LayoutInfoPropagation::visitStoreNdOp( xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType()); + + LayoutInfo storeLayout; + xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + storeLayout = LayoutInfo(anchorLayout); + } else { + auto uArch = getUArch(getChipStr(store).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); + VectorType dataTy = store.getValueType(); + auto blockWHC = uArchInstruction->getBlockWidthHeightCount( + store.getValueType().getElementType()); + if (!blockWHC) + store.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = xegpu::getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth); + if (instWidth == -1) + store.emitWarning( + "No suitable instruction multiple found for the given shape."); + if (dataTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = xegpu::getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); + if (instHeight == -1) + store.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + + if (layoutKind == LayoutKind::InstData) + storeLayout = + LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData)); + else + storeLayout = + getDefaultSIMTLayoutInfo(store.getValueType(), uArch, + uArchInstruction->getPackedFormatBitSize()); + store.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get())); + } + // Propagate the layout to the value operand. // Both operands should have the same layout for (LayoutInfoLattice *operand : operands) propagateIfChanged(operand, operand->meet(storeLayout)); @@ -584,21 +810,30 @@ void LayoutInfoPropagation::visitStoreNdOp( void LayoutInfoPropagation::visitLoadNdOp( xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - LayoutInfo valueLayout = results[0]->getValue(); - // Need the layout of the value to propagate to the tensor descriptor. - if (!valueLayout.isAssigned()) - return; - LayoutInfo tensorDescLayout = valueLayout; - // LoadNdOp has the transpose effect. However, at the stage of this analysis - // this effect is not expected and should be abstracted away. Emit a - // warning. - if (auto transpose = load.getTranspose()) { - load.emitWarning("Transpose effect is not expected for LoadNdOp at " - "LayoutInfoPropagation stage."); - tensorDescLayout = valueLayout.transpose(transpose.value()); + + LayoutInfo loadLayout; + xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + loadLayout = LayoutInfo(anchorLayout); + } else { + + LayoutInfo valueLayout = results[0]->getValue(); + // Need the layout of the value to propagate to the tensor descriptor. + if (!valueLayout.isAssigned()) + return; + loadLayout = valueLayout; + // LoadNdOp has the transpose effect. However, at the stage of this analysis + // this effect is not expected and should be abstracted away. Emit a + // warning. + if (auto transpose = load.getTranspose()) { + load.emitWarning("Transpose effect is not expected for LoadNdOp at " + "LayoutInfoPropagation stage."); + loadLayout = valueLayout.transpose(transpose.value()); + } + load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get())); } // Propagate the new layout to the tensor descriptor operand. - propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout)); + propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); } /// For vector::TransposeOp, the layout of the result is transposed and @@ -688,20 +923,48 @@ void LayoutInfoPropagation::visitVectorBitcastOp( void LayoutInfoPropagation::visitLoadGatherOp( xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // The layout is strictly determined by the payload type. - auto payloadTy = dyn_cast<VectorType>(load.getValueType()); - if (!payloadTy) { - load.emitWarning("Not propagating, non-vector payload supplied."); - return; - } - LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true); - // Mask operand should have 1D default layout. - LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1); + LayoutInfo loadLayout; + LayoutInfo maskLayout; + xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + loadLayout = LayoutInfo(anchorLayout); + maskLayout = loadLayout; + } else { + + // The layout is strictly determined by the payload type. + VectorType payloadTy = load.getValueType(); + if (!payloadTy) { + load.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + auto uArch = getUArch(getChipStr(load).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) + instData.push_back(chunkSize); + else if (auto srcTdescTy = + dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { + if (srcTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + + if (layoutKind == LayoutKind::InstData) + loadLayout = + LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData)); + else + loadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), + /*scattered*/ true); + // Mask operand should have 1D default layout. + maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize); + + load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get())); + } // Propagate the new layout to the tensor descriptor operand. if (isa<xegpu::TensorDescType>(load.getSourceType())) - propagateIfChanged(operands[0], operands[0]->meet(layout)); + propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); if (load.getOffsets()) @@ -717,8 +980,10 @@ void LayoutInfoPropagation::visitCreateDescOp( // Need the layout of the descriptor to propagate to the operands. if (!descLayout.isAssigned()) return; + auto uArch = getUArch(getChipStr(createDesc).value_or("")); // For offset operand propagate 1D default layout. - LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1); + LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1, + uArch->getSubgroupSize()); propagateIfChanged(operands[1], operands[1]->meet(layout)); } @@ -727,26 +992,56 @@ void LayoutInfoPropagation::visitCreateDescOp( void LayoutInfoPropagation::visitStoreScatterOp( xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // Currently, for 2D StoreScatterOp we expect that the height dimension of - // the tensor descriptor is equal to the subgroup size. This is ensured by - // the op verifier. - auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType()); - if (!payloadTy) { - storeScatter.emitWarning("Not propagating, non-vector payload supplied."); - return; + + LayoutInfo payloadLayout; + LayoutInfo maskLayout; + xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + payloadLayout = LayoutInfo(anchorLayout); + maskLayout = payloadLayout; + } else { + // Currently, for 2D StoreScatterOp we expect that the height dimension of + // the tensor descriptor is equal to the subgroup size. This is ensured by + // the op verifier. + VectorType payloadTy = storeScatter.getValueType(); + if (!payloadTy) { + storeScatter.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + + auto uArch = getUArch(getChipStr(storeScatter).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + + if (layoutKind == LayoutKind::InstData) { + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = storeScatter.getChunkSize().value_or(0); + chunkSize > 1) + instData.push_back(chunkSize); + else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>( + storeScatter.getDestType())) { + if (dstTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + payloadLayout = LayoutInfo( + xegpu::LayoutAttr::get(storeScatter.getContext(), instData)); + } else { + auto payloadShape = payloadTy.getShape(); + if (payloadShape.size() > 1) + assert(payloadShape[0] == subgroupSize && + "Expected the first dimension of 2D tensor descriptor to be " + "equal to " + "subgroup size."); + payloadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), + /*scattered=*/true); + } + + maskLayout = + getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize); + + storeScatter.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get())); } - auto payloadShape = payloadTy.getShape(); - if (payloadShape.size() > 1) - assert( - payloadShape[0] == xegpu::targetinfo::subgroupSize && - "Expected the first dimension of 2D tensor descriptor to be equal to " - "subgroup size."); - - LayoutInfo payloadLayout = - getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true); - - LayoutInfo maskLayout = - getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1); // Propagate the payload operand layout propagateIfChanged(operands[0], operands[0]->meet(payloadLayout)); // Propagate the destination (if tdesc) operand layout @@ -768,10 +1063,10 @@ class RunLayoutInfoPropagation { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation) - RunLayoutInfoPropagation(Operation *op) : target(op) { + RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) { SymbolTableCollection symbolTable; loadBaselineAnalyses(solver); - solver.load<LayoutInfoPropagation>(symbolTable); + solver.load<LayoutInfoPropagation>(symbolTable, layoutKind); (void)solver.initializeAndRun(op); } @@ -878,7 +1173,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, } // If the result is a vector type, add a temporary layout attribute to the // op. - xegpu::setDistributeLayoutAttr(result, layout); + xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true); } return success(); } @@ -1011,7 +1306,18 @@ struct XeGPUPropagateLayoutPass final } // namespace void XeGPUPropagateLayoutPass::runOnOperation() { - auto &analysis = getAnalysis<RunLayoutInfoPropagation>(); + LayoutKind layoutKind; + if (this->layoutKind == "lane") { + layoutKind = LayoutKind::Lane; + } else if (this->layoutKind == "inst") { + layoutKind = LayoutKind::InstData; + } else { + getOperation()->emitError("Unsupported layout kind option: " + + this->layoutKind); + signalPassFailure(); + return; + } + RunLayoutInfoPropagation analysis(getOperation(), layoutKind); // Print the analysis result and exit. (for debugging purposes) if (printOnly) { auto &os = llvm::outs(); @@ -1023,9 +1329,11 @@ void XeGPUPropagateLayoutPass::runOnOperation() { LayoutInfo layout = analysis.getLayoutInfo(val); if (!layout.isAssigned()) return {}; + xegpu::DistributeLayoutAttr layoutAttr = + cast<xegpu::DistributeLayoutAttr>(layout.get()); if (layout.isSliceLayout()) - return cast<xegpu::SliceAttr>(layout.get()); - return cast<xegpu::LayoutAttr>(layout.get()); + return cast<xegpu::SliceAttr>(layoutAttr); + return cast<xegpu::LayoutAttr>(layoutAttr); }; mlir::OpBuilder builder(&getContext()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index d09dc19..ca81c3c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -7,14 +7,15 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -98,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout, for (auto [i, dim] : llvm::enumerate(originalType.getShape())) { if (i < distributionStart) continue; - // Check if the dimension can be distributed evenly. if (dim % effectiveLaneLayout[i - distributionStart] != 0) return failure(); @@ -159,17 +159,33 @@ static bool requirePacked(const xegpu::LayoutAttr layout) { /// Helper function to check if the layout requires a transpose effect. static bool requireTranspose(const xegpu::LayoutAttr layout, - const std::string &chipStr) { + const xegpu::uArch::uArch *uArch) { // Return false for unsupported targets. // TODO: Add more support or move to target info. - if (chipStr != "pvc" && chipStr != "bmg") + if (uArch->getName().equals_insensitive("pvc") && + uArch->getName().equals_insensitive("bmg")) return false; if (!layout) return false; auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); if (laneLayout.size() != 2) return false; - return laneLayout[0] == xegpu::targetinfo::subgroupSize && laneLayout[1] == 1; + return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1; +} + +/// Given a vector type and its distributed vector type, return the list of +/// dimensions that are distributed. +static SmallVector<int64_t> getDistributedDims(VectorType originalType, + VectorType distributedType) { + assert(originalType.getRank() == distributedType.getRank() && + "sequential and distributed vector types must have the same rank"); + SmallVector<int64_t> distributedDims; + for (int64_t i = 0; i < originalType.getRank(); ++i) { + if (distributedType.getDimSize(i) != originalType.getDimSize(i)) { + distributedDims.push_back(i); + } + } + return distributedDims; } /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body @@ -199,6 +215,11 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> { using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern; LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, PatternRewriter &rewriter) const override { + auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or("")); + if (!uArch) + return rewriter.notifyMatchFailure( + gpuFuncOp, "Subgroup distribution requires target attribute attached " + "to set the warp size"); // If the function only contains a single void return, skip. if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) { return isa<gpu::ReturnOp>(op) && !op.getNumOperands(); @@ -230,7 +251,7 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> { ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults(); auto warpOp = gpu::WarpExecuteOnLane0Op::create( rewriter, laneId.getLoc(), gpuFuncResultType, laneId, - xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(), + uArch->getSubgroupSize(), newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes()); Block &warpBodyBlock = warpOp.getBodyRegion().front(); // Replace the ReturnOp of the original gpu function with a YieldOp. @@ -495,14 +516,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { warpOp, "warp result is not a xegpu::LoadNd op"); auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>(); + auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or("")); + if (!uArch) + return rewriter.notifyMatchFailure( + loadOp, "xegpu::LoadNdOp require target attribute attached to " + "determine transpose " + "requirement"); // Chip information is required to decide if the layout requires transpose // effect. - auto chipStr = xegpu::getChipStr(loadOp); - if (!chipStr) - return rewriter.notifyMatchFailure( - loadOp, - "xegpu::LoadNdOp require chip information to determine transpose " - "requirement"); // Expecting offsets to be present. SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets(); if (offsets.empty()) @@ -556,7 +577,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { // Set the packed attribute if the layout requires it. newLoadOp.setPacked(requirePacked(layout)); // Set the transpose attribute if the layout requires it. - if (requireTranspose(layout, chipStr.value())) + if (requireTranspose(layout, uArch)) newLoadOp.setTranspose( DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0})); Value distributedVal = newWarpOp.getResult(operandIdx); @@ -906,6 +927,183 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { } }; +static SmallVector<Value> computeDistributedCoordinatesForMatrixOp( + PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout, + Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) { + SmallVector<Value> newCoods; + auto maybeCoords = + layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape); + if (failed(maybeCoords)) + return {}; + assert(maybeCoords.value().size() == 1 && + "Expected one set of distributed offsets"); + SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned( + rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]), + getAsOpFoldResult(origOffsets)); + newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>); + return newCoods; +} + +/// Pattern for distributing xegpu::LoadMatrixOp. +struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + gpu::YieldOp yield = warpOp.getTerminator(); + Operation *lastNode = yield->getPrevNode(); + auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode); + if (!matrixOp) + return failure(); + + OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) { + return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op; + }); + if (!producedByLastLoad) + return rewriter.notifyMatchFailure( + warpOp, "The last op is not xegpu::LoadMatrixOp"); + const int operandIdx = producedByLastLoad->getOperandNumber(); + + VectorType sgPayloadTy = + dyn_cast<VectorType>(matrixOp.getResult().getType()); + VectorType warpResultTy = + cast<VectorType>(warpOp.getResult(operandIdx).getType()); + if (!sgPayloadTy) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix op payload must be a vector type"); + + auto loc = matrixOp.getLoc(); + auto offsets = matrixOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(matrixOp, + "the load op must have offsets"); + SmallVector<Value> offsetsAsValues = + vector::getAsValues(rewriter, matrixOp.getLoc(), offsets); + + auto layout = matrixOp.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix operation lacks layout attribute"); + + FailureOr<VectorType> distPayloadByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy); + if (failed(distPayloadByWarpOpOrFailure)) + return rewriter.notifyMatchFailure( + matrixOp, "Failed to distribute matrix op payload based on layout."); + + SmallVector<Value> operands = {matrixOp.getMemDesc()}; + const unsigned offsetsStartIdx = operands.size(); + operands.append(offsetsAsValues); + + SmallVector<Type> operandTypes = llvm::to_vector( + llvm::map_range(operands, [](Value v) { return v.getType(); })); + + SmallVector<size_t> newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypes, newRetIndices); + SmallVector<Value> newOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(), + ShapedType::kDynamic); + DenseI64ArrayAttr newConstOffsetsAttr = + rewriter.getDenseI64ArrayAttr(newConstOffsets); + ValueRange currentOffsets = + ValueRange(newOperands).drop_front(offsetsStartIdx); + + SmallVector<Value> newCoords = currentOffsets; + rewriter.setInsertionPointAfter(newWarpOp); + + if (!matrixOp.getSubgroupBlockIoAttr()) { + newCoords = computeDistributedCoordinatesForMatrixOp( + rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(), + currentOffsets); + } + xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create( + rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure, + newOperands[0], ValueRange(newCoords), newConstOffsetsAttr, + matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{}); + // Resolve the output type and replace all uses. + rewriter.replaceAllUsesWith( + newWarpOp.getResult(operandIdx), + resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter)); + return success(); + } +}; + +/// Pattern for distributing xegpu::StoreMatrixOp. +struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + gpu::YieldOp yield = warpOp.getTerminator(); + Operation *lastNode = yield->getPrevNode(); + auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode); + if (!matrixOp) + return failure(); + + VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType()); + if (!sgPayloadTy) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix op payload must be a vector type"); + + auto loc = matrixOp.getLoc(); + auto offsets = matrixOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(matrixOp, + "the store op must have offsets"); + SmallVector<Value> offsetsAsValues = + vector::getAsValues(rewriter, matrixOp.getLoc(), offsets); + + auto layout = matrixOp.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix operation lacks layout attribute"); + + FailureOr<VectorType> distPayloadByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy); + if (failed(distPayloadByWarpOpOrFailure)) + return rewriter.notifyMatchFailure( + matrixOp, "Failed to distribute matrix op payload based on layout."); + + SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()}; + const unsigned offsetsStartIdx = operands.size(); + operands.append(offsetsAsValues); + + SmallVector<Type> operandTypes = llvm::to_vector( + llvm::map_range(operands, [](Value v) { return v.getType(); })); + operandTypes[0] = *distPayloadByWarpOpOrFailure; + + SmallVector<size_t> newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypes, newRetIndices); + SmallVector<Value> newOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(), + ShapedType::kDynamic); + DenseI64ArrayAttr newConstOffsetsAttr = + rewriter.getDenseI64ArrayAttr(newConstOffsets); + ValueRange currentOffsets = + ValueRange(newOperands).drop_front(offsetsStartIdx); + + SmallVector<Value> newCoords = currentOffsets; + rewriter.setInsertionPointAfter(newWarpOp); + + if (!matrixOp.getSubgroupBlockIoAttr()) { + newCoords = computeDistributedCoordinatesForMatrixOp( + rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(), + currentOffsets); + } + + xegpu::StoreMatrixOp::create( + rewriter, loc, TypeRange{}, newOperands[0], newOperands[1], + ValueRange(newCoords), newConstOffsetsAttr, + matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{}); + rewriter.eraseOp(matrixOp); + return success(); + } +}; + /// Distribute a scattered load op. The logic and requirements are the same as /// for the scattered store distribution. The warpOp's payload vector is /// expected to be distributed by the load's result consumer. @@ -1225,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { } }; +/// This pattern distributes the `vector.broadcast` operation across lanes in a +/// warp. The pattern supports three use cases: +/// +/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input +/// vector +/// must have a slice layout of the result. If the distributed source and +/// target vector types are identical, this lowers to a no-op; otherwise, it +/// remains a broadcast but operates on distributed vectors. +/// +/// 2) Broadcast a same-rank vector with identical layouts for source and +/// target: +/// The source vector must have unit dimensions, and lane_data must be unit +/// size for those unit dims. This always lowers to a no-op. +/// +/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from +/// scalar to distributed result type. +/// +/// Example 1 (lowering to a broadcast with distributed types): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [0]> } : () -> (vector<32xf32>) +/// %2 = vector.broadcast %0 {layout_result_0 = +/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>} +/// : vector<32xf32> to vector<8x32xf32> +/// gpu.yield %1 : vector<8x32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [0]> } : () -> (vector<32xf32>) +/// gpu.yield %0 : vector<32xf32> +/// } +/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32> +/// +/// Example 2 (no-op): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [1]> } : () -> (vector<8xf32>) +/// %1 = vector.shape_cast %0 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8xf32> to vector<8x1xf32> +/// %2 = vector.broadcast %1 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8x1xf32> to vector<8x32xf32> +/// gpu.yield %1 : vector<8x32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [1]> } : () -> (vector<8xf32>) +/// %1 = vector.shape_cast %0 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8xf32> to vector<8x1xf32> +/// gpu.yield %1 : vector<8x1xf32> +/// } +/// // The broadcast is implicit through layout transformation (no-op) +/// "some_use"(%r#0) +/// ``` +struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); + if (!yieldOperand) + return failure(); + auto broadcastOp = + cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp()); + unsigned operandIdx = yieldOperand->getOperandNumber(); + + VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType()); + VectorType destType = + dyn_cast<VectorType>(broadcastOp.getResult().getType()); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0)); + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getDistributeLayoutAttr(broadcastOp.getResult()); + + FailureOr<VectorType> sourceDistType; + Type sourceElemOrDistType; + if (sourceType) { + + // Case 1 and 2: source is a vector type. + int64_t rankDiff = destType.getRank() - sourceType.getRank(); + if (rankDiff > 0) { + // Case 1: source is lower-rank than result. + bool isSliceOf = sourceLayout.isSliceOf(resultLayout); + if (!isSliceOf) + return rewriter.notifyMatchFailure( + warpOp, + "Broadcast input layout must be a slice of result layout."); + } + // case 2: source and result have same rank + if (rankDiff == 0) { + SetVector<int64_t> broadcastUnitDims = + broadcastOp.computeBroadcastedUnitDims(); + resultLayout = resultLayout.setUnitDimData(broadcastUnitDims); + bool isEqualTo = sourceLayout.isEqualTo(resultLayout); + if (!isEqualTo) + return rewriter.notifyMatchFailure( + warpOp, "For same-rank broadcast, source must be identical to " + "adjusted result layouts with unit dims."); + sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims); + } + + sourceDistType = + getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType); + if (failed(sourceDistType)) { + return rewriter.notifyMatchFailure( + warpOp, "Failed to distribute the source vector type."); + } + sourceElemOrDistType = sourceDistType.value(); + + } else { + // Case 3: source is a scalar type. + if (sourceLayout) { + return rewriter.notifyMatchFailure( + warpOp, "Broadcast from scalar must not have a layout attribute."); + } + sourceElemOrDistType = broadcastOp.getSourceType(); + } + FailureOr<VectorType> destDistType = + getDistVecTypeBasedOnLaneLayout(resultLayout, destType); + if (failed(destDistType)) { + return rewriter.notifyMatchFailure( + warpOp, "Failed to distribute the dest vector type."); + } + + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType, + newRetIndices); + + Value distributedSource = newWarpOp.getResult(newRetIndices[0]); + + Value newBroadcast = distributedSource; + + if (sourceElemOrDistType != destDistType.value()) { + rewriter.setInsertionPointAfter(newWarpOp); + newBroadcast = + vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(), + destDistType.value(), distributedSource); + } + + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast); + return success(); + } +}; + /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing /// `gpu.warp_execute_on_lane_0` region. struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { @@ -1285,6 +1643,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { } }; +// Distribute a `vector.extract_strided_slice` op feeding into yield op of an +// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers +// advanced cases where the distributed dimension is partially extracted and +// currently not supported by the generic vector distribution patterns. +struct VectorExtractStridedSliceDistribution + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>); + if (!operand) + return failure(); + auto extractOp = + cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp()); + unsigned operandIdx = operand->getOperandNumber(); + auto distributedType = + cast<VectorType>(warpOp.getResult(operandIdx).getType()); + // Find the distributed dimensions. + auto extractResultType = cast<VectorType>(operand->get().getType()); + auto distributedDims = + getDistributedDims(extractResultType, distributedType); + // Collect updated source type, sizes and offsets. They may be adjusted + // later if the data is distributed to lanes (as opposed to being owned by + // all lanes uniformly). + VectorType updatedSourceType = extractOp.getSourceVectorType(); + SmallVector<Attribute> updatedSizes = llvm::map_to_vector( + extractOp.getSizes(), [](Attribute attr) { return attr; }); + SmallVector<Attribute> updatedOffsets = llvm::map_to_vector( + extractOp.getOffsets(), [](Attribute attr) { return attr; }); + // If the result is distributed, it must be distributed in exactly one + // dimension. In this case, we adjust the sourceDistType, distributedSizes + // and distributedOffsets accordingly. + if (distributedDims.size() > 0) { + if (distributedDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Source can not be distributed in multiple dimensions."); + int64_t distributedDim = distributedDims[0]; + int sourceDistrDimSize = + extractOp.getSourceVectorType().getShape()[distributedDim]; + auto sourceLayout = + xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0)); + if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty()) + return rewriter.notifyMatchFailure( + warpOp, "the source of extract_strided_slice op lacks distribution " + "layout"); + auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt(); + // Because only single dimension distribution is supported, lane layout + // size at the distributed dim must be the subgroup size. + int subgroupSize = sourceLaneLayout[distributedDim]; + // Check if the source size in the distributed dimension is a multiple of + // subgroup size. + if (sourceDistrDimSize % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, + "Source size along distributed dimension is not a multiple of " + "subgroup size."); + auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt(); + // We expect lane data to be all ones in this case. + if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; })) + return rewriter.notifyMatchFailure( + warpOp, "Expecting unit lane data in source layout"); + // The offsets in the distributed dimention must be a multiple of subgroup + // size. + int64_t distrDimOffset = + cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt(); + if (distrDimOffset % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, "Offset along distributed dimension " + "is not a multiple of subgroup size."); + updatedSourceType = getDistVecTypeBasedOnLaneLayout( + sourceLayout, extractOp.getSourceVectorType()) + .value(); + // Update the distributed sizes to match the distributed type. + updatedSizes[distributedDim] = rewriter.getI64IntegerAttr( + distributedType.getDimSize(distributedDim)); + // Update the distributed offsets to match round robin distribution (i.e. + // each lane owns data at `subgroupSize` stride given unit lane data). + updatedOffsets[distributedDim] = + rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize); + } + // Do the distribution by yielding the source of the extract op from + // the warp op and creating a new extract op outside the warp op. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value source = newWarpOp.getResult(newRetIndices[0]); + // Create a new extract op outside the warp op. + Value newExtractOp = vector::ExtractStridedSliceOp::create( + rewriter, extractOp.getLoc(), distributedType, source, + ArrayAttr::get(rewriter.getContext(), updatedOffsets), + ArrayAttr::get(rewriter.getContext(), updatedSizes), + extractOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp); + return success(); + } +}; + +/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an +/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers +/// advanced cases where the distributed dimension is partially inserted and +/// currently not supported by the generic vector distribution patterns. +struct VectorInsertStridedSliceDistribution + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto insertOp = + operand->get().getDefiningOp<vector::InsertStridedSliceOp>(); + auto distributedType = + cast<VectorType>(warpOp.getResult(operandNumber).getType()); + // Find the distributed dimensions of the dest vector. + auto insertResultType = cast<VectorType>(operand->get().getType()); + auto destDistributedDims = + getDistributedDims(insertResultType, distributedType); + // Collect updated offsets, source type and dest type. They may be adjusted + // later if the data is distributed to lanes (as opposed to being owned by + // all lanes uniformly). + SmallVector<Attribute> updatedOffsets = llvm::map_to_vector( + insertOp.getOffsets(), [](Attribute attr) { return attr; }); + VectorType updatedSourceType = insertOp.getSourceVectorType(); + VectorType updatedDestType = insertOp.getDestVectorType(); + if (destDistributedDims.size() > 0) { + // Only single dimension distribution is supported. + if (destDistributedDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting source to be distributed in a single dimension."); + int64_t destDistributedDim = destDistributedDims[0]; + + VectorType srcType = insertOp.getSourceVectorType(); + VectorType destType = insertOp.getDestVectorType(); + // Currently we require that both source (kD) and dest (nD) vectors are + // distributed. This requires that distributedDim (d) is contained in the + // last k dims of the dest vector (d >= n - k). + int64_t sourceDistributedDim = + destDistributedDim - (destType.getRank() - srcType.getRank()); + if (sourceDistributedDim < 0) + return rewriter.notifyMatchFailure( + insertOp, + "distributed dimension must be in the last k (i.e. source " + "rank) dims of dest vector"); + int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim); + // Obtain the source and dest layouts. + auto destLayout = + xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1)); + auto sourceLayout = + xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0)); + if (!destLayout || !sourceLayout || + destLayout.getEffectiveLaneLayoutAsInt().empty() || + sourceLayout.getEffectiveLaneLayoutAsInt().empty()) + return rewriter.notifyMatchFailure( + warpOp, "the source or dest of insert_strided_slice op lacks " + "distribution layout"); + // Because only single dimension distribution is supported, lane layout + // size at the distributed dim must be the subgroup size. + int subgroupSize = + destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim]; + // We require that source and dest lane data are all ones to ensure + // uniform round robin distribution. + auto destLaneData = destLayout.getEffectiveLaneDataAsInt(); + auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt(); + if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) || + !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; })) + return rewriter.notifyMatchFailure( + warpOp, "Expecting unit lane data in source and dest layouts"); + // Source distributed dim size must be multiples of subgroup size. + if (srcDistrDimSize % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, "Distributed dimension size in source is not a multiple of " + "subgroup size."); + // Offsets in the distributed dimension must be multiples of subgroup + // size. + int64_t destDistrDimOffset = + cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt(); + if (destDistrDimOffset % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, + "Offset along distributed dimension in dest is not a multiple of " + "subgroup size."); + // Update the source and dest types based on their layouts. + updatedSourceType = getDistVecTypeBasedOnLaneLayout( + sourceLayout, insertOp.getSourceVectorType()) + .value(); + updatedDestType = getDistVecTypeBasedOnLaneLayout( + destLayout, insertOp.getDestVectorType()) + .value(); + // Update the distributed offsets to match round robin distribution (i.e. + // each lane owns data at `subgroupSize` stride given unit lane data). + updatedOffsets[destDistributedDim] = + rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize); + } + // Do the distribution by yielding the source and dest of the insert op + // from the warp op and creating a new insert op outside the warp op. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, + {updatedSourceType, updatedDestType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + + Value valueToStore = newWarpOp.getResult(newRetIndices[0]); + Value dest = newWarpOp.getResult(newRetIndices[1]); + // Create a new insert op outside the warp op. + Value newInsertOp = vector::InsertStridedSliceOp::create( + rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest, + ArrayAttr::get(rewriter.getContext(), updatedOffsets), + insertOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), + newInsertOp); + return success(); + } +}; + /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op /// outside of the warp op. @@ -1437,13 +2015,18 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( LoadNdDistribution, DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution, VectorMultiReductionDistribution, LoadDistribution, StoreDistribution, VectorTransposeDistribution, - VectorBitcastDistribution, + VectorBitcastDistribution, LoadMatrixDistribution, + StoreMatrixDistribution, MemrefExtractAlignedPointerAsIndexDistribution>( patterns.getContext(), /*pattern benefit=*/regularPatternBenefit); - patterns.add<VectorShapeCastDistribution>( - patterns.getContext(), - /*pattern benefit=*/highPatternBenefit); + // For following patterns, we need to override the regular vector distribution + // patterns. Therefore, assign higher benefit. + patterns + .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution, + VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>( + patterns.getContext(), + /*pattern benefit=*/highPatternBenefit); } void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns( @@ -1462,6 +2045,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // Layouts are needed for vector type only. if (!isa<VectorType>(operand.get().getType())) continue; + if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op)) + continue; auto layout = xegpu::getDistributeLayoutAttr(operand.get()); if (!layout) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index e6e71cc..af63f09 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value { xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); // return dummy Value to satisfy function's signature return nullptr; }; @@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { return xegpu::LoadNdOp::create( rewriter, loc, newValueTy, convertedTdescs[0], offsets, op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); }; newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createLoad, loc, rewriter); @@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++], convertedTdescs[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); // return dummy Value to satisfy function's signature return nullptr; }; @@ -678,12 +687,16 @@ struct UnrollLoadGatherOpWithOffset pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); } + auto layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); + SmallVector<Value> newOps; for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) { auto newOp = xegpu::LoadGatherOp::create( rewriter, loc, newValueTy, op.getSource(), o, m, rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); newOps.push_back(newOp); } @@ -774,12 +787,16 @@ struct UnrollStoreScatterOpWithOffsets SmallVector<Value> convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); + auto layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); + for (auto [v, o, m] : llvm::zip(convertedValues, convertedOffsets, convertedMasks)) { xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m, rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9fc5ad9..be82cda 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -86,8 +86,16 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, if (origOffsets.empty()) return failure(); + // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr() + xegpu::DistributeLayoutAttr layout; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> || + std::is_same_v<OpType, xegpu::StoreMatrixOp>) { + layout = op.getLayoutAttr(); + } else { + layout = op.getDescLayoutAttr(); + } + // not applicable to ops without workgroup layout attributes - xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -114,7 +122,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory // descriptors to be accessed, based on the layout information. ArrayRef<int64_t> wgShape = op.getDataShape(); - auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto maybeDescOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(maybeDescOffsets)) return failure(); @@ -189,7 +198,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { xegpu::TensorDescType tdescTy = op.getType(); ArrayRef<int64_t> wgShape = tdescTy.getShape(); Type elemTy = tdescTy.getElementType(); - xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; auto newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), @@ -308,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); SmallVector<Value> newOps; for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { @@ -317,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { auto newOp = xegpu::LoadNdOp::create( rewriter, op.getLoc(), newResTy, tdesc, offsets, /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); @@ -338,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); for (auto [v, tdesc, offsets] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); @@ -362,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); @@ -488,10 +506,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } @@ -737,12 +753,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { Location loc = op.getLoc(); auto eltType = vecType.getElementType(); - auto setLayoutIfNeeded = [&](Value val) { - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), - layout.dropSgLayoutAndData()); - } + auto setLayout = [&](Value val) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), + layout.dropSgLayoutAndData()); }; if (vecAttr.isSplat()) { @@ -750,14 +763,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { Attribute singleVal = vecAttr.getSplatValue<Attribute>(); auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); - setLayoutIfNeeded(cstOp->getResult(0)); + setLayout(cstOp->getResult(0)); rewriter.replaceOp(op, cstOp); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all // subgroups, don't distribute auto newConstOp = arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); - setLayoutIfNeeded(newConstOp->getResult(0)); + setLayout(newConstOp->getResult(0)); rewriter.replaceOp(op, newConstOp); return success(); } else { @@ -830,8 +843,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { // Get subgroup id Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); @@ -859,9 +872,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); - setLayoutIfNeeded(baseConstVec); - setLayoutIfNeeded(bcastOffset); - setLayoutIfNeeded(finalConst); + setLayout(baseConstVec); + setLayout(bcastOffset); + setLayout(finalConst); newConstOps.push_back(finalConst); } rewriter.replaceOpWithMultiple(op, {newConstOps}); @@ -912,11 +925,12 @@ struct WgToSgLoadGatherOpWithOffset VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); for (auto [offsets, mask] : llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLayout = layout.dropSgLayoutAndData(); auto newLoadOp = xegpu::LoadGatherOp::create( rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); - xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), - layout.dropSgLayoutAndData()); + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), + newLayout); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); @@ -964,16 +978,14 @@ struct WgToSgStoreScatterOpWithOffset adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { auto store = xegpu::StoreScatterOp::create( rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), + layout.dropSgLayoutAndData()); // Update the layout attribute to drop sg_layout and sg_data. - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - for (OpOperand &operand : store->getOpOperands()) { - // Skip for operand one (memref) - if (operand.getOperandNumber() == 1) - continue; - xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); - } + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); } } rewriter.eraseOp(op); @@ -1052,7 +1064,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> { Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); @@ -1065,15 +1078,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> { vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); auto finalSteps = arith::AddIOp::create(rewriter, loc, steps, bcastOffset); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(steps->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), - layout.dropSgLayoutAndData()); - } + xegpu::setDistributeLayoutAttr(steps->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), + layout.dropSgLayoutAndData()); newOps.push_back(finalSteps); } @@ -1141,10 +1151,8 @@ struct WgToSgVectorShapeCastOp for (auto src : adaptor.getSource()) { auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), newResultType, src); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), + layout.dropSgLayoutAndData()); newShapeCastOps.push_back(newShapeCast.getResult()); } @@ -1205,10 +1213,8 @@ struct WgToSgMultiDimReductionOp auto newOp = vector::MultiDimReductionOp::create( rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], op.getReductionDims()); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newOp->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newOp->getResult(0), + layout.dropSgLayoutAndData()); newReductions.push_back(newOp.getResult()); } @@ -1217,6 +1223,142 @@ struct WgToSgMultiDimReductionOp } }; +// This pattern transforms vector.transpose ops to work at subgroup level. +struct WgToSgVectorTransposeOp + : public OpConversionPattern<vector::TransposeOp> { + using OpConversionPattern<vector::TransposeOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + + ArrayRef<int64_t> wgShape = resultType.getShape(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(op.getVector()); + if (!sourceLayout || !sourceLayout.isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sourceSgLayout = + sourceLayout.getEffectiveSgLayoutAsInt(); + SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt(); + DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); + DenseI32ArrayAttr resultOrder = layout.getOrder(); + + if (!sourceOrder || !resultOrder) { + return rewriter.notifyMatchFailure( + op, "Both source and result must have order attributes"); + } + + ArrayRef<int64_t> permutation = op.getPermutation(); + size_t permutationSize = permutation.size(); + if (sourceSgLayout.size() != permutationSize || + resultSgLayout.size() != permutationSize) { + return rewriter.notifyMatchFailure( + op, "Layouts and permutation must have the same rank"); + } + + // Check that sgLayout, sgData & order are properly transposed for source + // and result + if (!layout.isTransposeOf(sourceLayout, permutation)) + return rewriter.notifyMatchFailure( + op, "Result layout is not a valid transpose of source layout " + "according to permutation"); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + SmallVector<Value> newTransposeOps; + for (auto src : adaptor.getVector()) { + auto newTranspose = vector::TransposeOp::create( + rewriter, op.getLoc(), newResultType, src, permutation); + xegpu::setDistributeLayoutAttr(newTranspose->getResult(0), + layout.dropSgLayoutAndData()); + newTransposeOps.push_back(newTranspose.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newTransposeOps}); + return success(); + } +}; + +// Distribute vector mask ops to work at subgroup level. +template <typename MaskOpType> +struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> { + using OpConversionPattern<MaskOpType>::OpConversionPattern; + + LogicalResult matchAndRewrite( + MaskOpType op, + typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Location loc = op.getLoc(); + VectorType type = op.getResult().getType(); + auto wgShape = type.getShape(); + + SmallVector<Value> wgMaskDimSizes; + if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) { + for (int64_t maskSize : op.getMaskDimSizes()) { + wgMaskDimSizes.push_back( + arith::ConstantIndexOp::create(rewriter, loc, maskSize)); + } + } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) { + wgMaskDimSizes = llvm::to_vector(op.getOperands()); + } + + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType resultType = VectorType::get(sgShape, type.getElementType()); + + // In each dimension, each subgroup computes its local mask size as: + // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d]) + SmallVector<Value> newCreateMaskOps; + for (auto offsetSet : *sgOffsets) { + SmallVector<Value> maskOperands; + + for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) { + Value dimSizeVal = + arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); + Value offset = offsetSet[i]; + Value adjustedMaskSize = + arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value nonNegative = + arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); + Value sgMaskSize = + arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal); + maskOperands.push_back(sgMaskSize); + } + + auto newCreateMaskOp = + vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands); + xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0), + layout.dropSgLayoutAndData()); + newCreateMaskOps.push_back(newCreateMaskOp.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newCreateMaskOps}); + return success(); + } +}; + +using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>; +using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>; } // namespace namespace mlir { @@ -1231,7 +1373,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp>(patterns.getContext()); + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, + WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1358,7 +1502,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>( + target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, + vector::TransposeOp, vector::BroadcastOp, + vector::MultiDimReductionOp, + vector::ConstantMaskOp, vector::CreateMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); @@ -1377,16 +1524,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::BroadcastOp>( - [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - - target.addDynamicallyLegalOp<vector::MultiDimReductionOp>( - [=](vector::MultiDimReductionOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index a38993e..9f126fe 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -140,10 +139,14 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { // for StoreMatrixOp, the layout is attached to the property of the op if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp)) return storeOp.getLayoutAttr(); - std::string layoutName = getLayoutName(result); if (defOp->hasAttr(layoutName)) return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); + + // check for "permament" layout only after "temporary" layout name lookup + // for backward compatibility + if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp)) + return loadGatherOp.getLayoutAttr(); } if (auto arg = dyn_cast<BlockArgument>(value)) { @@ -171,27 +174,77 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) { std::string layoutName = xegpu::getLayoutName(opr); if (op->hasAttr(layoutName)) return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); + + // check for "permament" layout only after "temporary" layout name lookup + if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) + if (auto layout = storeScatterOp.getLayoutAttr()) + return layout; + return getDistributeLayoutAttr(opr.get()); } +// Returns the permanent layout attribute for the given result if it's +// available on the defining op. Otherwise returns the provided layout. +xegpu::DistributeLayoutAttr +maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, + const OpResult &result, mlir::Operation *owner, + const std::string &name) { + xegpu::DistributeLayoutAttr candidate = layout; + + if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) { + if (auto perm = loadOp.getLayoutAttr()) + candidate = perm; + } + + return candidate; +} + +// Returns the permanent layout attribute for the given operand if it's +// available on the defining op. Otherwise returns the provided layout. +xegpu::DistributeLayoutAttr +maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, + const OpOperand &operand, mlir::Operation *owner, + const std::string &name) { + xegpu::DistributeLayoutAttr candidate = layout; + unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber(); + + if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) { + if (idx == 0) { + if (auto perm = storeOp.getLayoutAttr()) + candidate = perm; + } + } + + return candidate; +} + template <typename T, typename> void xegpu::setDistributeLayoutAttr(const T &operandOrResult, - const DistributeLayoutAttr layout) { + const DistributeLayoutAttr layout, + bool respectPermLayout) { Operation *owner = operandOrResult.getOwner(); std::string name = xegpu::getLayoutName(operandOrResult); - if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name)) - owner->setAttr(name, layout); + + if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) + return; + + DistributeLayoutAttr candidate = layout; + if (respectPermLayout) + candidate = maybePickPermanentLayout(layout, operandOrResult, owner, name); + + if (candidate) + owner->setAttr(name, candidate); } // Explicit instantiation for OpResult template void xegpu::setDistributeLayoutAttr<mlir::OpResult>( const mlir::OpResult &result, - const mlir::xegpu::DistributeLayoutAttr layout); + const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout); // Explicit instantiation for OpOperand template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>( const mlir::OpOperand &operand, - const mlir::xegpu::DistributeLayoutAttr layout); + const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout); void xegpu::setDistributeLayoutAttrs( Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) { @@ -253,7 +306,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, int64_t rankDiff = srcShapeRank - targetShapeRank; std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff, 1); - std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff); + llvm::copy(shape, adjustedTargetShape.begin() + rankDiff); SmallVector<Value> result; for (SmallVector<int64_t> offsets : @@ -473,7 +526,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder, for (auto [l, r] : llvm::zip_equal(lhs, rhs)) { auto lval = getValueOrCreateConstantIndexOp(builder, loc, l); auto rval = getValueOrCreateConstantIndexOp(builder, loc, r); - results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval)); + results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval)); } return results; } @@ -500,3 +553,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc, results.append(addElementwise(builder, loc, a, b)); return results; } + +template <typename T> +int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates, + ArrayRef<T> candidateMultiples) { + static_assert(std::is_integral<T>::value, "T must be an integer type"); + int largest = -1; + SmallVector<T> multiples = {1}; + if (!candidateMultiples.empty()) + multiples = + SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end()); + for (T candidate : candidates) { + for (T multiple : multiples) { + int value = static_cast<int>(candidate * multiple); + if (value != 0 && dim % value == 0 && value > largest) + largest = value; + } + } + return largest; +} + +/// Explicit instantiations +template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates, + ArrayRef<int> candidateMultiples); +template int +xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates, + ArrayRef<unsigned> candidateMultiples); |
