aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/XeGPU')
-rw-r--r--mlir/lib/Dialect/XeGPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp390
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp137
-rw-r--r--mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt17
-rw-r--r--mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp695
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp490
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp556
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp619
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp27
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp267
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp97
12 files changed, 2950 insertions, 347 deletions
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6..46b8251 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index f9aa28d5..1a19ab5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -8,10 +8,8 @@
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -38,55 +36,61 @@ void XeGPUDialect::initialize() {
>();
}
-/// Generates instructions to compute offsets for a subgroup identified by
-/// its multidimensional indices (sgId), using the specified subgroup layout
-/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
-/// dimensions (sizePerWg).
+// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
+// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
+// within each distribution unit.
+// Example:
+// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
+// distribution unit of shape 64x64, we have 2x4 such distribution units.
+// `delinearizedId` is used to identify a 16x32 of a subgroup in each
+// distribution unit.
static SmallVector<SmallVector<Value>>
-genOffsetsComputingInsts(OpBuilder &builder, Location loc,
- SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
- ArrayRef<int64_t> sizePerSg,
- ArrayRef<int64_t> sizePerWg) {
-
- SmallVector<SmallVector<Value>> offsets;
+genCoordinates(OpBuilder &builder, Location loc,
+ SmallVector<Value> delinearizedId,
+ ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+ ArrayRef<int64_t> srcShape) {
+ SmallVector<SmallVector<Value>> coordinates;
+
+ // A distribution unit must be less than or equal to `srcShape`
+ SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
+ llvm::zip_equal(srcShape,
+ computeElementwiseMul(subShapesLayout, subShape)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
- // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
- SmallVector<Value> localOffsets = llvm::map_to_vector(
- llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
- return builder.createOrFold<index::MulOp>(
+ // Get the offset of `subShape` within a distribution unit.
+ SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
+ llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::MulIOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});
- // distUnit[i] is the minimum value between sizePerWg[i] and
- // sgLayout[i] * sizePerSg[i]
- SmallVector<int64_t> distUnit = llvm::map_to_vector(
- llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
- [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
-
+ // For each dist unit
for (SmallVector<int64_t> unitOffs :
- StaticTileOffsetRange(sizePerWg, distUnit)) {
+ StaticTileOffsetRange(srcShape, distUnitShape)) {
+ // Get dist unit offset within `srcShape`.
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return arith::ConstantIndexOp::create(builder, loc, d);
});
-
- SmallVector<Value> adds = llvm::map_to_vector(
- llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
- return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
- std::get<1>(t));
- });
-
+ // Calculate `subShape` offset within `srcShape`.
+ SmallVector<Value> adds =
+ llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
+ [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(
+ loc, std::get<0>(t), std::get<1>(t));
+ });
+ // Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
- llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
- return builder.createOrFold<index::RemUOp>(
+ llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::RemUIOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
});
- offsets.push_back(mods);
+ coordinates.push_back(mods);
}
- return offsets;
+ return coordinates;
}
// Checks if the given shape can be evenly distributed based on the layout
@@ -229,8 +233,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
- return emitError()
- << "expected inst_data and lane_layout to have the same rank";
+ return emitError() << "expected inst_data and lane_layout to have the same "
+ "rank, got inst_data "
+ << inst_data.size() << ", lane_layout "
+ << lane_layout.size();
}
// sg_data is optional for Workgroup layout, but its presence requires
@@ -271,56 +277,197 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
- // delinearizeSubgroupId is only available for
- // workgroup-level layout attribute
- if (!isForWorkgroup())
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
+
+ SmallVector<int64_t> sgLayoutInt;
+ if (isForWorkgroup()) {
+ sgLayoutInt = getEffectiveSgLayoutAsInt();
+ } else if (isForSubgroup()) {
+ sgLayoutInt = getEffectiveLaneLayoutAsInt();
+ } else {
return failure();
+ }
- // TODO: handle order attribute
- auto hasDefaultOrder = [&]() {
- DenseI32ArrayAttr order = getOrder();
- return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
- llvm::reverse(order.asArrayRef())));
- };
- if (!hasDefaultOrder())
- return mlir::emitError(loc, "order attribute is currently not supported.");
+ DenseI32ArrayAttr orderAttr = getOrder();
- auto dims =
- llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
+ // Handle order attribute
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::to_vector(
+ llvm::map_range(orderAttr.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+ } else {
+ // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+ order = llvm::to_vector(
+ llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+ }
- return affine::delinearizeIndex(builder, loc, linearId, dims);
+ if (order.size() != sgLayoutInt.size()) {
+ return failure();
+ }
+
+ SmallVector<Value> result(sgLayoutInt.size());
+ Value remaining = linearId;
+
+ /// Process dimensions in the order they appear in the order array
+ /// The first dimension in order is the fastest-changing
+ ///
+ /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+ ///
+ /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
+ /// result=[?,?,?]
+ ///
+ /// i=0 (process columns, dimIdx=2, dimSize=4):
+ /// result[2] = 22 % 4 = 2 (column coordinate)
+ /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
+ ///
+ /// i=1 (process rows, dimIdx=1, dimSize=4):
+ /// result[1] = 5 % 4 = 1 (row coordinate)
+ /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
+ ///
+ /// i=2 (process layers, dimIdx=0, dimSize=2):
+ /// result[0] = 1 % 2 = 1 (layer coordinate)
+ /// (no remaining update - last iteration)
+ ///
+ /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ int64_t dimSize = sgLayoutInt[dimIdx];
+
+ Value dimSizeVal =
+ builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+ /// Extract the coordinate for this dimension using modulo operation
+ /// This gives us "how far within this dimension" we are
+ /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
+ /// this dimension)
+ result[dimIdx] =
+ builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
+
+ /// Update remaining for the next dimension by removing what we've already
+ /// processed. Division tells us "how many complete groups of this dimension
+ /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+ /// completed 5 groups of 4) Skip this for the last iteration since there's
+ /// no next dimension to process
+ if (i < order.size() - 1) {
+ remaining =
+ builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
+ }
+ }
+ return result;
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
- if (!isForWorkgroup())
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
return failure();
-
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ }
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
- SmallVector<Value> sgIds = *maybeIds;
+ SmallVector<Value> ids = *maybeIds;
+
+ return genCoordinates(builder, loc, ids, layout, subShape, shape);
+}
+
+bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::SliceAttr>(other))
+ return false;
+
+ return *this == dyn_cast<xegpu::LayoutAttr>(other);
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ auto sgDataOpt = getSgData();
+ auto instDataOpt = getInstData();
+ auto laneDataOpt = getLaneData();
+
+ SmallVector<int32_t> sgData;
+ SmallVector<int32_t> instData;
+ SmallVector<int32_t> laneData;
+
+ if (sgDataOpt) {
+ sgData = llvm::to_vector(sgDataOpt.asArrayRef());
+ }
+ if (instDataOpt) {
+ instData = llvm::to_vector(instDataOpt.asArrayRef());
+ }
+ if (laneDataOpt) {
+ laneData = llvm::to_vector(laneDataOpt.asArrayRef());
+ }
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgData.size()))
+ sgData[dim] = 1;
+ if (dim < static_cast<int64_t>(instData.size()))
+ instData[dim] = 1;
+ if (dim < static_cast<int64_t>(laneData.size()))
+ laneData[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(), getSgLayout(),
+ sgData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgData),
+ instData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), instData),
+ getLaneLayout(),
+ laneData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneData),
+ getOrder());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ auto sgLayoutOpt = getSgLayout();
+ auto laneLayoutOpt = getLaneLayout();
+
+ SmallVector<int32_t> sgLayout;
+ SmallVector<int32_t> laneLayout;
+
+ if (sgLayoutOpt) {
+ sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
+ }
+ if (laneLayoutOpt) {
+ laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
+ }
+
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgLayout.size()))
+ sgLayout[dim] = 1;
+ if (dim < static_cast<int64_t>(laneLayout.size()))
+ laneLayout[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(),
+ sgLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgLayout),
+ getSgData(), getInstData(),
+ laneLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneLayout),
+ getLaneData(), getOrder());
}
//===----------------------------------------------------------------------===//
@@ -374,34 +521,43 @@ SliceAttr SliceAttr::flatten() const {
}
FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
- return parent.delinearizeSubgroupId(builder, loc, linearId);
+ return parent.delinearizeId(builder, loc, linearId);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
-/// instructions for computing multi-dimensional offsets when distributed by
-/// SliceAttr.
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// instructions for computing multi-dimensional offsets when distributed by
+// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
+ return failure();
+ }
+
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
@@ -411,8 +567,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
SmallVector<Value> sgIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
}
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
@@ -435,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
[&](int64_t dim) { return thisDims.contains(dim); });
}
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::LayoutAttr>(other))
+ return false;
+
+ auto flattenedThis = flatten();
+ auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+ return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+ (flattenedThis.getDims() == flattenedOther.getDims()));
+}
+
+// Helper function to adjust unit dimensions from sliced space to parent space
+static SetVector<int64_t>
+adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
+ ArrayRef<int64_t> sliceDims) {
+ // Reconstruct parent's non-sliced dimensions
+
+ int64_t parentRank = sliceDims.size() + unitDims.size();
+ llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
+ sliceDims.end());
+ SmallVector<int64_t> nonSlicedDims;
+ for (int64_t i = 0; i < parentRank; ++i) {
+ if (!slicedDimsSet.contains(i))
+ nonSlicedDims.push_back(i);
+ }
+
+ // Map unit dims from sliced space to parent space
+ SetVector<int64_t> adjustUnitDims;
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
+ adjustUnitDims.insert(nonSlicedDims[dim]);
+ }
+ }
+
+ return adjustUnitDims;
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
+ attr.getDims());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
+ attr.getDims());
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
@@ -569,8 +787,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
int chunkAlignmentFactor =
- bitWidth < targetinfo::packedSizeInBitsForGatherScatter
- ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
+ bitWidth < xegpu::uArch::generalPackedFormatBitSize
+ ? xegpu::uArch::generalPackedFormatBitSize / bitWidth
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2..91ba07a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
LogicalResult
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
- UnitAttr subgroup_block_io,
+ UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
function_ref<InFlightDiagnostic()> emitError) {
if (!dataTy) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
+ "are only allowed when result is a VectorType.";
else
return success();
}
@@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ ArrayAttr strideAttr = mdescTy.getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+ if (subgroup_block_io && layout) {
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (!laneData.empty()) {
+ bool isLaneDataContiguous =
+ std::all_of(laneData.begin(), std::prev(laneData.end()),
+ [](int x) { return x == 1; });
+ if (!isLaneDataContiguous)
+ return emitError() << "With subgroup_block_io, accessed data must be "
+ "contiguous and coalesced.";
+ for (size_t i = 0; i < laneData.size(); ++i) {
+ if (laneLayout[i] != blockShape[i])
+ return emitError() << "With subgroup_block_io, the block shape must "
+ "match the lane layout.";
+ if (laneLayout[i] != 1 && strides[i] != 1)
+ return emitError() << "With subgroup_block_io, the distributed "
+ "dimensions must be contiguous.";
+ }
+ }
+ }
if (dataShape.size() == 2) {
- if (subgroup_block_io)
- return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
} else {
- SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
// if the subgroup_block_io attribute is set, mdescTy must have block
// attribute
if (subgroup_block_io && !blockShape.size())
@@ -258,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -320,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -439,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l3_hint) {
return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -454,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult PrefetchNdOp::verify() {
@@ -472,11 +499,8 @@ LogicalResult PrefetchNdOp::verify() {
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -496,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
return build(builder, state, retType, tensorDesc, ValueRange(),
DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
- l3_hint);
+ l3_hint, /*anchor_layout=*/nullptr);
}
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
@@ -504,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
UnitAttr packed, DenseI64ArrayAttr transpose,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -512,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
- packed, transpose, l1_hint, l2_hint, l3_hint);
+ packed, transpose, l1_hint, l2_hint, l3_hint,
+ /*anchor_layout=*/layout);
}
LogicalResult LoadNdOp::verify() {
@@ -597,11 +623,8 @@ LogicalResult LoadNdOp::verify() {
<< tdescTy;
int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -618,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
xegpu::CachePolicyAttr l3_hint) {
return build(builder, state, value, tensorDesc, ValueRange(),
- DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+ DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
+ /*anchor_layout=*/nullptr);
}
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -633,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult StoreNdOp::verify() {
@@ -691,11 +716,8 @@ LogicalResult StoreNdOp::verify() {
<< dstTy;
int64_t tDescRank = dstTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -809,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
- IntegerAttr{});
+ IntegerAttr{}, /*anchor_layout=*/nullptr);
}
//===----------------------------------------------------------------------===//
@@ -859,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -875,7 +897,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/nullptr);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint,
+ DistributeLayoutAttr layout) {
+ auto loc = source.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+ l2_hint, l3_hint, layout);
}
//===----------------------------------------------------------------------===//
@@ -926,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -944,7 +983,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
- l3_hint);
+ l3_hint, /*anchor_layout=*/nullptr);
+}
+
+void StoreScatterOp::build(
+ OpBuilder &builder, OperationState &state, Value value, Value dest,
+ ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
+ xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
+ auto loc = dest.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ // Call the correct builder overload that does not expect result types.
+ build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+ l3_hint, layout);
}
//===----------------------------------------------------------------------===//
@@ -1105,7 +1160,7 @@ LogicalResult LoadMatrixOp::verify() {
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1129,7 +1184,7 @@ LogicalResult StoreMatrixOp::verify() {
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
namespace mlir {
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000..48fe841
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+ XeGPUTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+ DEPENDS
+ MLIRXeGPUTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRXeGPUDialect
+ MLIRXeGPUTransforms
+ MLIRIR
+ MLIRTransformDialect
+ MLIRFuncDialect
+ MLIRSCFDialect
+)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000..e6009d5
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,695 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+
+#include <optional>
+
+#include "llvm/Support/DebugLog.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+ transform::TransformState &state, TransformOpInterface transformOp,
+ SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+ for (OpFoldResult ofr : ofrs) {
+ // Attribute case.
+ if (auto attr = dyn_cast<Attribute>(ofr)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ result.push_back(intAttr.getInt());
+ continue;
+ }
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ }
+
+ // Transform param case.
+ Value transformValue = cast<Value>(ofr);
+ if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure()
+ << "requires exactly one parameter associated";
+ result.push_back(
+ cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+ continue;
+ }
+
+ // Payload value case.
+ auto payloadOps = state.getPayloadOps(transformValue);
+ if (!llvm::hasSingleElement(payloadOps)) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "handle must be mapped to exactly one payload op";
+ diag.attachNote(transformValue.getLoc())
+ << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+ return diag;
+ }
+
+ Operation *op = *payloadOps.begin();
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+
+ IntegerAttr intAttr;
+ if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+ return transformOp.emitSilenceableError()
+ << "requires param or handle to be the result of a constant like "
+ "op";
+
+ result.push_back(intAttr.getInt());
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Find producer operation of type T for the given value.
+/// It's assumed that producer ops are chained through their first operand.
+/// Producer chain is traced trough loop block arguments (init values).
+template <typename T>
+static std::optional<T> findProducerOfType(Value val) {
+ Value currentValue = val;
+ if (!currentValue.getDefiningOp()) {
+ // Value may be a block argument initialized outside a loop.
+ if (val.getNumUses() == 0) {
+ LDBG() << "Failed to find producer op, value has no uses.";
+ return std::nullopt;
+ }
+ auto userOp = val.getUsers().begin();
+ auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
+ if (!parentLoop) {
+ LDBG() << "Failed to find producer op, not in a loop.";
+ return std::nullopt;
+ }
+ int64_t iterArgIdx;
+ if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
+ auto numInductionVars = parentLoop.getLoopInductionVars()->size();
+ iterArgIdx = iterArg.getArgNumber() - numInductionVars;
+ currentValue = parentLoop.getInits()[iterArgIdx];
+ } else {
+ LDBG() << "Failed to find producer op, value not in init values.";
+ return std::nullopt;
+ }
+ }
+ Operation *producerOp = currentValue.getDefiningOp();
+
+ if (auto matchingOp = dyn_cast<T>(producerOp))
+ return matchingOp;
+
+ if (producerOp->getNumOperands() == 0)
+ return std::nullopt;
+
+ return findProducerOfType<T>(producerOp->getOperand(0));
+}
+
+/// Create a layout attribute from the given parameters.
+static xegpu::LayoutAttr
+createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+ ArrayRef<int32_t> sgData,
+ std::optional<ArrayRef<int32_t>> instData) {
+ return xegpu::LayoutAttr::get(
+ ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+ DenseI32ArrayAttr::get(ctx, sgData),
+ instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+ /*lane_layout=*/nullptr,
+ /*lane_data=*/nullptr,
+ /*order=*/nullptr);
+}
+
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
+ TransformOpInterface transformOp,
+ ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+ ArrayRef<::mlir::OpFoldResult> mixedSgData,
+ ArrayRef<::mlir::OpFoldResult> mixedInstData,
+ xegpu::LayoutAttr &layoutAttr) {
+ SmallVector<int32_t> sgLayout, sgData, instData;
+ auto status =
+ convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+ if (!status.succeeded())
+ return status;
+ auto maybeInstData = instData.empty()
+ ? std::nullopt
+ : std::optional<ArrayRef<int32_t>>(instData);
+
+ layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+static xegpu::CreateNdDescOp
+setDescLayout(transform::TransformRewriter &rewriter,
+ xegpu::CreateNdDescOp descOp,
+ xegpu::DistributeLayoutAttr layout) {
+ assert(descOp.getMixedOffsets().size() == 0 &&
+ "create desc op with offsets is not supported");
+ auto oldTensorDesc = descOp.getType();
+ auto descType = xegpu::TensorDescType::get(
+ oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
+ /*array_length=*/oldTensorDesc.getArrayLength(),
+ /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
+ /*memory_space=*/oldTensorDesc.getMemorySpace(),
+ /*layout=*/layout);
+
+ rewriter.setInsertionPointAfter(descOp);
+ auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+ descOp.getMixedStrides());
+ return newDescOp;
+}
+
+DiagnosedSilenceableFailure
+transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues)) {
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ }
+
+ auto maybeDescOp =
+ findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
+ if (!maybeDescOp) {
+ return emitSilenceableFailure(getLoc())
+ << "Could not find a matching descriptor op when walking the "
+ "producer chain of the first operand.";
+ }
+
+ results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+ OperationState &result, Value target,
+ ArrayRef<OpFoldResult> mixedSgLayout,
+ ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData,
+ ArrayRef<int64_t> sliceDims) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, result, target.getType(),
+ /*target=*/target,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData,
+ /*slice_dims=*/sliceDims);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ xegpu::DistributeLayoutAttr layout = layoutAttr;
+ auto sliceDims = getSliceDims();
+ if (sliceDims.size() > 0) {
+ // Wrap layoutAttr in a slice attribute.
+ layout = xegpu::SliceAttr::get(
+ getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+ }
+
+ // For now only create_nd_desc op is supported.
+ auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+ if (!descOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a xegpu.create_nd_desc op, but got: "
+ << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Set layout attr in desc op's return type. Replaces old desc op.
+ auto newdescOp = setDescLayout(rewriter, descOp, layout);
+
+ // Map result handles.
+ results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+void transform::SetOpLayoutAttrOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+ ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
+ bool result) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*index=*/index,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData,
+ /*slice_dims=*/sliceDims,
+ /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ bool resultTarget = getResult();
+
+ int64_t index = getIndex();
+ if (resultTarget && index >= target->getNumResults()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op results";
+ }
+ if (!resultTarget && index >= target->getNumOperands()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op operands";
+ }
+
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ xegpu::DistributeLayoutAttr layout = layoutAttr;
+ auto sliceDims = getSliceDims();
+ if (sliceDims.size() > 0) {
+ // Wrap layoutAttr in a slice attribute.
+ layout = xegpu::SliceAttr::get(
+ getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+ }
+
+ // Set layout attribute for the op result or operand
+ if (resultTarget)
+ xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
+ else
+ xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ modifiesPayload(effects);
+}
+
+void transform::SetGPULaunchThreadsOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target,
+ ArrayRef<OpFoldResult> mixedThreads) {
+ SmallVector<int64_t> staticThreads;
+ SmallVector<Value> dynamicThreads;
+ dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*threads=*/dynamicThreads,
+ /*static_threads=*/staticThreads);
+}
+
+DiagnosedSilenceableFailure
+transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ auto launchOp = dyn_cast<gpu::LaunchOp>(target);
+ if (!launchOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a gpu.launch op, but got: " << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<int32_t> threads;
+ DiagnosedSilenceableFailure status =
+ convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
+ if (!status.succeeded())
+ return status;
+
+ if (threads.size() != 3) {
+ return emitSilenceableFailure(getLoc())
+ << "Expected threads argument to consist of three values (got "
+ << threads.size() << ")";
+ }
+
+ rewriter.setInsertionPoint(launchOp);
+ auto createConstValue = [&](int value) {
+ return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
+ };
+
+ // Replace threads in-place.
+ launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
+ launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
+ launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetGPULaunchThreadsOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getThreadsMutable(), effects);
+ modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues))
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ auto value = *targetValues.begin();
+
+ int64_t nbPrefetch = getStaticNbPrefetch();
+ if (getDynamicNbPrefetch()) {
+ // Get dynamic prefetch count from transform param or handle.
+ SmallVector<int32_t> dynamicNbPrefetch;
+ auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
+ {getDynamicNbPrefetch()});
+ if (!status.succeeded())
+ return status;
+ if (dynamicNbPrefetch.size() != 1)
+ return emitDefiniteFailure()
+ << "requires exactly one value for dynamic_nb_prefetch";
+ nbPrefetch = dynamicNbPrefetch[0];
+ }
+ if (nbPrefetch <= 0)
+ return emitSilenceableFailure(getLoc())
+ << "nb_prefetch must be a positive integer.";
+
+ // Find load operation of the operand.
+ auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+ if (!maybeLoadOp)
+ return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+ auto loadOp = *maybeLoadOp;
+ if (loadOp.getMixedOffsets().size() == 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op must have offsets.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find the parent scf.for loop.
+ auto forOp = loadOp->getParentOfType<scf::ForOp>();
+ if (!forOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op is not contained in a scf.for loop.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find descriptor op.
+ auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+ if (!maybeDescOp)
+ return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+ auto descOp = *maybeDescOp;
+ if (descOp.getMixedOffsets().size() > 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "desc op with offsets is not supported.";
+ diag.attachNote(descOp.getLoc()) << "desc op";
+ }
+
+ // Clone desc op outside the loop.
+ rewriter.setInsertionPoint(forOp);
+ auto newDescOp =
+ cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
+
+ // Clone reduction loop to emit initial prefetches.
+ // Compute upper bound of the init loop: start + nbPrefetch * step.
+ auto nbPrefetchCst =
+ arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
+ auto nbStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), nbPrefetchCst, forOp.getStep());
+ auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
+ forOp.getLoc(), forOp.getLowerBound(), nbStep);
+ auto initForOp =
+ scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+ initUpBound, forOp.getStep());
+
+ auto ctx = rewriter.getContext();
+ auto readCacheHint =
+ xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
+
+ // Modify loadOp mixedOffsets by replacing the for loop induction variable
+ // with the given value.
+ auto getPrefetchOffsets =
+ [&](Value replacementVal) -> SmallVector<OpFoldResult> {
+ IRMapping mapping;
+ mapping.map(forOp.getInductionVar(), replacementVal);
+ SmallVector<Value> dynamicOffsets =
+ llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
+ return mapping.lookupOrDefault(v);
+ }));
+ auto constOffsets = loadOp.getConstOffsets().value();
+ return getMixedValues(constOffsets, dynamicOffsets, ctx);
+ };
+
+ // Insert prefetch op in init loop.
+ // Replace induction var with the init loop induction var.
+ rewriter.setInsertionPointToStart(initForOp.getBody());
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(initForOp.getInductionVar()),
+ readCacheHint, readCacheHint, readCacheHint,
+ /*layout=*/nullptr);
+
+ // Insert prefetch op in main loop.
+ // Calculate prefetch offset after the init prefetches have been issued.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
+ forOp.getInductionVar(), nbStep);
+ // Replace induction var with correct offset.
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(prefetchOffset), readCacheHint,
+ readCacheHint, readCacheHint, /*layout=*/nullptr);
+
+ // Unroll the init loop.
+ if (failed(loopUnrollFull(initForOp)))
+ return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
+
+ results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::InsertPrefetchOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+void transform::ConvertLayoutOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target,
+ ArrayRef<OpFoldResult> mixedInputSgLayout,
+ ArrayRef<OpFoldResult> mixedInputSgData,
+ ArrayRef<OpFoldResult> mixedInputInstData,
+ ArrayRef<OpFoldResult> mixedTargetSgLayout,
+ ArrayRef<OpFoldResult> mixedTargetSgData,
+ ArrayRef<OpFoldResult> mixedTargetInstData) {
+ SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
+ staticInputInstData;
+ SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
+ dynamicInputInstData;
+ dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
+ staticInputSgLayout);
+ dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
+ staticInputSgData);
+ dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
+ staticInputInstData);
+ SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
+ staticTargetInstData;
+ SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
+ dynamicTargetInstData;
+ dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
+ staticTargetSgLayout);
+ dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
+ staticTargetSgData);
+ dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
+ staticTargetInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*input_sg_layout=*/dynamicInputSgLayout,
+ /*input_sg_data=*/dynamicInputSgData,
+ /*input_inst_data=*/dynamicInputInstData,
+ /*target_sg_layout=*/dynamicTargetSgLayout,
+ /*target_sg_data=*/dynamicTargetSgData,
+ /*target_inst_data=*/dynamicTargetInstData,
+ /*static_input_sg_layout=*/staticInputSgLayout,
+ /*static_input_sg_data=*/staticInputSgData,
+ /*static_input_inst_data=*/staticInputInstData,
+ /*static_target_sg_layout=*/staticTargetSgLayout,
+ /*static_target_sg_data=*/staticTargetSgData,
+ /*static_target_inst_data=*/staticTargetInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues))
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ auto value = *targetValues.begin();
+
+ // Construct layout attributes.
+ xegpu::LayoutAttr inputLayoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(
+ getContext(), state, (*this), getMixedInputSgLayout(),
+ getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ xegpu::LayoutAttr targetLayoutAttr = nullptr;
+ status = getLayoutAttrFromOperands(
+ getContext(), state, (*this), getMixedTargetSgLayout(),
+ getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ // Find first user op to define insertion point for layout conversion.
+ if (value.use_empty())
+ return emitSilenceableFailure(getLoc())
+ << "Value has no users to insert layout conversion.";
+ Operation *userOp = *value.getUsers().begin();
+
+ // Emit convert_layout op.
+ rewriter.setInsertionPoint(userOp);
+ auto convLayoutOp =
+ xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
+ value, inputLayoutAttr, targetLayoutAttr);
+ // Replace load op result with the converted layout.
+ rewriter.replaceUsesWithIf(
+ value, convLayoutOp.getResult(), [&](OpOperand &use) {
+ return use.getOwner() != convLayoutOp.getOperation();
+ });
+
+ results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ConvertLayoutOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getInputSgLayoutMutable(), effects);
+ onlyReadsHandle(getInputSgDataMutable(), effects);
+ onlyReadsHandle(getInputInstDataMutable(), effects);
+ onlyReadsHandle(getTargetSgLayoutMutable(), effects);
+ onlyReadsHandle(getTargetSgDataMutable(), effects);
+ onlyReadsHandle(getTargetInstDataMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+namespace {
+class XeGPUTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ XeGPUTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+ using Base::Base;
+
+ void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+ declareGeneratedDialect<scf::SCFDialect>();
+ declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+ >();
+}
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+ registry.addExtensions<XeGPUTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e6f7606..29b645f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUWgToSgDistribute.cpp
XeGPUPropagateLayout.cpp
XeGPUVectorLinearize.cpp
+ XeGPUOptimizeBlockLoads.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
new file mode 100644
index 0000000..ab41fe4
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -0,0 +1,490 @@
+//===- XeGPUOptimizeBlockLoads.cpp - XeGPU optimize block loads -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUOPTIMIZEBLOCKLOADS
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-optimize-block-loads"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+
+/// Get the 2D lane data from a tensor desc type if it exists.
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneData(xegpu::TensorDescType tdescType) {
+ auto layout = tdescType.getLayoutAttr();
+ if (!layout)
+ return std::nullopt;
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ if (laneData.size() != 2)
+ return std::nullopt;
+ return laneData;
+}
+
+/// Get the 2D lane layout from a tensor desc type if it exists.
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
+ auto layout = tdescType.getLayoutAttr();
+ if (!layout)
+ return std::nullopt;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (laneLayout.size() != 2)
+ return std::nullopt;
+ return laneLayout;
+}
+
+/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 &&
+/// lane[1] == 1), but inner lane data is not equal to [1, 1].
+/// Example:
+/// !xegpu.tensor_desc<16x16xf16,
+/// #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+/// In this case, lane layout is transposed (from the usual [1, SG_SIZE] form)
+/// indicating that this is a load that requires transpose effect. However,
+/// lane data is [1, 2], meaning that each lane must grab 2 f16 elements from
+/// the inner dimension. We convert this to a optimized form by converting the
+/// tensor_desc to i32 type such that lane data becomes [1, 1]. This makes the
+/// later lowering easily use the load with transpose instruction.
+static bool canBeOptimizedForTranspose(ArrayRef<int64_t> laneLayout,
+ ArrayRef<int64_t> laneData) {
+ if (laneLayout.size() != 2 || laneData.size() != 2)
+ return false;
+ if (laneLayout[0] == 1 || laneLayout[1] != 1)
+ return false;
+ if (laneData[0] != 1 || laneData[1] == 1)
+ return false;
+ return true;
+}
+
+/// A tensor desc type can be optimized if its element type is less than 32 bits
+/// and its layout can be optimized.
+static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
+ // If the dtype is greater or equal to 32 bits, layout must be valid.
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ if (elementTyBitwidth >= 32)
+ return false;
+ auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
+ auto maybeLaneData = getMaybeLaneData(tdescType);
+ if (!maybeLaneData || !maybeLaneLayout)
+ return false;
+ return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
+}
+
+/// Check if a tensor desc type can be optimized for transpose, if so return the
+/// new optimized tensor desc type with a valid transpose layout.
+static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
+ const uArch *targetuArch) {
+ if (!canBeOptimizedForTranspose(tdescType))
+ return tdescType;
+ auto laneData = getMaybeLaneData(tdescType)
+ .value(); // Lane data must exist if we reach here.
+ int64_t innerLaneData = laneData[1];
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ // Required shape is total shape of the vector result that this tensor desc
+ // must eventually load after adjusting for the new bitwidth and array
+ // length.
+ SmallVector<int64_t> requiredShape(tdescType.getShape());
+ requiredShape.back() =
+ requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
+ int newBitWidth = elementTyBitwidth * innerLaneData;
+ Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
+ // Supported shape is the max transpose shape that can be supported by
+ // hardware that is less than or equal to required shape.
+ auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
+ targetuArch->getInstruction(InstructionKind::Subgroup2DBlockLoad));
+ auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
+ newElemTy, /** has transform */ false, /** has transpose */ true);
+ // If no HW params found, return the original type.
+ if (!maybeHWParams)
+ return tdescType;
+ auto [widths, heights, counts] = maybeHWParams.value();
+ // TODO: Currently we expect array length to be 1 for transpose case.
+ if (counts.size() != 1 || counts[0] != 1)
+ return tdescType;
+ int arrayLen = counts[0];
+ int supportedHeight =
+ xegpu::getLargestDivisor(static_cast<int>(requiredShape[0]), heights);
+ int supportedWidth =
+ xegpu::getLargestDivisor(static_cast<int>(requiredShape[1]), widths);
+ // If no supported height or width found, return the original type.
+ if (supportedHeight == -1 || supportedWidth == -1)
+ return tdescType;
+
+ SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+ tdescType.getContext(),
+ tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+ // Array length can not be larger than 1 for transpose case.
+ return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
+ tdescType.getBoundaryCheck(),
+ tdescType.getMemorySpace(), newLayout);
+}
+
+/// Helper to convert an OpFoldResult to Value.
+static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
+ OpFoldResult ofr) {
+ std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
+ if (mayBeInt)
+ return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt).getResult();
+ return llvm::cast<Value>(ofr);
+}
+
+/// Helper to divide a Value by a constant integer.
+static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
+ Value val, int64_t constant) {
+ // If the constant is a power of 2, use right shift for division.
+ if (llvm::isPowerOf2_64(constant)) {
+ int64_t shiftAmount = llvm::Log2_64(constant);
+ return arith::ShRUIOp::create(
+ rewriter, loc, val,
+ arith::ConstantIndexOp::create(rewriter, loc, shiftAmount)
+ .getResult())
+ .getResult();
+ }
+ auto constantOp =
+ arith::ConstantIndexOp::create(rewriter, loc, constant).getResult();
+ return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
+}
+
+/// This function takes a larger register block `data` and generates multiple
+/// smaller loads (size given by `newTensorDesc`) to fill in the `data` block
+/// starting from `offsets`.
+static Value generateLoads(ConversionPatternRewriter &rewriter,
+ TypedValue<VectorType> data,
+ SmallVector<OpFoldResult> offsets,
+ TypedValue<xegpu::TensorDescType> newTensorDesc,
+ xegpu::LoadNdOp origLoadOp) {
+ Location loc = data.getLoc();
+ assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
+ Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
+ Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
+ SmallVector<int64_t> supportedShape(newTensorDesc.getType().getShape());
+ // Compute the ratio between original shape and supported shape. We need to
+ // generate loads in this ratio arrangement.
+ auto shapeRatio = computeShapeRatio(data.getType().getShape(),
+ supportedShape)
+ .value(); // `ratio` must be defined if we reach here.
+ for (int64_t h = 0; h < shapeRatio[0]; ++h) {
+ for (int64_t w = 0; w < shapeRatio[1]; ++w) {
+ int64_t localOffsetDim0 = h * supportedShape[0];
+ int64_t localOffsetDim1 = w * supportedShape[1];
+ Value loadOffsetX = arith::AddIOp::create(
+ rewriter, loc, offsetDim0,
+ arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim0)
+ .getResult());
+ Value loadOffsetY = arith::AddIOp::create(
+ rewriter, loc, offsetDim1,
+ arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim1)
+ .getResult());
+ auto loadOp = xegpu::LoadNdOp::create(
+ rewriter, loc,
+ VectorType::get(supportedShape, data.getType().getElementType()),
+ newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
+ origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
+ origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
+ origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
+ // Set the layout for the loadOp.
+ auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
+ xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
+ // Insert the loaded block into the right position in data.
+ auto insertOp = vector::InsertStridedSliceOp::create(
+ rewriter, loc, loadOp.getResult(), data,
+ ArrayRef<int64_t>{localOffsetDim0, localOffsetDim1},
+ ArrayRef<int64_t>{1, 1});
+ // InsertOp must have the same layout as newTensorDesc.
+ xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr);
+ data = insertOp.getResult();
+ }
+ }
+ return data;
+}
+
+/// Checks if a CreateNdDescOp can be optimized for transpose, if so creates a
+/// new CreateNdDescOp with optimized tensor desc type. This involves extracting
+/// the base pointer from the original memory source and adjusting the shape and
+/// strides of the tensor desc to fit with the new optimized transpose layout.
+class XeGPUCreateNdDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto tdescTy = createNdOp.getType();
+ // Get the target uArch info.
+ auto chipStr = xegpu::getChipStr(createNdOp);
+ // Check if the chip is supported.
+ assert(
+ chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg") &&
+ "Expecting target chip to be pvc or bmg for transpose optimization.");
+ const uArch *targetuArch = xegpu::uArch::getUArch(chipStr.value());
+
+ auto convertType = tryOptimize(tdescTy, targetuArch);
+ if (convertType == tdescTy)
+ return failure();
+ auto strides = createNdOp.getMixedStrides();
+ auto maybeConstInnerStride = getConstantIntValue(strides.back());
+ // Only row-major memrefs are expected for now.
+ if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
+ return rewriter.notifyMatchFailure(
+ createNdOp, "Expecting row-major memref for transpose optimization.");
+ Value source = createNdOp.getSource();
+ auto optionalLaneData = getMaybeLaneData(tdescTy);
+ assert(optionalLaneData && "Expected 2D lane data");
+ auto laneData = optionalLaneData.value();
+ int64_t innerLaneData = laneData[1];
+ auto memrefType = dyn_cast<MemRefType>(source.getType());
+ // Inner dimension of the shape must be adjusted based on innerLaneData.
+ SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
+ modifiedShape.back() = divideByConstant(
+ rewriter, createNdOp.getLoc(),
+ convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
+ innerLaneData);
+ // Similarly, second to last stride must be adjusted.
+ assert(strides.size() >= 2 &&
+ "Expected at least 2 strides for CreateNdDescOp");
+ SmallVector<OpFoldResult> modifiedStrides(strides);
+ modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
+ rewriter, createNdOp.getLoc(),
+ convertToValue(rewriter, createNdOp.getLoc(),
+ modifiedStrides[modifiedStrides.size() - 2]),
+ innerLaneData);
+
+ // If the source is a static memref, we need to extract the pointer to
+ // base address.
+ if (memrefType && memrefType.hasStaticShape()) {
+ auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, createNdOp.getLoc(), source);
+ source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
+ rewriter.getI64Type(),
+ extractOp.getResult())
+ .getResult();
+ }
+ // Create a new CreateNdDescOp with the modified shape and converted type.
+ auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
+ rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
+ modifiedStrides);
+ rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
+ return success();
+ }
+};
+
+/// Checks if a LoadNdOp consumes a tensor desc type that was rewritten for
+/// tranpose optimization. If so, rewrites the LoadNdOp to to align with the
+/// adjusted tensor desc type. This can result in multiple LoadNdOps being
+/// generated to fill in the original load shape.
+class XeGPULoadNdDescOpPattern final
+ : public OpConversionPattern<xegpu::LoadNdOp> {
+public:
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto origTensorDescType = loadNdOp.getTensorDescType();
+ auto adaptorType =
+ cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
+ if (adaptorType == origTensorDescType)
+ return failure();
+ // Offsets must be adjusted based on innerLaneData.
+ auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
+ int64_t innerLaneData = laneData[1];
+ auto offsets = loadNdOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(loadNdOp,
+ "Expecting offsets in LoadNd");
+ SmallVector<OpFoldResult> modifiedOffsets(offsets);
+ modifiedOffsets.back() = divideByConstant(
+ rewriter, loadNdOp.getLoc(),
+ convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
+ innerLaneData);
+ // Get the 2D data shape of this loadNdOp in its original type including
+ // array length.
+ SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
+ // Adjust the data shape based on innerLaneData.
+ origDataShape.back() /= innerLaneData;
+ // HW supported shape is the new tensor desc shape after conversion.
+ SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
+ VectorType origVectorType =
+ VectorType::get(origDataShape, adaptorType.getElementType());
+ Value data;
+ // Orig data shape is 3D for the array length case.
+ if (origTensorDescType.getArrayLength() > 1) {
+ SmallVector<Value> arraySlices;
+ for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
+ Value slice = arith::ConstantOp::create(
+ rewriter, loadNdOp->getLoc(), origVectorType,
+ rewriter.getZeroAttr(origVectorType));
+ // Increase the Y offset for each array slice.
+ Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
+ modifiedOffsets.back());
+ modifiedOffsets.back() =
+ arith::AddIOp::create(
+ rewriter, loadNdOp->getLoc(), offsetY,
+ arith::ConstantIndexOp::create(rewriter, loadNdOp->getLoc(),
+ i * origDataShape[1])
+ .getResult())
+ .getResult();
+ slice = generateLoads(
+ rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets,
+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+ loadNdOp);
+ // BitCast back to original load shape without array length.
+ auto bitcastType = VectorType::get(origTensorDescType.getShape(),
+ origTensorDescType.getElementType());
+ auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ bitcastType, slice);
+ // BitCastOp must have the same layout as the original loadNdOp.
+ xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+ origTensorDescType.getLayoutAttr());
+ arraySlices.push_back(bitCastOp.getResult());
+ }
+ rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
+ return success();
+ }
+ data = arith::ConstantOp::create(
+ rewriter, loadNdOp->getLoc(),
+ VectorType::get(origDataShape, adaptorType.getElementType()),
+ rewriter.getZeroAttr(origVectorType));
+ data = generateLoads(
+ rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets,
+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+ loadNdOp);
+ auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ loadNdOp.getType(), data);
+ // BitCastOp must have the same layout as the original loadNdOp.
+ xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+ origTensorDescType.getLayoutAttr());
+ rewriter.replaceOp(loadNdOp, bitCastOp);
+ return success();
+ }
+};
+
+/// Vector ExtractOp must be processed if the original tensor desc type has
+/// array length greater than 1. In this case, the LoadNdOp is replaced with
+/// multiple LoadNdOps for each array slice making the extraction unnecessary.
+/// In this case, we simply remove the ExtractOp.
+class VectorExtractOpPattern final
+ : public OpConversionPattern<vector::ExtractOp> {
+public:
+ using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check if the source of the extraction is split to multiple values.
+ if (adaptor.getSource().size() == 1)
+ return failure();
+ auto mixedPos = extractOp.getMixedPosition();
+ if (mixedPos.size() != 1)
+ return failure();
+ auto mayBeInt = getConstantIntValue(mixedPos[0]);
+ if (!mayBeInt)
+ return failure();
+ rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
+ return success();
+ }
+};
+
+} // namespace
+
+void xegpu::populateXeGPUOptimizeBlockLoadsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
+ VectorExtractOpPattern>(patterns.getContext());
+}
+
+namespace {
+
+struct XeGPUOptimizeBlockLoadsPass final
+ : public xegpu::impl::XeGPUOptimizeBlockLoadsBase<
+ XeGPUOptimizeBlockLoadsPass> {
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter converter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+
+ // This pass is only meant for PVC and BMG targets. If unsupported target
+ // is found, exit early.
+ bool isTargetSupported = false;
+ getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
+ auto chipStr = xegpu::getChipStr(funcOp);
+ if (chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg"))
+ isTargetSupported = true;
+ });
+
+ if (!isTargetSupported) {
+ DBGS() << "XeGPUOptimizeBlockLoadsPass only supports PVC and BMG targets."
+ << "\n";
+ return;
+ }
+
+ // CreateNdDescOp and LoadNdOp with optimizable tensor desc types must be
+ // converted.
+ target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+ [&](xegpu::CreateNdDescOp createNdOp) {
+ return !canBeOptimizedForTranspose(createNdOp.getType());
+ });
+ target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
+ [&](xegpu::LoadNdOp loadNdOp) {
+ return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
+ });
+ // Vector ExtractOps can have optimizable layouts if they extract from
+ // LoadNdOps with array length greater than 1. These ExtractOps must be
+ // converted.
+ target.addDynamicallyLegalOp<vector::ExtractOp>(
+ [&](vector::ExtractOp extractOp) {
+ auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult());
+ if (!layout)
+ return true;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ return !canBeOptimizedForTranspose(laneLayout, laneData);
+ });
+ converter.addConversion([](Type type) { return type; });
+
+ target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
+ vector::VectorDialect>();
+ scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
+ xegpu::populateXeGPUOptimizeBlockLoadsPatterns(patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ DBGS() << "Optimize block loads pass failed.\n";
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8fab255..dc9eb96 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/Attributes.h"
@@ -37,6 +36,8 @@
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+
namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
@@ -52,6 +53,8 @@ using namespace mlir::dataflow;
namespace {
+enum class LayoutKind { Lane, InstData };
+
//===----------------------------------------------------------------------===//
// LayoutInfo
//===----------------------------------------------------------------------===//
@@ -104,6 +107,8 @@ public:
SmallVector<int> getLaneData() const;
+ SmallVector<int> getInstData() const;
+
bool isSliceLayout() const {
if (!isAssigned())
return false;
@@ -137,6 +142,13 @@ SmallVector<int> LayoutInfo::getLaneData() const {
[](int64_t val) { return static_cast<int>(val); });
}
+SmallVector<int> LayoutInfo::getInstData() const {
+ if (!isAssigned())
+ return {};
+ return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
+ [](int64_t val) { return static_cast<int>(val); });
+}
+
void LayoutInfo::print(raw_ostream &os) const {
if (isAssigned()) {
os << storage;
@@ -156,7 +168,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
llvm_unreachable("Join should not be triggered by layout propagation.");
}
-/// Construct a new layout with the transposed lane layout and lane data.
+/// Construct a new layout with the transposed inst_data or lane_layout,
+/// lane_data.
LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
if (!isAssigned())
return {};
@@ -174,12 +187,22 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
SmallVector<int32_t> laneLayout;
SmallVector<int32_t> laneData;
+ SmallVector<int32_t> instData;
for (int64_t idx : permutation) {
- laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
- laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+ if (getLaneLayout().size()) {
+ laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+ laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+ }
+ if (getInstData().size())
+ instData.push_back(static_cast<int32_t>(getInstData()[idx]));
}
- return LayoutInfo(
- xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
+ xegpu::LayoutAttr layoutAttr;
+ if (getLaneLayout().size())
+ layoutAttr =
+ xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
+ if (getInstData().size())
+ layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
+ return LayoutInfo(layoutAttr);
}
//===----------------------------------------------------------------------===//
@@ -200,18 +223,30 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
- unsigned rank) {
+ unsigned rank,
+ const xegpu::uArch::uArch *uArch) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1) {
return LayoutInfo(
- xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1}));
+ xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
}
- return LayoutInfo(xegpu::LayoutAttr::get(
- ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
+ return LayoutInfo(
+ xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
+}
+
+static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
+ unsigned rank, int subgroupSize) {
+ assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+ if (rank == 1) {
+ return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
+ }
+ return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
}
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
+ const xegpu::uArch::uArch *uArch,
+ unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
@@ -221,28 +256,24 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
+ return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
- int packingFactor = 1;
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+ int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
- packingFactor =
- bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
- ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
- : 1;
- return LayoutInfo(xegpu::LayoutAttr::get(
- vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
- {1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
+ {uArch->getSubgroupSize(), 1},
+ {1, packingFactor}));
}
- if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
- packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
- {1, xegpu::targetinfo::subgroupSize},
+ {1, uArch->getSubgroupSize()},
{1, packingFactor}));
}
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
+ const xegpu::uArch::uArch *uArch,
+ unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
@@ -252,27 +283,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (tdescTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1);
+ return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
-
+ int subgroupSize = uArch->getSubgroupSize();
+ int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
- int packingFactor =
- bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
- ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
- : 1;
return LayoutInfo(xegpu::LayoutAttr::get(
- tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
- {1, packingFactor}));
+ tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
}
- int packingFactor =
- (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
- ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
- : 1;
- return LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(),
- {1, xegpu::targetinfo::subgroupSize},
- {1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(
+ tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
}
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -281,25 +303,25 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
/// `packedSizeInBitsForDefault`
/// * For B operand, the data must be packed in minimum
/// `packedSizeInBitsForDpasB`
-static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
- unsigned operandNum) {
+static LayoutInfo
+getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
+ const xegpu::uArch::uArch *uArch,
+ unsigned packingSize) {
Type elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
"Expected int or float type in DPAS operands");
- SmallVector<int32_t, 2> layout({1, xegpu::targetinfo::subgroupSize});
+ SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()});
// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
// must have the VNNI format.
- if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
- xegpu::targetinfo::packedSizeInBitsForDpasB) {
+ if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
SmallVector<int32_t, 2> data(
- {static_cast<int32_t>(xegpu::targetinfo::packedSizeInBitsForDpasB /
- elementTy.getIntOrFloatBitWidth()),
+ {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
1});
return LayoutInfo(
xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
}
// Otherwise, return the default layout for the vector type.
- return getDefaultSIMTLayoutInfo(vectorTy);
+ return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
}
//===----------------------------------------------------------------------===//
@@ -314,6 +336,7 @@ static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
+ LayoutKind layoutKind;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -364,10 +387,14 @@ private:
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
+
public:
LayoutInfoPropagation(DataFlowSolver &solver,
- SymbolTableCollection &symbolTable)
- : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+ SymbolTableCollection &symbolTable,
+ LayoutKind layoutKind)
+ : SparseBackwardDataFlowAnalysis(solver, symbolTable),
+ layoutKind(layoutKind) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
LogicalResult
@@ -450,13 +477,71 @@ LogicalResult LayoutInfoPropagation::visitOperation(
return success();
}
+bool LayoutInfoPropagation::hasParamsOfLayoutKind(
+ xegpu::DistributeLayoutAttr anchorLayout) {
+ if (anchorLayout == nullptr) {
+ return false;
+ }
+ if (layoutKind == LayoutKind::InstData) {
+ return !(anchorLayout.getEffectiveInstDataAsInt().empty());
+ } else if (layoutKind == LayoutKind::Lane) {
+ return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ anchorLayout.getEffectiveLaneDataAsInt().empty());
+ }
+ return false;
+}
+
void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Here we assign the default layout to the tensor descriptor operand of
- // prefetch.
- auto tdescTy = prefetch.getTensorDescType();
- auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
+
+ LayoutInfo prefetchLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ prefetchLayout = LayoutInfo(anchorLayout);
+ } else {
+ // Here we assign the default layout to the tensor descriptor operand of
+ // prefetch.
+ auto tdescTy = prefetch.getTensorDescType();
+
+ auto uArch = getUArch(getChipStr(prefetch).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
+
+ auto blockWHC =
+ uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
+ if (!blockWHC)
+ prefetch.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
+ if (instWidth == -1)
+ prefetch.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ if (tdescTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ prefetch.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ prefetchLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
+ else
+ prefetchLayout = getDefaultSIMTLayoutInfo(
+ tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+
+ prefetch.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
+ }
// Propagate the layout to the source tensor descriptor.
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
}
@@ -475,10 +560,11 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
reduction.emitWarning("Expecting output type to be 1D vector.");
return;
}
+ auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
// Given that the result is 1D, the layout of the operand should be 2D with
// default layout.
- LayoutInfo operandLayout =
- getDefaultSIMTLayoutInfo(reduction->getContext(), 2);
+ LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
+ reduction->getContext(), 2, uArch->getSubgroupSize());
propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
@@ -494,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
// Only consider vector to vector broadcasts for now.
VectorType resultTy = broadcast.getResultVectorType();
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!sourceTy) {
- broadcast.emitWarning("Expecting source type to be a vector type.");
+ // skip layout propagation for non-vector source operand.
+ if (!sourceTy)
return;
- }
- // Only consider nD -> nD broadcast.
+ // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
if (sourceTy.getRank() != resultTy.getRank()) {
- broadcast.emitWarning("Expecting source and result to have same rank.");
+ auto sourceDims = sourceTy.getShape();
+ auto resultDims = resultTy.getShape();
+ SmallVector<int64_t> bcastDims;
+ auto dimDiff = resultTy.getRank() - sourceTy.getRank();
+ // adding the missing leading dims
+ for (int i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
+
+ // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
+ // broadcasted dim
+ for (size_t i = 0; i < sourceDims.size(); i++)
+ if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
+ bcastDims.push_back(i + dimDiff);
+
+ // create a slice layout for the source
+ xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+ broadcast->getContext(),
+ cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+ DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
}
+
SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
- if (broadcastUnitDims.size() != 1) {
- broadcast.emitWarning("Expecting source type to be nD vector only with "
- "one broadcasted dimension.");
- return;
- }
- // Propagate the result layout to the source operand.
+ resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
+ .setUnitDimData(broadcastUnitDims);
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}
@@ -555,17 +657,97 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp(
void LayoutInfoPropagation::visitDpasOp(
xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- VectorType aTy = dpas.getLhsType();
- VectorType bTy = dpas.getRhsType();
- propagateIfChanged(
- operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
- propagateIfChanged(
- operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
+
+ LayoutInfo dpasALayout;
+ LayoutInfo dpasBLayout;
+ LayoutInfo dpasCDLayout;
+
+ xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
+ if (hasParamsOfLayoutKind(anchorLayoutCD)) {
+ xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
+ xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
+ assert(hasParamsOfLayoutKind(anchorLayoutA) &&
+ "Expected anchor layout for DPAS A operand.");
+ assert(hasParamsOfLayoutKind(anchorLayoutB) &&
+ "Expected anchor layout for DPAS B operand.");
+ dpasALayout = LayoutInfo(anchorLayoutA);
+ dpasBLayout = LayoutInfo(anchorLayoutB);
+ dpasCDLayout = LayoutInfo(anchorLayoutCD);
+
+ } else {
+
+ VectorType aTy = dpas.getLhsType();
+ VectorType bTy = dpas.getRhsType();
+
+ auto uArch = getUArch(getChipStr(dpas).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+ const unsigned dataALen = aTy.getShape().front();
+ auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+ const int maxALen =
+ xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+ if (maxALen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+
+ const unsigned dataBLen = bTy.getShape().back();
+ auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
+
+ const int maxBLen =
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+
+ if (maxBLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataA = {maxALen, subgroupSize};
+ SmallVector<int> instDataB = {subgroupSize, maxBLen};
+
+ if (layoutKind == LayoutKind::InstData) {
+ dpasALayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
+ dpasBLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+ } else {
+ dpasALayout = getSIMTLayoutInfoForDPASOperand(
+ aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
+ dpasBLayout = getSIMTLayoutInfoForDPASOperand(
+ bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ }
+
+ if (operands.size() > 2) {
+ VectorType cTy = dpas.getAccType();
+ if (layoutKind == LayoutKind::InstData) {
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen =
+ uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxCLen = xegpu::getLargestDivisor(
+ dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataC = {maxALen, maxCLen};
+ dpasCDLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
+ } else
+ dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
+ cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+
+ dpas.setLayoutCdAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
+ }
+ dpas.setLayoutAAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
+ dpas.setLayoutBAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
+ }
+
+ propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+ propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
if (operands.size() > 2) {
- VectorType cTy = dpas.getAccType();
- propagateIfChanged(
- operands[2],
- operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
+ propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
}
}
@@ -573,7 +755,51 @@ void LayoutInfoPropagation::visitDpasOp(
void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
+
+ LayoutInfo storeLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ storeLayout = LayoutInfo(anchorLayout);
+ } else {
+ auto uArch = getUArch(getChipStr(store).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
+ VectorType dataTy = store.getValueType();
+ auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
+ store.getValueType().getElementType());
+ if (!blockWHC)
+ store.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
+ if (instWidth == -1)
+ store.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ if (dataTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ store.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ storeLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
+ else
+ storeLayout =
+ getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
+ store.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
+ }
+ // Propagate the layout to the value operand.
// Both operands should have the same layout
for (LayoutInfoLattice *operand : operands)
propagateIfChanged(operand, operand->meet(storeLayout));
@@ -584,21 +810,30 @@ void LayoutInfoPropagation::visitStoreNdOp(
void LayoutInfoPropagation::visitLoadNdOp(
xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo valueLayout = results[0]->getValue();
- // Need the layout of the value to propagate to the tensor descriptor.
- if (!valueLayout.isAssigned())
- return;
- LayoutInfo tensorDescLayout = valueLayout;
- // LoadNdOp has the transpose effect. However, at the stage of this analysis
- // this effect is not expected and should be abstracted away. Emit a
- // warning.
- if (auto transpose = load.getTranspose()) {
- load.emitWarning("Transpose effect is not expected for LoadNdOp at "
- "LayoutInfoPropagation stage.");
- tensorDescLayout = valueLayout.transpose(transpose.value());
+
+ LayoutInfo loadLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ loadLayout = LayoutInfo(anchorLayout);
+ } else {
+
+ LayoutInfo valueLayout = results[0]->getValue();
+ // Need the layout of the value to propagate to the tensor descriptor.
+ if (!valueLayout.isAssigned())
+ return;
+ loadLayout = valueLayout;
+ // LoadNdOp has the transpose effect. However, at the stage of this analysis
+ // this effect is not expected and should be abstracted away. Emit a
+ // warning.
+ if (auto transpose = load.getTranspose()) {
+ load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+ "LayoutInfoPropagation stage.");
+ loadLayout = valueLayout.transpose(transpose.value());
+ }
+ load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
}
// Propagate the new layout to the tensor descriptor operand.
- propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
}
/// For vector::TransposeOp, the layout of the result is transposed and
@@ -688,20 +923,48 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // The layout is strictly determined by the payload type.
- auto payloadTy = dyn_cast<VectorType>(load.getValueType());
- if (!payloadTy) {
- load.emitWarning("Not propagating, non-vector payload supplied.");
- return;
- }
- LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
- // Mask operand should have 1D default layout.
- LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1);
+ LayoutInfo loadLayout;
+ LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ loadLayout = LayoutInfo(anchorLayout);
+ maskLayout = loadLayout;
+ } else {
+
+ // The layout is strictly determined by the payload type.
+ VectorType payloadTy = load.getValueType();
+ if (!payloadTy) {
+ load.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+ auto uArch = getUArch(getChipStr(load).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
+ instData.push_back(chunkSize);
+ else if (auto srcTdescTy =
+ dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
+ if (srcTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ loadLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
+ else
+ loadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered*/ true);
+ // Mask operand should have 1D default layout.
+ maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+
+ load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
+ }
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
- propagateIfChanged(operands[0], operands[0]->meet(layout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
if (load.getOffsets())
@@ -717,8 +980,10 @@ void LayoutInfoPropagation::visitCreateDescOp(
// Need the layout of the descriptor to propagate to the operands.
if (!descLayout.isAssigned())
return;
+ auto uArch = getUArch(getChipStr(createDesc).value_or(""));
// For offset operand propagate 1D default layout.
- LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1);
+ LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
+ uArch->getSubgroupSize());
propagateIfChanged(operands[1], operands[1]->meet(layout));
}
@@ -727,26 +992,56 @@ void LayoutInfoPropagation::visitCreateDescOp(
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Currently, for 2D StoreScatterOp we expect that the height dimension of
- // the tensor descriptor is equal to the subgroup size. This is ensured by
- // the op verifier.
- auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
- if (!payloadTy) {
- storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
- return;
+
+ LayoutInfo payloadLayout;
+ LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ payloadLayout = LayoutInfo(anchorLayout);
+ maskLayout = payloadLayout;
+ } else {
+ // Currently, for 2D StoreScatterOp we expect that the height dimension of
+ // the tensor descriptor is equal to the subgroup size. This is ensured by
+ // the op verifier.
+ VectorType payloadTy = storeScatter.getValueType();
+ if (!payloadTy) {
+ storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+
+ auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+
+ if (layoutKind == LayoutKind::InstData) {
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
+ chunkSize > 1)
+ instData.push_back(chunkSize);
+ else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
+ storeScatter.getDestType())) {
+ if (dstTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+ payloadLayout = LayoutInfo(
+ xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+ } else {
+ auto payloadShape = payloadTy.getShape();
+ if (payloadShape.size() > 1)
+ assert(payloadShape[0] == subgroupSize &&
+ "Expected the first dimension of 2D tensor descriptor to be "
+ "equal to "
+ "subgroup size.");
+ payloadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered=*/true);
+ }
+
+ maskLayout =
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+
+ storeScatter.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
}
- auto payloadShape = payloadTy.getShape();
- if (payloadShape.size() > 1)
- assert(
- payloadShape[0] == xegpu::targetinfo::subgroupSize &&
- "Expected the first dimension of 2D tensor descriptor to be equal to "
- "subgroup size.");
-
- LayoutInfo payloadLayout =
- getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
-
- LayoutInfo maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1);
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
@@ -768,10 +1063,10 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op) : target(op) {
+ RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
- solver.load<LayoutInfoPropagation>(symbolTable);
+ solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
(void)solver.initializeAndRun(op);
}
@@ -878,7 +1173,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- xegpu::setDistributeLayoutAttr(result, layout);
+ xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true);
}
return success();
}
@@ -1011,7 +1306,18 @@ struct XeGPUPropagateLayoutPass final
} // namespace
void XeGPUPropagateLayoutPass::runOnOperation() {
- auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
+ LayoutKind layoutKind;
+ if (this->layoutKind == "lane") {
+ layoutKind = LayoutKind::Lane;
+ } else if (this->layoutKind == "inst") {
+ layoutKind = LayoutKind::InstData;
+ } else {
+ getOperation()->emitError("Unsupported layout kind option: " +
+ this->layoutKind);
+ signalPassFailure();
+ return;
+ }
+ RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
@@ -1023,9 +1329,11 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
LayoutInfo layout = analysis.getLayoutInfo(val);
if (!layout.isAssigned())
return {};
+ xegpu::DistributeLayoutAttr layoutAttr =
+ cast<xegpu::DistributeLayoutAttr>(layout.get());
if (layout.isSliceLayout())
- return cast<xegpu::SliceAttr>(layout.get());
- return cast<xegpu::LayoutAttr>(layout.get());
+ return cast<xegpu::SliceAttr>(layoutAttr);
+ return cast<xegpu::LayoutAttr>(layoutAttr);
};
mlir::OpBuilder builder(&getContext());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d09dc19..ca81c3c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -7,14 +7,15 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -98,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
if (i < distributionStart)
continue;
-
// Check if the dimension can be distributed evenly.
if (dim % effectiveLaneLayout[i - distributionStart] != 0)
return failure();
@@ -159,17 +159,33 @@ static bool requirePacked(const xegpu::LayoutAttr layout) {
/// Helper function to check if the layout requires a transpose effect.
static bool requireTranspose(const xegpu::LayoutAttr layout,
- const std::string &chipStr) {
+ const xegpu::uArch::uArch *uArch) {
// Return false for unsupported targets.
// TODO: Add more support or move to target info.
- if (chipStr != "pvc" && chipStr != "bmg")
+ if (uArch->getName().equals_insensitive("pvc") &&
+ uArch->getName().equals_insensitive("bmg"))
return false;
if (!layout)
return false;
auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
if (laneLayout.size() != 2)
return false;
- return laneLayout[0] == xegpu::targetinfo::subgroupSize && laneLayout[1] == 1;
+ return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
+}
+
+/// Given a vector type and its distributed vector type, return the list of
+/// dimensions that are distributed.
+static SmallVector<int64_t> getDistributedDims(VectorType originalType,
+ VectorType distributedType) {
+ assert(originalType.getRank() == distributedType.getRank() &&
+ "sequential and distributed vector types must have the same rank");
+ SmallVector<int64_t> distributedDims;
+ for (int64_t i = 0; i < originalType.getRank(); ++i) {
+ if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
+ distributedDims.push_back(i);
+ }
+ }
+ return distributedDims;
}
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
@@ -199,6 +215,11 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
PatternRewriter &rewriter) const override {
+ auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ gpuFuncOp, "Subgroup distribution requires target attribute attached "
+ "to set the warp size");
// If the function only contains a single void return, skip.
if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
@@ -230,7 +251,7 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
auto warpOp = gpu::WarpExecuteOnLane0Op::create(
rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
- xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(),
+ uArch->getSubgroupSize(), newGpuFunc.getArguments(),
newGpuFunc.getArgumentTypes());
Block &warpBodyBlock = warpOp.getBodyRegion().front();
// Replace the ReturnOp of the original gpu function with a YieldOp.
@@ -495,14 +516,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
warpOp, "warp result is not a xegpu::LoadNd op");
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
+ auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ loadOp, "xegpu::LoadNdOp require target attribute attached to "
+ "determine transpose "
+ "requirement");
// Chip information is required to decide if the layout requires transpose
// effect.
- auto chipStr = xegpu::getChipStr(loadOp);
- if (!chipStr)
- return rewriter.notifyMatchFailure(
- loadOp,
- "xegpu::LoadNdOp require chip information to determine transpose "
- "requirement");
// Expecting offsets to be present.
SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
if (offsets.empty())
@@ -556,7 +577,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
// Set the packed attribute if the layout requires it.
newLoadOp.setPacked(requirePacked(layout));
// Set the transpose attribute if the layout requires it.
- if (requireTranspose(layout, chipStr.value()))
+ if (requireTranspose(layout, uArch))
newLoadOp.setTranspose(
DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
Value distributedVal = newWarpOp.getResult(operandIdx);
@@ -906,6 +927,183 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
+static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
+ PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
+ Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
+ SmallVector<Value> newCoods;
+ auto maybeCoords =
+ layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+ if (failed(maybeCoords))
+ return {};
+ assert(maybeCoords.value().size() == 1 &&
+ "Expected one set of distributed offsets");
+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+ getAsOpFoldResult(origOffsets));
+ newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
+ return newCoods;
+}
+
+/// Pattern for distributing xegpu::LoadMatrixOp.
+struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadMatrixOp");
+ const int operandIdx = producedByLastLoad->getOperandNumber();
+
+ VectorType sgPayloadTy =
+ dyn_cast<VectorType>(matrixOp.getResult().getType());
+ VectorType warpResultTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the load op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> operands = {matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newCoords = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordinatesForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
+ }
+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+ rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+ newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ newWarpOp.getResult(operandIdx),
+ resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+ return success();
+ }
+};
+
+/// Pattern for distributing xegpu::StoreMatrixOp.
+struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the store op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+ operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newCoords = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordinatesForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
+ }
+
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+ ValueRange(newCoords), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(matrixOp);
+ return success();
+ }
+};
+
/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
@@ -1225,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
}
};
+/// This pattern distributes the `vector.broadcast` operation across lanes in a
+/// warp. The pattern supports three use cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+/// vector
+/// must have a slice layout of the result. If the distributed source and
+/// target vector types are identical, this lowers to a no-op; otherwise, it
+/// remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+/// target:
+/// The source vector must have unit dimensions, and lane_data must be unit
+/// size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
+/// scalar to distributed result type.
+///
+/// Example 1 (lowering to a broadcast with distributed types):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [0]> } : () -> (vector<32xf32>)
+/// %2 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
+/// : vector<32xf32> to vector<8x32xf32>
+/// gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [0]> } : () -> (vector<32xf32>)
+/// gpu.yield %0 : vector<32xf32>
+/// }
+/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
+///
+/// Example 2 (no-op):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [1]> } : () -> (vector<8xf32>)
+/// %1 = vector.shape_cast %0
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8xf32> to vector<8x1xf32>
+/// %2 = vector.broadcast %1
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
+/// gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [1]> } : () -> (vector<8xf32>)
+/// %1 = vector.shape_cast %0
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8xf32> to vector<8x1xf32>
+/// gpu.yield %1 : vector<8x1xf32>
+/// }
+/// // The broadcast is implicit through layout transformation (no-op)
+/// "some_use"(%r#0)
+/// ```
+struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
+ if (!yieldOperand)
+ return failure();
+ auto broadcastOp =
+ cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandIdx = yieldOperand->getOperandNumber();
+
+ VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ VectorType destType =
+ dyn_cast<VectorType>(broadcastOp.getResult().getType());
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0));
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
+
+ FailureOr<VectorType> sourceDistType;
+ Type sourceElemOrDistType;
+ if (sourceType) {
+
+ // Case 1 and 2: source is a vector type.
+ int64_t rankDiff = destType.getRank() - sourceType.getRank();
+ if (rankDiff > 0) {
+ // Case 1: source is lower-rank than result.
+ bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
+ if (!isSliceOf)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Broadcast input layout must be a slice of result layout.");
+ }
+ // case 2: source and result have same rank
+ if (rankDiff == 0) {
+ SetVector<int64_t> broadcastUnitDims =
+ broadcastOp.computeBroadcastedUnitDims();
+ resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+ bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
+ if (!isEqualTo)
+ return rewriter.notifyMatchFailure(
+ warpOp, "For same-rank broadcast, source must be identical to "
+ "adjusted result layouts with unit dims.");
+ sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+ }
+
+ sourceDistType =
+ getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+ if (failed(sourceDistType)) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the source vector type.");
+ }
+ sourceElemOrDistType = sourceDistType.value();
+
+ } else {
+ // Case 3: source is a scalar type.
+ if (sourceLayout) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Broadcast from scalar must not have a layout attribute.");
+ }
+ sourceElemOrDistType = broadcastOp.getSourceType();
+ }
+ FailureOr<VectorType> destDistType =
+ getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+ if (failed(destDistType)) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the dest vector type.");
+ }
+
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
+ newRetIndices);
+
+ Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
+
+ Value newBroadcast = distributedSource;
+
+ if (sourceElemOrDistType != destDistType.value()) {
+ rewriter.setInsertionPointAfter(newWarpOp);
+ newBroadcast =
+ vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
+ destDistType.value(), distributedSource);
+ }
+
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
+ return success();
+ }
+};
+
/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
/// `gpu.warp_execute_on_lane_0` region.
struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
@@ -1285,6 +1643,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
}
};
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+// advanced cases where the distributed dimension is partially extracted and
+// currently not supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ auto extractOp =
+ cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+ unsigned operandIdx = operand->getOperandNumber();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ // Find the distributed dimensions.
+ auto extractResultType = cast<VectorType>(operand->get().getType());
+ auto distributedDims =
+ getDistributedDims(extractResultType, distributedType);
+ // Collect updated source type, sizes and offsets. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ VectorType updatedSourceType = extractOp.getSourceVectorType();
+ SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
+ extractOp.getSizes(), [](Attribute attr) { return attr; });
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ extractOp.getOffsets(), [](Attribute attr) { return attr; });
+ // If the result is distributed, it must be distributed in exactly one
+ // dimension. In this case, we adjust the sourceDistType, distributedSizes
+ // and distributedOffsets accordingly.
+ if (distributedDims.size() > 0) {
+ if (distributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source can not be distributed in multiple dimensions.");
+ int64_t distributedDim = distributedDims[0];
+ int sourceDistrDimSize =
+ extractOp.getSourceVectorType().getShape()[distributedDim];
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source of extract_strided_slice op lacks distribution "
+ "layout");
+ auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize = sourceLaneLayout[distributedDim];
+ // Check if the source size in the distributed dimension is a multiple of
+ // subgroup size.
+ if (sourceDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Source size along distributed dimension is not a multiple of "
+ "subgroup size.");
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ // We expect lane data to be all ones in this case.
+ if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source layout");
+ // The offsets in the distributed dimention must be a multiple of subgroup
+ // size.
+ int64_t distrDimOffset =
+ cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+ if (distrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Offset along distributed dimension "
+ "is not a multiple of subgroup size.");
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, extractOp.getSourceVectorType())
+ .value();
+ // Update the distributed sizes to match the distributed type.
+ updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+ distributedType.getDimSize(distributedDim));
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[distributedDim] =
+ rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source of the extract op from
+ // the warp op and creating a new extract op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value source = newWarpOp.getResult(newRetIndices[0]);
+ // Create a new extract op outside the warp op.
+ Value newExtractOp = vector::ExtractStridedSliceOp::create(
+ rewriter, extractOp.getLoc(), distributedType, source,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ ArrayAttr::get(rewriter.getContext(), updatedSizes),
+ extractOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
+ return success();
+ }
+};
+
+/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+/// advanced cases where the distributed dimension is partially inserted and
+/// currently not supported by the generic vector distribution patterns.
+struct VectorInsertStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto insertOp =
+ operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Find the distributed dimensions of the dest vector.
+ auto insertResultType = cast<VectorType>(operand->get().getType());
+ auto destDistributedDims =
+ getDistributedDims(insertResultType, distributedType);
+ // Collect updated offsets, source type and dest type. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ insertOp.getOffsets(), [](Attribute attr) { return attr; });
+ VectorType updatedSourceType = insertOp.getSourceVectorType();
+ VectorType updatedDestType = insertOp.getDestVectorType();
+ if (destDistributedDims.size() > 0) {
+ // Only single dimension distribution is supported.
+ if (destDistributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting source to be distributed in a single dimension.");
+ int64_t destDistributedDim = destDistributedDims[0];
+
+ VectorType srcType = insertOp.getSourceVectorType();
+ VectorType destType = insertOp.getDestVectorType();
+ // Currently we require that both source (kD) and dest (nD) vectors are
+ // distributed. This requires that distributedDim (d) is contained in the
+ // last k dims of the dest vector (d >= n - k).
+ int64_t sourceDistributedDim =
+ destDistributedDim - (destType.getRank() - srcType.getRank());
+ if (sourceDistributedDim < 0)
+ return rewriter.notifyMatchFailure(
+ insertOp,
+ "distributed dimension must be in the last k (i.e. source "
+ "rank) dims of dest vector");
+ int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
+ // Obtain the source and dest layouts.
+ auto destLayout =
+ xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
+ if (!destLayout || !sourceLayout ||
+ destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source or dest of insert_strided_slice op lacks "
+ "distribution layout");
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize =
+ destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
+ // We require that source and dest lane data are all ones to ensure
+ // uniform round robin distribution.
+ auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
+ !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source and dest layouts");
+ // Source distributed dim size must be multiples of subgroup size.
+ if (srcDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Distributed dimension size in source is not a multiple of "
+ "subgroup size.");
+ // Offsets in the distributed dimension must be multiples of subgroup
+ // size.
+ int64_t destDistrDimOffset =
+ cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
+ if (destDistrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Offset along distributed dimension in dest is not a multiple of "
+ "subgroup size.");
+ // Update the source and dest types based on their layouts.
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, insertOp.getSourceVectorType())
+ .value();
+ updatedDestType = getDistVecTypeBasedOnLaneLayout(
+ destLayout, insertOp.getDestVectorType())
+ .value();
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[destDistributedDim] =
+ rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source and dest of the insert op
+ // from the warp op and creating a new insert op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+ {updatedSourceType, updatedDestType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
+ Value dest = newWarpOp.getResult(newRetIndices[1]);
+ // Create a new insert op outside the warp op.
+ Value newInsertOp = vector::InsertStridedSliceOp::create(
+ rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ insertOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
+ newInsertOp);
+ return success();
+ }
+};
+
/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
/// outside of the warp op.
@@ -1437,13 +2015,18 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
GpuBarrierDistribution, VectorMultiReductionDistribution,
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
- VectorBitcastDistribution,
+ VectorBitcastDistribution, LoadMatrixDistribution,
+ StoreMatrixDistribution,
MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext(),
/*pattern benefit=*/regularPatternBenefit);
- patterns.add<VectorShapeCastDistribution>(
- patterns.getContext(),
- /*pattern benefit=*/highPatternBenefit);
+ // For following patterns, we need to override the regular vector distribution
+ // patterns. Therefore, assign higher benefit.
+ patterns
+ .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
+ VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
+ patterns.getContext(),
+ /*pattern benefit=*/highPatternBenefit);
}
void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
@@ -1462,6 +2045,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+ continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index e6e71cc..af63f09 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
// return dummy Value to satisfy function's signature
return nullptr;
};
@@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return xegpu::LoadNdOp::create(
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL2HintAttr(), op.getL3HintAttr(), layout);
};
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
*targetShape, createLoad, loc, rewriter);
@@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
convertedTdescs[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
// return dummy Value to satisfy function's signature
return nullptr;
};
@@ -678,12 +687,16 @@ struct UnrollLoadGatherOpWithOffset
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
}
+ auto layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
+
SmallVector<Value> newOps;
for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
auto newOp = xegpu::LoadGatherOp::create(
rewriter, loc, newValueTy, op.getSource(), o, m,
rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL2HintAttr(), op.getL3HintAttr(), layout);
newOps.push_back(newOp);
}
@@ -774,12 +787,16 @@ struct UnrollStoreScatterOpWithOffsets
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ auto layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
+
for (auto [v, o, m] :
llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
rewriter.getI64IntegerAttr(chunkSize),
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9..be82cda 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -86,8 +86,16 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
if (origOffsets.empty())
return failure();
+ // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
+ xegpu::DistributeLayoutAttr layout;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
+ std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
+ layout = op.getLayoutAttr();
+ } else {
+ layout = op.getDescLayoutAttr();
+ }
+
// not applicable to ops without workgroup layout attributes
- xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -114,7 +122,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
// descriptors to be accessed, based on the layout information.
ArrayRef<int64_t> wgShape = op.getDataShape();
- auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeDescOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(maybeDescOffsets))
return failure();
@@ -189,7 +198,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
xegpu::TensorDescType tdescTy = op.getType();
ArrayRef<int64_t> wgShape = tdescTy.getShape();
Type elemTy = tdescTy.getElementType();
- xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
auto newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
@@ -308,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
SmallVector<Value> newOps;
for (auto [tdesc, offsets] :
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
@@ -317,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
auto newOp = xegpu::LoadNdOp::create(
rewriter, op.getLoc(), newResTy, tdesc, offsets,
/*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL2HintAttr(), op.getL3HintAttr(), layout);
newOps.push_back(newOp);
}
rewriter.replaceOpWithMultiple(op, {newOps});
@@ -338,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
for (auto [v, tdesc, offsets] :
llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
@@ -362,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
for (auto [tdesc, offsets] :
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
@@ -488,10 +506,8 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -737,12 +753,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
Location loc = op.getLoc();
auto eltType = vecType.getElementType();
- auto setLayoutIfNeeded = [&](Value val) {
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
- layout.dropSgLayoutAndData());
- }
+ auto setLayout = [&](Value val) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
};
if (vecAttr.isSplat()) {
@@ -750,14 +763,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
- setLayoutIfNeeded(cstOp->getResult(0));
+ setLayout(cstOp->getResult(0));
rewriter.replaceOp(op, cstOp);
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
// subgroups, don't distribute
auto newConstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
- setLayoutIfNeeded(newConstOp->getResult(0));
+ setLayout(newConstOp->getResult(0));
rewriter.replaceOp(op, newConstOp);
return success();
} else {
@@ -830,8 +843,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
@@ -859,9 +872,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
rewriter, loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
- setLayoutIfNeeded(baseConstVec);
- setLayoutIfNeeded(bcastOffset);
- setLayoutIfNeeded(finalConst);
+ setLayout(baseConstVec);
+ setLayout(bcastOffset);
+ setLayout(finalConst);
newConstOps.push_back(finalConst);
}
rewriter.replaceOpWithMultiple(op, {newConstOps});
@@ -912,11 +925,12 @@ struct WgToSgLoadGatherOpWithOffset
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
for (auto [offsets, mask] :
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+ auto newLayout = layout.dropSgLayoutAndData();
auto newLoadOp = xegpu::LoadGatherOp::create(
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
- op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
- xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
- layout.dropSgLayoutAndData());
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+ newLayout);
+ xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -964,16 +978,14 @@ struct WgToSgStoreScatterOpWithOffset
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
auto store = xegpu::StoreScatterOp::create(
rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
- op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+ layout.dropSgLayoutAndData());
// Update the layout attribute to drop sg_layout and sg_data.
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- for (OpOperand &operand : store->getOpOperands()) {
- // Skip for operand one (memref)
- if (operand.getOperandNumber() == 1)
- continue;
- xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
- }
+ for (OpOperand &operand : store->getOpOperands()) {
+ // Skip for operand one (memref)
+ if (operand.getOperandNumber() == 1)
+ continue;
+ xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
}
}
rewriter.eraseOp(op);
@@ -1052,7 +1064,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
@@ -1065,15 +1078,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
auto finalSteps =
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- xegpu::setDistributeLayoutAttr(steps->getResult(0),
- layout.dropSgLayoutAndData());
- xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
- layout.dropSgLayoutAndData());
- xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
- layout.dropSgLayoutAndData());
- }
+ xegpu::setDistributeLayoutAttr(steps->getResult(0),
+ layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
+ layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
+ layout.dropSgLayoutAndData());
newOps.push_back(finalSteps);
}
@@ -1141,10 +1151,8 @@ struct WgToSgVectorShapeCastOp
for (auto src : adaptor.getSource()) {
auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
newResultType, src);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+ layout.dropSgLayoutAndData());
newShapeCastOps.push_back(newShapeCast.getResult());
}
@@ -1205,10 +1213,8 @@ struct WgToSgMultiDimReductionOp
auto newOp = vector::MultiDimReductionOp::create(
rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
adaptor.getAcc()[0], op.getReductionDims());
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newOp->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newOp->getResult(0),
+ layout.dropSgLayoutAndData());
newReductions.push_back(newOp.getResult());
}
@@ -1217,6 +1223,142 @@ struct WgToSgMultiDimReductionOp
}
};
+// This pattern transforms vector.transpose ops to work at subgroup level.
+struct WgToSgVectorTransposeOp
+ : public OpConversionPattern<vector::TransposeOp> {
+ using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(op.getVector());
+ if (!sourceLayout || !sourceLayout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sourceSgLayout =
+ sourceLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
+ DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
+ DenseI32ArrayAttr resultOrder = layout.getOrder();
+
+ if (!sourceOrder || !resultOrder) {
+ return rewriter.notifyMatchFailure(
+ op, "Both source and result must have order attributes");
+ }
+
+ ArrayRef<int64_t> permutation = op.getPermutation();
+ size_t permutationSize = permutation.size();
+ if (sourceSgLayout.size() != permutationSize ||
+ resultSgLayout.size() != permutationSize) {
+ return rewriter.notifyMatchFailure(
+ op, "Layouts and permutation must have the same rank");
+ }
+
+ // Check that sgLayout, sgData & order are properly transposed for source
+ // and result
+ if (!layout.isTransposeOf(sourceLayout, permutation))
+ return rewriter.notifyMatchFailure(
+ op, "Result layout is not a valid transpose of source layout "
+ "according to permutation");
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+ SmallVector<Value> newTransposeOps;
+ for (auto src : adaptor.getVector()) {
+ auto newTranspose = vector::TransposeOp::create(
+ rewriter, op.getLoc(), newResultType, src, permutation);
+ xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
+ layout.dropSgLayoutAndData());
+ newTransposeOps.push_back(newTranspose.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newTransposeOps});
+ return success();
+ }
+};
+
+// Distribute vector mask ops to work at subgroup level.
+template <typename MaskOpType>
+struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
+ using OpConversionPattern<MaskOpType>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ MaskOpType op,
+ typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+
+ SmallVector<Value> wgMaskDimSizes;
+ if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
+ for (int64_t maskSize : op.getMaskDimSizes()) {
+ wgMaskDimSizes.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, maskSize));
+ }
+ } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
+ wgMaskDimSizes = llvm::to_vector(op.getOperands());
+ }
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+ // In each dimension, each subgroup computes its local mask size as:
+ // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
+ SmallVector<Value> newCreateMaskOps;
+ for (auto offsetSet : *sgOffsets) {
+ SmallVector<Value> maskOperands;
+
+ for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
+ Value dimSizeVal =
+ arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+ Value offset = offsetSet[i];
+ Value adjustedMaskSize =
+ arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value nonNegative =
+ arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+ Value sgMaskSize =
+ arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
+ maskOperands.push_back(sgMaskSize);
+ }
+
+ auto newCreateMaskOp =
+ vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
+ xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newCreateMaskOps.push_back(newCreateMaskOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
+ return success();
+ }
+};
+
+using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
+using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
} // namespace
namespace mlir {
@@ -1231,7 +1373,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
- WgToSgMultiDimReductionOp>(patterns.getContext());
+ WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
+ WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -1358,7 +1502,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+ vector::TransposeOp, vector::BroadcastOp,
+ vector::MultiDimReductionOp,
+ vector::ConstantMaskOp, vector::CreateMaskOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
@@ -1377,16 +1524,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::BroadcastOp>(
- [=](vector::BroadcastOp op) -> bool {
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
- });
-
- target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
- [=](vector::MultiDimReductionOp op) -> bool {
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
- });
-
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index a38993e..9f126fe 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -140,10 +139,14 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
// for StoreMatrixOp, the layout is attached to the property of the op
if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
return storeOp.getLayoutAttr();
-
std::string layoutName = getLayoutName(result);
if (defOp->hasAttr(layoutName))
return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+ // check for "permament" layout only after "temporary" layout name lookup
+ // for backward compatibility
+ if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
+ return loadGatherOp.getLayoutAttr();
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -171,27 +174,77 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
std::string layoutName = xegpu::getLayoutName(opr);
if (op->hasAttr(layoutName))
return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+ // check for "permament" layout only after "temporary" layout name lookup
+ if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+ if (auto layout = storeScatterOp.getLayoutAttr())
+ return layout;
+
return getDistributeLayoutAttr(opr.get());
}
+// Returns the permanent layout attribute for the given result if it's
+// available on the defining op. Otherwise returns the provided layout.
+xegpu::DistributeLayoutAttr
+maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
+ const OpResult &result, mlir::Operation *owner,
+ const std::string &name) {
+ xegpu::DistributeLayoutAttr candidate = layout;
+
+ if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
+ if (auto perm = loadOp.getLayoutAttr())
+ candidate = perm;
+ }
+
+ return candidate;
+}
+
+// Returns the permanent layout attribute for the given operand if it's
+// available on the defining op. Otherwise returns the provided layout.
+xegpu::DistributeLayoutAttr
+maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
+ const OpOperand &operand, mlir::Operation *owner,
+ const std::string &name) {
+ xegpu::DistributeLayoutAttr candidate = layout;
+ unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+
+ if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
+ if (idx == 0) {
+ if (auto perm = storeOp.getLayoutAttr())
+ candidate = perm;
+ }
+ }
+
+ return candidate;
+}
+
template <typename T, typename>
void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
- const DistributeLayoutAttr layout) {
+ const DistributeLayoutAttr layout,
+ bool respectPermLayout) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
- owner->setAttr(name, layout);
+
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
+ return;
+
+ DistributeLayoutAttr candidate = layout;
+ if (respectPermLayout)
+ candidate = maybePickPermanentLayout(layout, operandOrResult, owner, name);
+
+ if (candidate)
+ owner->setAttr(name, candidate);
}
// Explicit instantiation for OpResult
template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
const mlir::OpResult &result,
- const mlir::xegpu::DistributeLayoutAttr layout);
+ const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);
// Explicit instantiation for OpOperand
template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
const mlir::OpOperand &operand,
- const mlir::xegpu::DistributeLayoutAttr layout);
+ const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);
void xegpu::setDistributeLayoutAttrs(
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
@@ -253,7 +306,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
int64_t rankDiff = srcShapeRank - targetShapeRank;
std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
1);
- std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+ llvm::copy(shape, adjustedTargetShape.begin() + rankDiff);
SmallVector<Value> result;
for (SmallVector<int64_t> offsets :
@@ -473,7 +526,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder,
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
- results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
+ results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval));
}
return results;
}
@@ -500,3 +553,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
results.append(addElementwise(builder, loc, a, b));
return results;
}
+
+template <typename T>
+int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
+ ArrayRef<T> candidateMultiples) {
+ static_assert(std::is_integral<T>::value, "T must be an integer type");
+ int largest = -1;
+ SmallVector<T> multiples = {1};
+ if (!candidateMultiples.empty())
+ multiples =
+ SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
+ for (T candidate : candidates) {
+ for (T multiple : multiples) {
+ int value = static_cast<int>(candidate * multiple);
+ if (value != 0 && dim % value == 0 && value > largest)
+ largest = value;
+ }
+ }
+ return largest;
+}
+
+/// Explicit instantiations
+template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
+ ArrayRef<int> candidateMultiples);
+template int
+xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
+ ArrayRef<unsigned> candidateMultiples);