diff options
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp')
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 85 |
1 files changed, 70 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 397107b..fb5d1e7 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -280,27 +280,82 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, FailureOr<SmallVector<Value>> LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { - // TODO: handle order attribute - auto hasDefaultOrder = [&]() { - DenseI32ArrayAttr order = getOrder(); - return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( - llvm::reverse(order.asArrayRef()))); - }; - if (!hasDefaultOrder()) - return mlir::emitError(loc, "order attribute is currently not supported."); - SmallVector<int64_t> layout; + SmallVector<int64_t> sgLayoutInt; if (isForWorkgroup()) { - layout = getEffectiveSgLayoutAsInt(); + sgLayoutInt = getEffectiveSgLayoutAsInt(); } else if (isForSubgroup()) { - layout = getEffectiveLaneLayoutAsInt(); + sgLayoutInt = getEffectiveLaneLayoutAsInt(); } else { return failure(); } - auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value { - return builder.createOrFold<arith::ConstantIndexOp>(loc, d); - }); - return affine::delinearizeIndex(builder, loc, linearId, dims); + DenseI32ArrayAttr orderAttr = getOrder(); + + // Handle order attribute + SmallVector<int64_t> order; + if (orderAttr && !orderAttr.empty()) { + order = llvm::to_vector( + llvm::map_range(orderAttr.asArrayRef(), + [](int32_t idx) { return static_cast<int64_t>(idx); })); + } else { + // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc. + order = llvm::to_vector( + llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size()))); + } + + if (order.size() != sgLayoutInt.size()) { + return failure(); + } + + SmallVector<Value> result(sgLayoutInt.size()); + Value remaining = linearId; + + /// Process dimensions in the order they appear in the order array + /// The first dimension in order is the fastest-changing + /// + /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]: + /// + /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx], + /// result=[?,?,?] + /// + /// i=0 (process columns, dimIdx=2, dimSize=4): + /// result[2] = 22 % 4 = 2 (column coordinate) + /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed) + /// + /// i=1 (process rows, dimIdx=1, dimSize=4): + /// result[1] = 5 % 4 = 1 (row coordinate) + /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed) + /// + /// i=2 (process layers, dimIdx=0, dimSize=2): + /// result[0] = 1 % 2 = 1 (layer coordinate) + /// (no remaining update - last iteration) + /// + /// Final result: [1,1,2] = Layer 1, Row 1, Column 2 + for (size_t i = 0; i < order.size(); ++i) { + int64_t dimIdx = order[i]; + int64_t dimSize = sgLayoutInt[dimIdx]; + + Value dimSizeVal = + builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize); + + /// Extract the coordinate for this dimension using modulo operation + /// This gives us "how far within this dimension" we are + /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within + /// this dimension) + result[dimIdx] = + builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal); + + /// Update remaining for the next dimension by removing what we've already + /// processed. Division tells us "how many complete groups of this dimension + /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've + /// completed 5 groups of 4) Skip this for the last iteration since there's + /// no next dimension to process + if (i < order.size() - 1) { + remaining = + builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal); + } + } + return result; } /// Implements DistributeLayoutAttr::computeDistributedCoords to generate |
