aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp')
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp85
1 files changed, 70 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 397107b..fb5d1e7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -280,27 +280,82 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
FailureOr<SmallVector<Value>>
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
- // TODO: handle order attribute
- auto hasDefaultOrder = [&]() {
- DenseI32ArrayAttr order = getOrder();
- return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
- llvm::reverse(order.asArrayRef())));
- };
- if (!hasDefaultOrder())
- return mlir::emitError(loc, "order attribute is currently not supported.");
- SmallVector<int64_t> layout;
+ SmallVector<int64_t> sgLayoutInt;
if (isForWorkgroup()) {
- layout = getEffectiveSgLayoutAsInt();
+ sgLayoutInt = getEffectiveSgLayoutAsInt();
} else if (isForSubgroup()) {
- layout = getEffectiveLaneLayoutAsInt();
+ sgLayoutInt = getEffectiveLaneLayoutAsInt();
} else {
return failure();
}
- auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
- return affine::delinearizeIndex(builder, loc, linearId, dims);
+ DenseI32ArrayAttr orderAttr = getOrder();
+
+ // Handle order attribute
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::to_vector(
+ llvm::map_range(orderAttr.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+ } else {
+ // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+ order = llvm::to_vector(
+ llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+ }
+
+ if (order.size() != sgLayoutInt.size()) {
+ return failure();
+ }
+
+ SmallVector<Value> result(sgLayoutInt.size());
+ Value remaining = linearId;
+
+ /// Process dimensions in the order they appear in the order array
+ /// The first dimension in order is the fastest-changing
+ ///
+ /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+ ///
+ /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
+ /// result=[?,?,?]
+ ///
+ /// i=0 (process columns, dimIdx=2, dimSize=4):
+ /// result[2] = 22 % 4 = 2 (column coordinate)
+ /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
+ ///
+ /// i=1 (process rows, dimIdx=1, dimSize=4):
+ /// result[1] = 5 % 4 = 1 (row coordinate)
+ /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
+ ///
+ /// i=2 (process layers, dimIdx=0, dimSize=2):
+ /// result[0] = 1 % 2 = 1 (layer coordinate)
+ /// (no remaining update - last iteration)
+ ///
+ /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ int64_t dimSize = sgLayoutInt[dimIdx];
+
+ Value dimSizeVal =
+ builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+ /// Extract the coordinate for this dimension using modulo operation
+ /// This gives us "how far within this dimension" we are
+ /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
+ /// this dimension)
+ result[dimIdx] =
+ builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
+
+ /// Update remaining for the next dimension by removing what we've already
+ /// processed. Division tells us "how many complete groups of this dimension
+ /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+ /// completed 5 groups of 4) Skip this for the last iteration since there's
+ /// no next dimension to process
+ if (i < order.size() - 1) {
+ remaining =
+ builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
+ }
+ }
+ return result;
}
/// Implements DistributeLayoutAttr::computeDistributedCoords to generate