aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNishant Patel <nishant.b.patel@intel.com>2025-11-04 19:37:08 -0800
committerGitHub <noreply@github.com>2025-11-04 19:37:08 -0800
commitf291f335c9628ea8d855fcc7c246171d70ceff58 (patch)
tree251817fd4e5d4187af1e460e7b11f7f172ef8568
parent952d4b4c0bc959afe6bf18a7550fab024ab5a9b8 (diff)
downloadllvm-f291f335c9628ea8d855fcc7c246171d70ceff58.zip
llvm-f291f335c9628ea8d855fcc7c246171d70ceff58.tar.gz
llvm-f291f335c9628ea8d855fcc7c246171d70ceff58.tar.bz2
[MLIR][XeGPU] Support order attribute and add pattern for vector.transpose in WgToSg Pass (#165307)
This PR does the following: 1. Handle order attribute during the delinearization from linear subgroup Id to multi-dim id. 2. Adds a transformation pattern for vector.transpose in wg to sg pass. 3. Updates CHECKS in the wg to sg tests
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp85
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp81
-rw-r--r--mlir/test/Dialect/XeGPU/subgroup-distribute.mlir42
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir39
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir6
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir88
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir64
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir145
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir123
9 files changed, 387 insertions, 286 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 397107b..fb5d1e7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -280,27 +280,82 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
FailureOr<SmallVector<Value>>
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
- // TODO: handle order attribute
- auto hasDefaultOrder = [&]() {
- DenseI32ArrayAttr order = getOrder();
- return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
- llvm::reverse(order.asArrayRef())));
- };
- if (!hasDefaultOrder())
- return mlir::emitError(loc, "order attribute is currently not supported.");
- SmallVector<int64_t> layout;
+ SmallVector<int64_t> sgLayoutInt;
if (isForWorkgroup()) {
- layout = getEffectiveSgLayoutAsInt();
+ sgLayoutInt = getEffectiveSgLayoutAsInt();
} else if (isForSubgroup()) {
- layout = getEffectiveLaneLayoutAsInt();
+ sgLayoutInt = getEffectiveLaneLayoutAsInt();
} else {
return failure();
}
- auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
- return affine::delinearizeIndex(builder, loc, linearId, dims);
+ DenseI32ArrayAttr orderAttr = getOrder();
+
+ // Handle order attribute
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::to_vector(
+ llvm::map_range(orderAttr.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+ } else {
+ // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+ order = llvm::to_vector(
+ llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+ }
+
+ if (order.size() != sgLayoutInt.size()) {
+ return failure();
+ }
+
+ SmallVector<Value> result(sgLayoutInt.size());
+ Value remaining = linearId;
+
+ /// Process dimensions in the order they appear in the order array
+ /// The first dimension in order is the fastest-changing
+ ///
+ /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+ ///
+ /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
+ /// result=[?,?,?]
+ ///
+ /// i=0 (process columns, dimIdx=2, dimSize=4):
+ /// result[2] = 22 % 4 = 2 (column coordinate)
+ /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
+ ///
+ /// i=1 (process rows, dimIdx=1, dimSize=4):
+ /// result[1] = 5 % 4 = 1 (row coordinate)
+ /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
+ ///
+ /// i=2 (process layers, dimIdx=0, dimSize=2):
+ /// result[0] = 1 % 2 = 1 (layer coordinate)
+ /// (no remaining update - last iteration)
+ ///
+ /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ int64_t dimSize = sgLayoutInt[dimIdx];
+
+ Value dimSizeVal =
+ builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+ /// Extract the coordinate for this dimension using modulo operation
+ /// This gives us "how far within this dimension" we are
+ /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
+ /// this dimension)
+ result[dimIdx] =
+ builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
+
+ /// Update remaining for the next dimension by removing what we've already
+ /// processed. Division tells us "how many complete groups of this dimension
+ /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+ /// completed 5 groups of 4) Skip this for the last iteration since there's
+ /// no next dimension to process
+ if (i < order.size() - 1) {
+ remaining =
+ builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
+ }
+ }
+ return result;
}
/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d12a04df..0a9ef0a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1219,6 +1219,70 @@ struct WgToSgMultiDimReductionOp
}
};
+// This pattern transforms vector.transpose ops to work at subgroup level.
+struct WgToSgVectorTransposeOp
+ : public OpConversionPattern<vector::TransposeOp> {
+ using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(op.getVector());
+ if (!sourceLayout || !sourceLayout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sourceSgLayout =
+ sourceLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
+ DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
+ DenseI32ArrayAttr resultOrder = layout.getOrder();
+
+ if (!sourceOrder || !resultOrder) {
+ return rewriter.notifyMatchFailure(
+ op, "Both source and result must have order attributes");
+ }
+
+ ArrayRef<int64_t> permutation = op.getPermutation();
+ size_t permutationSize = permutation.size();
+ if (sourceSgLayout.size() != permutationSize ||
+ resultSgLayout.size() != permutationSize) {
+ return rewriter.notifyMatchFailure(
+ op, "Layouts and permutation must have the same rank");
+ }
+
+ // Check that sgLayout, sgData & order are properly transposed for source
+ // and result
+ if (!layout.isTransposeOf(sourceLayout, permutation))
+ return rewriter.notifyMatchFailure(
+ op, "Result layout is not a valid transpose of source layout "
+ "according to permutation");
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+ SmallVector<Value> newTransposeOps;
+ for (auto src : adaptor.getVector()) {
+ auto newTranspose = vector::TransposeOp::create(
+ rewriter, op.getLoc(), newResultType, src, permutation);
+ xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
+ layout.dropSgLayoutAndData());
+ newTransposeOps.push_back(newTranspose.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newTransposeOps});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -1233,7 +1297,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
- WgToSgMultiDimReductionOp>(patterns.getContext());
+ WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -1360,7 +1425,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+ vector::TransposeOp, vector::BroadcastOp,
+ vector::MultiDimReductionOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
@@ -1379,16 +1446,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::BroadcastOp>(
- [=](vector::BroadcastOp op) -> bool {
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
- });
-
- target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
- [=](vector::MultiDimReductionOp op) -> bool {
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
- });
-
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 8946d14..8fd3cca 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -268,15 +268,16 @@ gpu.module @xevm_module{
// -----
// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
-// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
-// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
-// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
-// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
-// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
-// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
-// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C8]]
+// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C8]]
+// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C2]]
+// CHECK: %[[REMU3:.*]] = index.remu %[[REMU2]], %[[C2]]
+// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C8]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[REMU4]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[REMU4]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.module @xevm_module{
gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
%c0 = arith.constant 0 : index
@@ -288,19 +289,20 @@ gpu.module @xevm_module{
// -----
// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) {
-// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
-// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
-// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
-// CHECK: %[[USER_OFFSET_X:.*]] = arith.constant 1 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
-// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
-// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
-// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
-// CHECK: %[[LANE_X_OFFSET_1:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
-// CHECK: %[[LANE_X_OFFSET:.*]] = index.add %[[LANE_X_OFFSET_1]], %[[USER_OFFSET_X]]
-// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
-// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C4]]
+// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C4]]
+// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C4]]
+// CHECK: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2]]
+// CHECK: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C8]]
+// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C4]]
+// CHECK: %[[ADD:.*]] = index.add %[[REMU4]], %[[C1]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[ADD]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.module @xevm_module{
gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
index b73bc69..02c5f71 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -1,33 +1,32 @@
// RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
-//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test {
gpu.func @slice_attr() -> vector<128xindex> {
- //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
- //CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
- //CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
- //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
- //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
- //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C8:.*]]
+ // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU]], %[[C4:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
+ // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+ // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+ // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+ // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
gpu.return %step : vector<128xindex>
}
gpu.func @nested_slice_attr() -> vector<128xindex> {
- //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
- //CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
- //CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
- //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
- //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
- //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[DIVU1:.*]] = index.divu %[[SGID]], %[[C1:.*]]
+ // CHECK-DAG: %[[DIVU2:.*]] = index.divu %[[DIVU1]], %[[C8:.*]]
+ // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU2]], %[[C4:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
+ // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+ // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+ // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+ // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
%0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
gpu.return %0 : vector<128xindex>
}
-} \ No newline at end of file
+}
+
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
index 09df1e4..9580769 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -166,14 +166,12 @@ gpu.module @test_elementwise_ops {
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-> vector<24x32xf32>
- // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
// CHECK-NOT: arith.negf
%negf = arith.negf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
: vector<24x32xf32>
- // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+ // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
// CHECK-NOT: math.powf
%powf = math.powf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d2d250c..01134d8e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,14 +1,10 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-#map = affine_map<()[s0] -> (s0 floordiv 4)>
-#map1 = affine_map<()[s0] -> (s0 mod 4)>
-
gpu.module @test_round_robin_assignment {
// CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -16,22 +12,23 @@ gpu.module @test_round_robin_assignment {
}
// CHECK-LABEL: create_nd_tdesc_with_shared_data
- // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
- //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
- //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
- //CHECK: [[C16:%.+]] = arith.constant 16 : index
- //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
- //CHECK: [[C64:%.+]] = arith.constant 64 : index
- //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
- //CHECK: [[C0:%.+]] = arith.constant 0 : index
- //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
- //CHECK: [[C128:%.+]] = arith.constant 128 : index
- //CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]]
- //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
- //CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]]
- //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]]
+ // CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]]
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[LY:.*]] = index.mul %[[IDY]], %[[C16]]
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ // CHECK: %[[LX:.*]] = index.mul %[[IDX]], %[[C64]]
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[OFFY:.*]] = index.remu %[[LY]], %[[C128]]
+ // CHECK: %[[C64_1:.*]] = arith.constant 64 : index
+ // CHECK: %[[OFFX:.*]] = index.remu %[[LX]], %[[C64_1]]
+ // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
gpu.return
@@ -42,9 +39,7 @@ gpu.module @test_round_robin_assignment {
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
- // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+ // CHECK-COUNT-4: xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -57,9 +52,8 @@ gpu.module @test_round_robin_assignment {
gpu.func @store_nd(%src: memref<256x128xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-NOT : xegpu.store_nd
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-NOT: xegpu.store_nd
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf32>
@@ -73,8 +67,7 @@ gpu.module @test_round_robin_assignment {
gpu.func @update_nd(%src: memref<256x128xf32>){
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
- // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
+ // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.update_nd_offset
%update = xegpu.update_nd_offset %tdesc, [0, 16]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -84,15 +77,9 @@ gpu.module @test_round_robin_assignment {
// CHECK-LABEL: dpas
// CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+ // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -113,8 +100,7 @@ gpu.module @test_round_robin_assignment {
// CHECK-LABEL: prefetch_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
- // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -131,9 +117,7 @@ gpu.module @test_round_robin_assignment {
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
-> vector<128x1xf32>
- // CHECK-COUNT-2: vector.broadcast {{.*}}
- // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32>
+ // CHECK-COUNT-2: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} : vector<16x1xf32> to vector<16x32xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
@@ -171,10 +155,10 @@ gpu.module @test_round_robin_assignment {
%0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
%2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
- //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
+ // CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
%3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
%4 = arith.cmpi slt, %arg3, %c10_i32 : i32
- //CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32
+ // CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32
scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
} do {
// CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: vector<16xf32>, [[arg4:%.+]]: i32)
@@ -195,16 +179,16 @@ gpu.module @test_round_robin_assignment {
%2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
%3 = arith.cmpi eq, %0, %c10 : index
// CHECK-LABEL: scf.if
- // CHECK-SAME: (vector<16xf32>, vector<16xf32>)
+ // CHECK-SAME: (vector<16xf32>, vector<16xf32>)
%4 = scf.if %3 -> (vector<256xf32>) {
%5 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
// CHECK-LABEL: scf.yield
- // CHECK-SAME: vector<16xf32>, vector<16xf32>
+ // CHECK-SAME: vector<16xf32>, vector<16xf32>
scf.yield %5 : vector<256xf32>
} else {
%5 = xegpu.load_nd %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
// CHECK-LABEL: scf.yield
- // CHECK-SAME: vector<16xf32>, vector<16xf32>
+ // CHECK-SAME: vector<16xf32>, vector<16xf32>
scf.yield %5 : vector<256xf32>
} {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [16]>}
xegpu.store_nd %4, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -220,16 +204,16 @@ gpu.module @test_round_robin_assignment {
%0 = arith.cmpi eq, %id, %c10 : index
// CHECK-LABEL: scf.if
- // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
+ // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
%1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) {
%2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
// CHECK-LABEL: scf.yield
- // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
} else {
%3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
// CHECK-LABEL: scf.yield
- // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
}
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -238,8 +222,8 @@ gpu.module @test_round_robin_assignment {
gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) {
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
- //CHECK-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
- //CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
+ // CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
+ // CHECK-COUNT-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32>
%2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>,
target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 86a021b..84ce80f4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -14,13 +14,11 @@ gpu.module @test_distribution {
// CHECK-LABEL: load_nd_tdesc_with_offset
gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
- // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+ // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
- %load = xegpu.load_nd %tdesc[0, 0]
+ %load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf32>
gpu.return
@@ -28,8 +26,7 @@ gpu.module @test_distribution {
// CHECK-LABEL: store_nd_with_offset
gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
- // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.store_nd
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -42,10 +39,8 @@ gpu.module @test_distribution {
}
// CHECK-LABEL: prefetch_nd_tdesc_with_offset
- // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
- // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
- // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -57,15 +52,11 @@ gpu.module @test_distribution {
// CHECK-LABEL: dpas
// CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+ // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+ // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
%tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
-> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -102,27 +93,42 @@ gpu.module @test_distribution {
gpu.func @non_splat_constant() {
// CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]]
- // CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]]
- // CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]]
- // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]]
- // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
+ // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[SGID]], %[[C1:.*]]
+ // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C1:.*]]
+ // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C8:.*]]
+ // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2:.*]]
+ // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+ // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C1:.*]]
// CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
- // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
- // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
- // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[REMU5:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
+ // CHECK-DAG: %[[REMU6:.*]] = index.remu %[[REMU1]], %[[C1:.*]]
+ // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index
- // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index
// CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
// CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
- // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU5]], %[[C16:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index
- // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
+ // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU6]], %[[C0:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index
// CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex>
// CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
+
+ // CHECK-LABEL: vector_transpose
+ gpu.func @vector_transpose(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
+ -> vector<256x128xf32>
+ // CHECK-COUNT-2: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<32x16xf32> to vector<16x32xf32>
+ // CHECK-NOT: vector.transpose
+ %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
+ gpu.return
+ }
}
+
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 8d98fcf..4fbb566c 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,8 +1,5 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
-//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
-//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -26,13 +23,23 @@ gpu.module @test_distribution {
}
// CHECK-LABEL: load_nd_tdesc_with_offset
- // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
- //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
- %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+ //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ //CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ //CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ //CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4]]
+ //CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4]]
+ //CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ //CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8]]
+ //CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ //CHECK-DAG: %[[L_OFF_Y:.*]] = index.mul %[[SGIDY]], %[[C32]]
+ //CHECK-DAG: %[[L_OFF_X:.*]] = index.mul %[[SGIDX]], %[[C32]]
+ //CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+ //CHECK-DAG: %[[OFF_Y:.*]] = index.remu %[[L_OFF_Y]], %[[C256]]
+ //CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+ //CHECK-DAG: %[[OFF_X:.*]] = index.remu %[[L_OFF_X]], %[[C128]]
+ //CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -43,9 +50,6 @@ gpu.module @test_distribution {
// CHECK-LABEL: store_nd_with_offsets
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
//CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -60,9 +64,6 @@ gpu.module @test_distribution {
// CHECK-LABEL: prefetch_nd_tdesc_with_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
//CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%cst0 = arith.constant 0 : index
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
@@ -319,17 +320,15 @@ gpu.module @test_distribution {
gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
//CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
- //CHECK: [[c2:%.+]] = arith.constant 2 : index
//CHECK: [[c4:%.+]] = arith.constant 4 : index
- //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
- //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
- //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+ //CHECK: [[sgidx:%.+]] = index.remu [[sgid]], [[c4]]
+ //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgid]], [[c4]]
+ //CHECK: [[c2:%.+]] = arith.constant 2 : index
+ //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c2]]
//CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
- //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
- //CHECK: [[c0:%.+]] = arith.constant 0 : index
- //CHECK: [[c0_1:%.+]] = arith.constant 0 : index
+ //CHECK: [[l_off_y:%.+]] = index.mul [[sgidy]], [[c32]]
+ //CHECK: [[c32_0:%.+]] = arith.constant 32 : index
+ //CHECK: [[l_off_x:%.+]] = index.mul [[sgidx]], [[c32_0]]
//CHECK: [[c64:%.+]] = arith.constant 64 : index
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
//CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -346,17 +345,15 @@ gpu.module @test_distribution {
//CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32>
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
//CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
- //CHECK: [[c2:%.+]] = arith.constant 2 : index
//CHECK: [[c4:%.+]] = arith.constant 4 : index
- //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
- //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
- //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+ //CHECK: [[sgidx:%.+]] = index.remu [[sgid]], [[c4]]
+ //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgid]], [[c4]]
+ //CHECK: [[c2:%.+]] = arith.constant 2 : index
+ //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c2]]
//CHECK: [[c32:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
- //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
- //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
- //CHECK: [[c0:%.+]] = arith.constant 0 : index
- //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
+ //CHECK: [[l_off_y:%.+]] = index.mul [[sgidy]], [[c32]]
+ //CHECK: [[c32_0:%.+]] = arith.constant 32 : index
+ //CHECK: [[l_off_x:%.+]] = index.mul [[sgidx]], [[c32_0]]
//CHECK: [[c64:%.+]] = arith.constant 64 : index
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
//CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -411,14 +408,17 @@ gpu.module @test_distribution {
// CHECK-LABEL: vector_step_op
gpu.func @vector_step_op_slice_attr() {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
- //CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
- //CHECK-DAG: [[LY:%.+]] = index.mul [[IDY]], [[c32]]
- //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
- //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
- //CHECK-DAG: [[MODY:%.+]] = index.remu [[LY]], [[c128]]
- //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
- //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[c8:%.+]] = arith.constant 8 : index
+ //CHECK: [[sgidx:%.+]] = index.remu [[sgId]], [[c8]]
+ //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgId]], [[c8]]
+ //CHECK: [[c4:%.+]] = arith.constant 4 : index
+ //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c4]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LY:%.+]] = index.mul [[sgidy]], [[c32]]
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[LY]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
gpu.return
@@ -426,14 +426,14 @@ gpu.module @test_distribution {
gpu.func @vector_step_op_layout_attr() {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
- //CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
- //CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
- //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
- //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
- //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
- //CHECK-DAG: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
- //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
- //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+ //CHECK: [[c16:%.+]] = arith.constant 16 : index
+ //CHECK: [[sgidx:%.+]] = index.remu [[sgId]], [[c16]]
+ //CHECK: [[c8:%.+]] = arith.constant 8 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[sgidx]], [[c8]]
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
%step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
gpu.return
@@ -464,14 +464,27 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: vector_transpose
+ gpu.func @vector_transpose(%src: memref<256x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x32xf32>
+ -> !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
+ -> vector<256x32xf32>
+ //CHECK: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<64x32xf32> to vector<32x64xf32>
+ %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x32xf32> to vector<32x256xf32>
+ gpu.return
+ }
+
// CHECK-LABEL: non_splat_constant_2D
gpu.func @non_splat_constant_2D() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: affine.apply #map4()[%[[SGID]]]
- // CHECK-DAG: affine.apply #map5()[%[[SGID]]]
- // CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]]
- // CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]]
+ // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %{{.*}}
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %{{.*}}
+ // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %{{.*}}
+ // CHECK-DAG: %[[IDY:.*]] = index.remu %[[SGIDY]], %{{.*}}
+ // CHECK-DAG: %[[IDX:.*]] = index.remu %[[SGIDX]], %{{.*}}
// CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[STRIDECOL]] : index
// CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index
@@ -484,20 +497,19 @@ gpu.module @test_distribution {
// CHECK-LABEL: non_splat_constant_2D_non_unit_dim
gpu.func @non_splat_constant_2D_non_unit_dim() {
- // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex>
+ // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{\[}}{{\[}}0, 16{{\]}}, {{\[}}8, 24{{\]}}{{\]}}> : vector<2x2xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]]
- // CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
- // CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]]
+ // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %{{.*}}
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %{{.*}}
+ // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %{{.*}}
+ // CHECK-DAG: %[[MULY:.*]] = index.mul %[[SGIDY]], %[[C2:.*]]
+ // CHECK-DAG: %[[MULX:.*]] = index.mul %[[SGIDX]], %{{.*}}
// CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]]
- // CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index
- // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]]
- // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %{{.*}}
+ // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %{{.*}} : index
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[MUL5]] : index
// CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index
- // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index
+ // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<2x2xindex>
// CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex>
%cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[
@@ -517,13 +529,14 @@ gpu.module @test_distribution {
gpu.func @non_splat_constant() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C32:.*]]
- // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %{{.*}}
+ // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[REMU]], %{{.*}}
+ // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C16:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[MUL]] : index
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1xindex>
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496]> : vector<32xindex>
- // CHECK: arith.constant dense<{{\[}}[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]{{\]}}> : vector<1x16xindex>
+ // CHECK: arith.constant dense<{{\[}}{{\[}}0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15{{\]}}{{\]}}> : vector<1x16xindex>
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
gpu.return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index e83229e..5ce3d1d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,47 +1,35 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
-//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
- //CHECK: [[C32:%.+]] = arith.constant 32 : index
- //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
- //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
- //CHECK: [[C0:%.+]] = arith.constant 0 : index
- //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
- //CHECK: [[C256:%.+]] = arith.constant 256 : index
- //CHECK: [[Y:%.+]] = index.remu [[LY]], [[C256]]
- //CHECK: [[C128:%.+]] = arith.constant 128 : index
- //CHECK: [[X:%.+]] = index.remu [[LX]], [[C128]]
- //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REMUX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[REMUY:.*]] = index.remu %[[DIVU]], %[[C8:.*]]
+ // CHECK-DAG: %[[MULY:.*]] = index.mul %[[REMUY]], %[[C32:.*]]
+ // CHECK-DAG: %[[MULX:.*]] = index.mul %[[REMUX]], %[[C32:.*]]
+ // CHECK-DAG: %[[MODY:.*]] = index.remu %[[MULY]], %[[C256:.*]]
+ // CHECK-DAG: %[[MODX:.*]] = index.remu %[[MULX]], %[[C128:.*]]
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[MODY]], %[[MODX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: create_nd_tdesc_from_higher_rank_memref
- // CHECK-SAME: [[ARG_0:%.*]]: memref<3x256x128xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<3x256x128xf32>
gpu.func @create_nd_tdesc_from_higher_rank_memref(%src: memref<3x256x128xf32>) {
- //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
- //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
- //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
- //CHECK: [[C32:%.+]] = arith.constant 32 : index
- //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
- //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
- //CHECK: [[C0:%.+]] = arith.constant 0 : index
- //CHECK: [[C0_2:%.+]] = arith.constant 0 : index
- //CHECK: [[C256:%.+]] = arith.constant 256 : index
- //CHECK: [[MODY:%.+]] = index.remu [[LY]], [[C256]]
- //CHECK: [[C128:%.+]] = arith.constant 128 : index
- //CHECK: [[MODX:%.+]] = index.remu [[LX]], [[C128]]
- //CHECK: [[C0_3:%.+]] = arith.constant 0 : index
- //CHECK: [[C0_4:%.+]] = arith.constant 0 : index
- //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[MODY]], [[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REMUX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+ // CHECK-DAG: %[[REMUY:.*]] = index.remu %[[DIVU]], %[[C8:.*]]
+ // CHECK-DAG: %[[MULY:.*]] = index.mul %[[REMUY]], %[[C32:.*]]
+ // CHECK-DAG: %[[MULX:.*]] = index.mul %[[REMUX]], %[[C32:.*]]
+ // CHECK-DAG: %[[MODY:.*]] = index.remu %[[MULY]], %[[C256:.*]]
+ // CHECK-DAG: %[[MODX:.*]] = index.remu %[[MULX]], %[[C128:.*]]
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][1, %[[MODY]], %[[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
@@ -81,25 +69,24 @@ gpu.module @test_1_1_assignment {
xegpu.store_nd %load, %tdesc
: vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
-}
+ }
-// CHECK-LABEL: update_nd
-// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
-gpu.func @update_nd(%src: memref<256x128xf32>){
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
- // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
- -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
- %update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
-}
+ // CHECK-LABEL: update_nd
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @update_nd(%src: memref<256x128xf32>){
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %update = xegpu.update_nd_offset %tdesc, [0, 16]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
-// CHECK-LABEL: dpas
-gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+ // CHECK-LABEL: dpas
+ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
@@ -110,16 +97,15 @@ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
-> vector<128x128xf16>
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
%dpas = xegpu.dpas %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
-
-// CHECK-LABEL: dpas_no_sg_data
-gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ // CHECK-LABEL: dpas_no_sg_data
+ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
order = [1, 0]>>
@@ -134,6 +120,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
order = [1, 0]>>
-> vector<128x128xf16>
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
%dpas = xegpu.dpas %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
@@ -196,9 +183,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
}
gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
- //CHECK: [[c0:%.+]] = arith.constant 0 : index
- //CHECK: [[c128:%.+]] = arith.constant 128 : index
- //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 : index
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1024 = arith.constant 1024 : index
@@ -211,15 +198,15 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
%5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
- // CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
- // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
+ // CHECK: %[[SCF:.*]]:3 = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C1024]] step %[[C128]]
+ // CHECK-SAME: iter_args(%[[ARG4:.*]] = {{.*}}, %[[ARG5:.*]] = {{.*}}, %[[ARG6:.*]] = {{.*}}) ->
// CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
- // CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
- // CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
- // CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
- // CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
- // CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
- // CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
+ // CHECK: %[[A:.*]] = xegpu.load_nd %[[ARG4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+ // CHECK: %[[B:.*]] = xegpu.load_nd %[[ARG5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+ // CHECK: %[[C:.*]] = xegpu.dpas %[[A]], %[[B]], %[[ARG6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+ // CHECK: %[[AT:.*]] = xegpu.update_nd_offset %[[ARG4]], [%[[C0]], %[[C128]]] : !xegpu.tensor_desc<16x128xf16>
+ // CHECK: %[[BT:.*]] = xegpu.update_nd_offset %[[ARG5]], [%[[C128]], %[[C0]]] : !xegpu.tensor_desc<128x16xf16>
+ // CHECK: scf.yield %[[AT]], %[[BT]], %[[C]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
%6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3)
-> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
@@ -252,7 +239,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
// CHECK: scf.condition{{.*}} : vector<16xf32>, i32
scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
} do {
- // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: i32)
+ // CHECK: (%[[ARG2:.*]]: vector<16xf32>, %[[ARG3:.*]]: i32)
^bb0(%arg2: vector<256xf32>, %arg3: i32):
xegpu.store_nd %arg2, %2 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
%4 = arith.addi %arg3, %c1_i32 : i32
@@ -344,9 +331,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%cond4 = arith.cmpi slt, %sg_id, %c31 : index
%cond5 = arith.andi %cond3, %cond4 : i1
scf.if %cond5 {
- // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
%tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc