aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR')
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp390
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp137
2 files changed, 400 insertions, 127 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index f9aa28d5..1a19ab5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -8,10 +8,8 @@
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -38,55 +36,61 @@ void XeGPUDialect::initialize() {
>();
}
-/// Generates instructions to compute offsets for a subgroup identified by
-/// its multidimensional indices (sgId), using the specified subgroup layout
-/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
-/// dimensions (sizePerWg).
+// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
+// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
+// within each distribution unit.
+// Example:
+// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
+// distribution unit of shape 64x64, we have 2x4 such distribution units.
+// `delinearizedId` is used to identify a 16x32 of a subgroup in each
+// distribution unit.
static SmallVector<SmallVector<Value>>
-genOffsetsComputingInsts(OpBuilder &builder, Location loc,
- SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
- ArrayRef<int64_t> sizePerSg,
- ArrayRef<int64_t> sizePerWg) {
-
- SmallVector<SmallVector<Value>> offsets;
+genCoordinates(OpBuilder &builder, Location loc,
+ SmallVector<Value> delinearizedId,
+ ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+ ArrayRef<int64_t> srcShape) {
+ SmallVector<SmallVector<Value>> coordinates;
+
+ // A distribution unit must be less than or equal to `srcShape`
+ SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
+ llvm::zip_equal(srcShape,
+ computeElementwiseMul(subShapesLayout, subShape)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
- // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
- SmallVector<Value> localOffsets = llvm::map_to_vector(
- llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
- return builder.createOrFold<index::MulOp>(
+ // Get the offset of `subShape` within a distribution unit.
+ SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
+ llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::MulIOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});
- // distUnit[i] is the minimum value between sizePerWg[i] and
- // sgLayout[i] * sizePerSg[i]
- SmallVector<int64_t> distUnit = llvm::map_to_vector(
- llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
- [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
-
+ // For each dist unit
for (SmallVector<int64_t> unitOffs :
- StaticTileOffsetRange(sizePerWg, distUnit)) {
+ StaticTileOffsetRange(srcShape, distUnitShape)) {
+ // Get dist unit offset within `srcShape`.
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return arith::ConstantIndexOp::create(builder, loc, d);
});
-
- SmallVector<Value> adds = llvm::map_to_vector(
- llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
- return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
- std::get<1>(t));
- });
-
+ // Calculate `subShape` offset within `srcShape`.
+ SmallVector<Value> adds =
+ llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
+ [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(
+ loc, std::get<0>(t), std::get<1>(t));
+ });
+ // Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
- llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
- return builder.createOrFold<index::RemUOp>(
+ llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::RemUIOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
});
- offsets.push_back(mods);
+ coordinates.push_back(mods);
}
- return offsets;
+ return coordinates;
}
// Checks if the given shape can be evenly distributed based on the layout
@@ -229,8 +233,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
- return emitError()
- << "expected inst_data and lane_layout to have the same rank";
+ return emitError() << "expected inst_data and lane_layout to have the same "
+ "rank, got inst_data "
+ << inst_data.size() << ", lane_layout "
+ << lane_layout.size();
}
// sg_data is optional for Workgroup layout, but its presence requires
@@ -271,56 +277,197 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
- // delinearizeSubgroupId is only available for
- // workgroup-level layout attribute
- if (!isForWorkgroup())
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
+
+ SmallVector<int64_t> sgLayoutInt;
+ if (isForWorkgroup()) {
+ sgLayoutInt = getEffectiveSgLayoutAsInt();
+ } else if (isForSubgroup()) {
+ sgLayoutInt = getEffectiveLaneLayoutAsInt();
+ } else {
return failure();
+ }
- // TODO: handle order attribute
- auto hasDefaultOrder = [&]() {
- DenseI32ArrayAttr order = getOrder();
- return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
- llvm::reverse(order.asArrayRef())));
- };
- if (!hasDefaultOrder())
- return mlir::emitError(loc, "order attribute is currently not supported.");
+ DenseI32ArrayAttr orderAttr = getOrder();
- auto dims =
- llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
+ // Handle order attribute
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::to_vector(
+ llvm::map_range(orderAttr.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+ } else {
+ // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+ order = llvm::to_vector(
+ llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+ }
- return affine::delinearizeIndex(builder, loc, linearId, dims);
+ if (order.size() != sgLayoutInt.size()) {
+ return failure();
+ }
+
+ SmallVector<Value> result(sgLayoutInt.size());
+ Value remaining = linearId;
+
+ /// Process dimensions in the order they appear in the order array
+ /// The first dimension in order is the fastest-changing
+ ///
+ /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+ ///
+ /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
+ /// result=[?,?,?]
+ ///
+ /// i=0 (process columns, dimIdx=2, dimSize=4):
+ /// result[2] = 22 % 4 = 2 (column coordinate)
+ /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
+ ///
+ /// i=1 (process rows, dimIdx=1, dimSize=4):
+ /// result[1] = 5 % 4 = 1 (row coordinate)
+ /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
+ ///
+ /// i=2 (process layers, dimIdx=0, dimSize=2):
+ /// result[0] = 1 % 2 = 1 (layer coordinate)
+ /// (no remaining update - last iteration)
+ ///
+ /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ int64_t dimSize = sgLayoutInt[dimIdx];
+
+ Value dimSizeVal =
+ builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+ /// Extract the coordinate for this dimension using modulo operation
+ /// This gives us "how far within this dimension" we are
+ /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
+ /// this dimension)
+ result[dimIdx] =
+ builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
+
+ /// Update remaining for the next dimension by removing what we've already
+ /// processed. Division tells us "how many complete groups of this dimension
+ /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+ /// completed 5 groups of 4) Skip this for the last iteration since there's
+ /// no next dimension to process
+ if (i < order.size() - 1) {
+ remaining =
+ builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
+ }
+ }
+ return result;
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
- if (!isForWorkgroup())
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
return failure();
-
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ }
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
- SmallVector<Value> sgIds = *maybeIds;
+ SmallVector<Value> ids = *maybeIds;
+
+ return genCoordinates(builder, loc, ids, layout, subShape, shape);
+}
+
+bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::SliceAttr>(other))
+ return false;
+
+ return *this == dyn_cast<xegpu::LayoutAttr>(other);
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ auto sgDataOpt = getSgData();
+ auto instDataOpt = getInstData();
+ auto laneDataOpt = getLaneData();
+
+ SmallVector<int32_t> sgData;
+ SmallVector<int32_t> instData;
+ SmallVector<int32_t> laneData;
+
+ if (sgDataOpt) {
+ sgData = llvm::to_vector(sgDataOpt.asArrayRef());
+ }
+ if (instDataOpt) {
+ instData = llvm::to_vector(instDataOpt.asArrayRef());
+ }
+ if (laneDataOpt) {
+ laneData = llvm::to_vector(laneDataOpt.asArrayRef());
+ }
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgData.size()))
+ sgData[dim] = 1;
+ if (dim < static_cast<int64_t>(instData.size()))
+ instData[dim] = 1;
+ if (dim < static_cast<int64_t>(laneData.size()))
+ laneData[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(), getSgLayout(),
+ sgData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgData),
+ instData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), instData),
+ getLaneLayout(),
+ laneData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneData),
+ getOrder());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ auto sgLayoutOpt = getSgLayout();
+ auto laneLayoutOpt = getLaneLayout();
+
+ SmallVector<int32_t> sgLayout;
+ SmallVector<int32_t> laneLayout;
+
+ if (sgLayoutOpt) {
+ sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
+ }
+ if (laneLayoutOpt) {
+ laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
+ }
+
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgLayout.size()))
+ sgLayout[dim] = 1;
+ if (dim < static_cast<int64_t>(laneLayout.size()))
+ laneLayout[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(),
+ sgLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgLayout),
+ getSgData(), getInstData(),
+ laneLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneLayout),
+ getLaneData(), getOrder());
}
//===----------------------------------------------------------------------===//
@@ -374,34 +521,43 @@ SliceAttr SliceAttr::flatten() const {
}
FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
- return parent.delinearizeSubgroupId(builder, loc, linearId);
+ return parent.delinearizeId(builder, loc, linearId);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
-/// instructions for computing multi-dimensional offsets when distributed by
-/// SliceAttr.
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// instructions for computing multi-dimensional offsets when distributed by
+// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
+ return failure();
+ }
+
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
@@ -411,8 +567,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
SmallVector<Value> sgIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
}
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
@@ -435,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
[&](int64_t dim) { return thisDims.contains(dim); });
}
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::LayoutAttr>(other))
+ return false;
+
+ auto flattenedThis = flatten();
+ auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+ return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+ (flattenedThis.getDims() == flattenedOther.getDims()));
+}
+
+// Helper function to adjust unit dimensions from sliced space to parent space
+static SetVector<int64_t>
+adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
+ ArrayRef<int64_t> sliceDims) {
+ // Reconstruct parent's non-sliced dimensions
+
+ int64_t parentRank = sliceDims.size() + unitDims.size();
+ llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
+ sliceDims.end());
+ SmallVector<int64_t> nonSlicedDims;
+ for (int64_t i = 0; i < parentRank; ++i) {
+ if (!slicedDimsSet.contains(i))
+ nonSlicedDims.push_back(i);
+ }
+
+ // Map unit dims from sliced space to parent space
+ SetVector<int64_t> adjustUnitDims;
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
+ adjustUnitDims.insert(nonSlicedDims[dim]);
+ }
+ }
+
+ return adjustUnitDims;
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
+ attr.getDims());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
+ attr.getDims());
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
@@ -569,8 +787,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
int chunkAlignmentFactor =
- bitWidth < targetinfo::packedSizeInBitsForGatherScatter
- ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
+ bitWidth < xegpu::uArch::generalPackedFormatBitSize
+ ? xegpu::uArch::generalPackedFormatBitSize / bitWidth
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2..91ba07a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
LogicalResult
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
- UnitAttr subgroup_block_io,
+ UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
function_ref<InFlightDiagnostic()> emitError) {
if (!dataTy) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
+ "are only allowed when result is a VectorType.";
else
return success();
}
@@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ ArrayAttr strideAttr = mdescTy.getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+ if (subgroup_block_io && layout) {
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (!laneData.empty()) {
+ bool isLaneDataContiguous =
+ std::all_of(laneData.begin(), std::prev(laneData.end()),
+ [](int x) { return x == 1; });
+ if (!isLaneDataContiguous)
+ return emitError() << "With subgroup_block_io, accessed data must be "
+ "contiguous and coalesced.";
+ for (size_t i = 0; i < laneData.size(); ++i) {
+ if (laneLayout[i] != blockShape[i])
+ return emitError() << "With subgroup_block_io, the block shape must "
+ "match the lane layout.";
+ if (laneLayout[i] != 1 && strides[i] != 1)
+ return emitError() << "With subgroup_block_io, the distributed "
+ "dimensions must be contiguous.";
+ }
+ }
+ }
if (dataShape.size() == 2) {
- if (subgroup_block_io)
- return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
} else {
- SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
// if the subgroup_block_io attribute is set, mdescTy must have block
// attribute
if (subgroup_block_io && !blockShape.size())
@@ -258,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -320,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -439,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l3_hint) {
return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -454,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult PrefetchNdOp::verify() {
@@ -472,11 +499,8 @@ LogicalResult PrefetchNdOp::verify() {
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -496,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
return build(builder, state, retType, tensorDesc, ValueRange(),
DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
- l3_hint);
+ l3_hint, /*anchor_layout=*/nullptr);
}
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
@@ -504,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
UnitAttr packed, DenseI64ArrayAttr transpose,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -512,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
- packed, transpose, l1_hint, l2_hint, l3_hint);
+ packed, transpose, l1_hint, l2_hint, l3_hint,
+ /*anchor_layout=*/layout);
}
LogicalResult LoadNdOp::verify() {
@@ -597,11 +623,8 @@ LogicalResult LoadNdOp::verify() {
<< tdescTy;
int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -618,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
xegpu::CachePolicyAttr l3_hint) {
return build(builder, state, value, tensorDesc, ValueRange(),
- DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+ DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
+ /*anchor_layout=*/nullptr);
}
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -633,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult StoreNdOp::verify() {
@@ -691,11 +716,8 @@ LogicalResult StoreNdOp::verify() {
<< dstTy;
int64_t tDescRank = dstTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -809,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
- IntegerAttr{});
+ IntegerAttr{}, /*anchor_layout=*/nullptr);
}
//===----------------------------------------------------------------------===//
@@ -859,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -875,7 +897,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/nullptr);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint,
+ DistributeLayoutAttr layout) {
+ auto loc = source.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+ l2_hint, l3_hint, layout);
}
//===----------------------------------------------------------------------===//
@@ -926,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -944,7 +983,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
- l3_hint);
+ l3_hint, /*anchor_layout=*/nullptr);
+}
+
+void StoreScatterOp::build(
+ OpBuilder &builder, OperationState &state, Value value, Value dest,
+ ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
+ xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
+ auto loc = dest.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ // Call the correct builder overload that does not expect result types.
+ build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+ l3_hint, layout);
}
//===----------------------------------------------------------------------===//
@@ -1105,7 +1160,7 @@ LogicalResult LoadMatrixOp::verify() {
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1129,7 +1184,7 @@ LogicalResult StoreMatrixOp::verify() {
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
namespace mlir {