aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp82
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp219
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp15
-rw-r--r--mlir/lib/Dialect/XeGPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp85
-rw-r--r--mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt17
-rw-r--r--mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp225
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp81
9 files changed, 576 insertions, 154 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 262d9b7..d43f881 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1752,15 +1752,21 @@ std::string NVVM::MBarrierInitOp::getPtx() {
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
+static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
+ auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
+ return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+}
+
+static bool isPtrInSharedCTASpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+}
+
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
- .getAddressSpace();
- llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
- ? llvm::Intrinsic::nvvm_mbarrier_init_shared
- : llvm::Intrinsic::nvvm_mbarrier_init;
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
+ : llvm::Intrinsic::nvvm_mbarrier_init;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
@@ -1773,16 +1779,72 @@ mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
- .getAddressSpace();
- llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id = isShared
? llvm::Intrinsic::nvvm_mbarrier_inval_shared
: llvm::Intrinsic::nvvm_mbarrier_inval;
return {id, {mt.lookupValue(thisOp.getAddr())}};
}
+mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id = isShared
+ ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared
+ : llvm::Intrinsic::nvvm_mbarrier_arrive;
+
+ return {id, {mt.lookupValue(thisOp.getAddr())}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id =
+ isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
+ : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id = isShared
+ ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
+ : llvm::Intrinsic::nvvm_mbarrier_test_wait;
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getState()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+
+ llvm::Intrinsic::ID id;
+ if (thisOp.getNoinc()) {
+ id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
+ : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
+ } else {
+ id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
+ : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
+ }
+
+ return {id, {mt.lookupValue(thisOp.getAddr())}};
+}
+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index bd25e94..027268c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -232,10 +232,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
// 2. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied.
- PackingMetadata packingMetadata = computePackingMetadata(
- packedTensorType.getRank(), packOp.getInnerDimsPos());
+ PackingMetadata packingMetadata;
SmallVector<int64_t> packedToStripMinedShapePerm =
- getPackInverseDestPerm(packOp);
+ getPackInverseDestPerm(packOp, packingMetadata);
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cb6199f..19d2d85 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
-/// Given a linalg::PackOp, return the `dest` shape before any packing
-/// permutations.
-static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
- ArrayRef<int64_t> destShape) {
- return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
-}
-
/// Determines whether a mask for xfer_write is trivially "all true"
///
/// Given all the inputs required to generate a mask (mask sizes and shapes),
@@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return mlir::vector::maskOperation(builder, write, maskForWrite);
}
-/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
-/// padding value and (3) input vector sizes into:
-///
-/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
-///
-/// As in the following example:
-/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
-/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
-///
-/// This pack would be vectorized to:
-///
-/// %load = vector.mask %mask {
-/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
-/// {in_bounds = [true, true, true]} :
-/// tensor<32x7x16xf32>, vector<32x8x16xf32>
-/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
-/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
-/// to vector<32x4x2x1x16xf32>
-/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
-/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-/// %write = vector.transfer_write %transpose,
-/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
-/// {in_bounds = [true, true, true, true, true]}
-/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-///
-/// If the (3) input vector sizes are not provided, the vector sizes are
-/// determined by the result tensor shape and the `in_bounds`
-/// attribute is used instead of masking to mark out-of-bounds accesses.
-///
-/// NOTE: The input vector sizes specify the dimensions corresponding to the
-/// outer dimensions of the output tensor. The remaining dimensions are
-/// computed based on, e.g., the static inner tiles.
-/// Supporting dynamic inner tiles will require the user to specify the
-/// missing vector sizes. This is left as a TODO.
-static LogicalResult
-vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
- ArrayRef<int64_t> inputVectorSizes,
- SmallVectorImpl<Value> &newResults) {
- // TODO: Introduce a parent class that will handle the insertion point update.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(packOp);
-
- Location loc = packOp.getLoc();
- std::optional<Value> padValue = packOp.getPaddingValue()
- ? std::optional(packOp.getPaddingValue())
- : std::nullopt;
-
- // If the input vector sizes are not provided, then the vector sizes are
- // determined by the result tensor shape. In case the vector sizes aren't
- // provided, we update the inBounds attribute instead of masking.
- bool useInBoundsInsteadOfMasking = false;
- if (inputVectorSizes.empty()) {
- ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
- inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
- useInBoundsInsteadOfMasking = true;
- }
-
- // Create masked TransferReadOp.
- SmallVector<int64_t> inputShape(inputVectorSizes);
- auto innerTiles = packOp.getStaticInnerTiles();
- auto innerDimsPos = packOp.getInnerDimsPos();
- auto outerDimsPerm = packOp.getOuterDimsPerm();
- if (!outerDimsPerm.empty())
- applyPermutationToVector(inputShape,
- invertPermutationVector(outerDimsPerm));
- for (auto [idx, size] : enumerate(innerTiles))
- inputShape[innerDimsPos[idx]] *= size;
- auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, packOp.getSource(), inputShape, padValue,
- useInBoundsInsteadOfMasking,
- /*inputScalableVecSizes=*/{});
-
- // Create ShapeCastOp.
- SmallVector<int64_t> destShape(inputVectorSizes);
- destShape.append(innerTiles.begin(), innerTiles.end());
- auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
- packOp.getDestType().getElementType());
- auto shapeCastOp =
- vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
-
- // Create TransposeOp.
- auto destPermutation =
- invertPermutationVector(getPackInverseDestPerm(packOp));
- auto transposeOp = vector::TransposeOp::create(
- rewriter, loc, shapeCastOp.getResult(), destPermutation);
-
- // Create TransferWriteOp.
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), packOp.getDest());
- newResults.push_back(write->getResult(0));
- return success();
-}
-
/// Given the re-associations, "collapses" the input Vector type
///
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1901,12 +1801,121 @@ static VectorType getCollapsedVecType(VectorType type,
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
}
+/// Vectorize `linalg.pack` as:
+/// * xfer_read -> shape_cast -> transpose -> xfer_write
+///
+/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
+/// sizes for the xfer_write operation). This is sufficient to infer the other
+/// vector sizes required here.
+///
+/// If the vector sizes are not provided:
+/// * the vector sizes are determined from the destination tensor static shape.
+/// * the inBounds attribute is used instead of masking.
+///
+/// EXAMPLE (no vector sizes):
+/// ```
+/// %pack = tensor.pack %src
+/// inner_dims_pos = [2, 1]
+/// inner_tiles = [16, 2]
+/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+/// ``
+/// is vectorizes as:
+/// ```
+/// %read = vector.transfer_read %src
+/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
+/// %sc = vector.shape_cast %read
+/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
+/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+/// %write = vector.transfer_write %tr into %dest
+/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+/// ```
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ if (!inputVectorSizes.empty()) {
+ assert(inputVectorSizes.size() == packOp.getDestRank() &&
+ "Invalid number of input vector sizes!");
+ }
+
+ // TODO: Introduce a parent class that will handle the insertion point update.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ Location loc = packOp.getLoc();
+ std::optional<Value> padValue = packOp.getPaddingValue()
+ ? std::optional(packOp.getPaddingValue())
+ : std::nullopt;
+
+ SmallVector<int64_t> destShape =
+ SmallVector<int64_t>(packOp.getDestType().getShape());
+
+ // This is just a convenience alias to clearly communicate that the input
+ // vector sizes determine the _write_ sizes.
+ ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
+
+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
+ // In addition, use the inBounds attribute instead of masking.
+ bool useInBoundsInsteadOfMasking = false;
+ if (writeVectorSizes.empty()) {
+ if (ShapedType::isDynamicShape(destShape))
+ return rewriter.notifyMatchFailure(packOp,
+ "unable to infer vector sizes");
+
+ writeVectorSizes = destShape;
+ useInBoundsInsteadOfMasking = true;
+ }
+
+ // Compute pre-transpose-write-vector-type, i.e. the write vector type
+ // _before_ the transposition (i.e. before dimension permutation). This is
+ // done by inverting the permutation/transposition that's part of the Pack
+ // operation. This type is required to:
+ // 1) compute the read vector type for masked-read below, and
+ // 2) generate shape-cast Op below that expands the read vector type.
+ PackingMetadata packMetadata;
+ SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
+ auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
+ applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
+ auto preTransposeWriteVecType = VectorType::get(
+ preTransposeWriteVecSizses, packOp.getType().getElementType());
+
+ // Compute vector type for the _read_ opeartion. This is simply
+ // pre-transpose-write-vector-type with the dimensions collapsed
+ // as per the Pack operation.
+ VectorType readVecType = getCollapsedVecType(
+ preTransposeWriteVecType,
+ getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ rewriter.getContext(), packMetadata.reassociations)));
+
+ // Create masked TransferReadOp.
+ auto maskedRead = vector::createReadOrMaskedRead(
+ rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
+ useInBoundsInsteadOfMasking,
+ /*inputScalableVecSizes=*/{});
+
+ // Create ShapeCastOp.
+ auto shapeCastOp = vector::ShapeCastOp::create(
+ rewriter, loc, preTransposeWriteVecType, maskedRead);
+
+ // Create TransposeOp.
+ auto destPermutation = invertPermutationVector(destInvPermutation);
+ auto transposeOp = vector::TransposeOp::create(
+ rewriter, loc, shapeCastOp.getResult(), destPermutation);
+
+ // Create TransferWriteOp.
+ Operation *write = createWriteOrMaskedWrite(
+ rewriter, loc, transposeOp.getResult(), packOp.getDest());
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
/// Vectorize `linalg.unpack` as:
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
///
-/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
-/// for the xfer_read operation). This is sufficient to infer the other vector
-/// sizes required here.
+/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
+/// sizes for the xfer_read operation). This is sufficient to infer the other
+/// vector sizes required here.
///
/// If the vector sizes are not provided:
/// * the vector sizes are determined from the input tensor static shape.
@@ -1960,7 +1969,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
if (inputVectorSizes.empty()) {
if (ShapedType::isDynamicShape(sourceShape))
- return failure();
+ return rewriter.notifyMatchFailure(unpackOp,
+ "Unable to infer vector sizes!");
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
useInBoundsInsteadOfMasking = true;
@@ -2443,6 +2453,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
Attribute cstAttr;
+ // TODO: Relax this condiiton
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
LDBG() << "pad value is not constant: " << packOp;
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722..6eeb206 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -171,29 +171,24 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
namespace mlir {
namespace linalg {
-SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
+SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp,
+ PackingMetadata &metadata) {
- PackingMetadata pMetadata;
int64_t packedRank = packOp.getDestType().getRank();
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
SmallVector<int64_t> packInvDestPerm =
- computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
+ computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
return packInvDestPerm;
}
-SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
- PackingMetadata metadata;
- return getUnPackInverseSrcPerm(unpackOp, metadata);
-}
-
SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
PackingMetadata &metadata) {
- int64_t unpackRank = unpackOp.getSourceType().getRank();
+ int64_t packedRank = unpackOp.getSourceType().getRank();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
SmallVector<int64_t> unpackInvSrcPerm =
- computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
+ computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
return unpackInvSrcPerm;
}
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6..46b8251 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 397107b..fb5d1e7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -280,27 +280,82 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
FailureOr<SmallVector<Value>>
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
- // TODO: handle order attribute
- auto hasDefaultOrder = [&]() {
- DenseI32ArrayAttr order = getOrder();
- return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
- llvm::reverse(order.asArrayRef())));
- };
- if (!hasDefaultOrder())
- return mlir::emitError(loc, "order attribute is currently not supported.");
- SmallVector<int64_t> layout;
+ SmallVector<int64_t> sgLayoutInt;
if (isForWorkgroup()) {
- layout = getEffectiveSgLayoutAsInt();
+ sgLayoutInt = getEffectiveSgLayoutAsInt();
} else if (isForSubgroup()) {
- layout = getEffectiveLaneLayoutAsInt();
+ sgLayoutInt = getEffectiveLaneLayoutAsInt();
} else {
return failure();
}
- auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
- return affine::delinearizeIndex(builder, loc, linearId, dims);
+ DenseI32ArrayAttr orderAttr = getOrder();
+
+ // Handle order attribute
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::to_vector(
+ llvm::map_range(orderAttr.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+ } else {
+ // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+ order = llvm::to_vector(
+ llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+ }
+
+ if (order.size() != sgLayoutInt.size()) {
+ return failure();
+ }
+
+ SmallVector<Value> result(sgLayoutInt.size());
+ Value remaining = linearId;
+
+ /// Process dimensions in the order they appear in the order array
+ /// The first dimension in order is the fastest-changing
+ ///
+ /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+ ///
+ /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
+ /// result=[?,?,?]
+ ///
+ /// i=0 (process columns, dimIdx=2, dimSize=4):
+ /// result[2] = 22 % 4 = 2 (column coordinate)
+ /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
+ ///
+ /// i=1 (process rows, dimIdx=1, dimSize=4):
+ /// result[1] = 5 % 4 = 1 (row coordinate)
+ /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
+ ///
+ /// i=2 (process layers, dimIdx=0, dimSize=2):
+ /// result[0] = 1 % 2 = 1 (layer coordinate)
+ /// (no remaining update - last iteration)
+ ///
+ /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ int64_t dimSize = sgLayoutInt[dimIdx];
+
+ Value dimSizeVal =
+ builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+ /// Extract the coordinate for this dimension using modulo operation
+ /// This gives us "how far within this dimension" we are
+ /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
+ /// this dimension)
+ result[dimIdx] =
+ builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
+
+ /// Update remaining for the next dimension by removing what we've already
+ /// processed. Division tells us "how many complete groups of this dimension
+ /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+ /// completed 5 groups of 4) Skip this for the last iteration since there's
+ /// no next dimension to process
+ if (i < order.size() - 1) {
+ remaining =
+ builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
+ }
+ }
+ return result;
}
/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000..48fe841
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+ XeGPUTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+ DEPENDS
+ MLIRXeGPUTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRXeGPUDialect
+ MLIRXeGPUTransforms
+ MLIRIR
+ MLIRTransformDialect
+ MLIRFuncDialect
+ MLIRSCFDialect
+)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000..8943ba0
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,225 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+
+#include <optional>
+
+using namespace mlir;
+using namespace mlir::transform;
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+ transform::TransformState &state, TransformOpInterface transformOp,
+ SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+ for (OpFoldResult ofr : ofrs) {
+ // Attribute case.
+ if (auto attr = dyn_cast<Attribute>(ofr)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ result.push_back(intAttr.getInt());
+ continue;
+ }
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ }
+
+ // Transform param case.
+ Value transformValue = cast<Value>(ofr);
+ if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure()
+ << "requires exactly one parameter associated";
+ result.push_back(
+ cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+ continue;
+ }
+
+ // Payload value case.
+ auto payloadOps = state.getPayloadOps(transformValue);
+ if (!llvm::hasSingleElement(payloadOps)) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "handle must be mapped to exactly one payload op";
+ diag.attachNote(transformValue.getLoc())
+ << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+ return diag;
+ }
+
+ Operation *op = *payloadOps.begin();
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+
+ IntegerAttr intAttr;
+ if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+ return transformOp.emitSilenceableError()
+ << "requires param or handle to be the result of a constant like "
+ "op";
+
+ result.push_back(intAttr.getInt());
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Create a layout attribute from the given parameters.
+static xegpu::LayoutAttr
+createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+ ArrayRef<int32_t> sgData,
+ std::optional<ArrayRef<int32_t>> instData) {
+ return xegpu::LayoutAttr::get(
+ ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+ DenseI32ArrayAttr::get(ctx, sgData),
+ instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+ /*lane_layout=*/nullptr,
+ /*lane_data=*/nullptr,
+ /*order=*/nullptr);
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+static xegpu::CreateNdDescOp
+setDescLayout(transform::TransformRewriter &rewriter,
+ xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
+ assert(descOp.getMixedOffsets().size() == 0 &&
+ "create desc op with offsets is not supported");
+ auto oldTensorDesc = descOp.getType();
+ auto descType = xegpu::TensorDescType::get(
+ oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
+ /*array_length=*/oldTensorDesc.getArrayLength(),
+ /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
+ /*memory_space=*/oldTensorDesc.getMemorySpace(),
+ /*layout=*/layout);
+
+ rewriter.setInsertionPointAfter(descOp);
+ auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+ descOp.getMixedStrides());
+ return newDescOp;
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+ OperationState &result, Value target,
+ ArrayRef<OpFoldResult> mixedSgLayout,
+ ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, result, target.getType(),
+ /*target=*/target,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ SmallVector<int32_t> sgLayout;
+ DiagnosedSilenceableFailure status =
+ convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
+ if (!status.succeeded())
+ return status;
+
+ SmallVector<int32_t> sgData;
+ status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
+ if (!status.succeeded())
+ return status;
+
+ SmallVector<int32_t> instData;
+ status =
+ convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
+ if (!status.succeeded())
+ return status;
+ auto maybeInstData = instData.empty()
+ ? std::nullopt
+ : std::optional<ArrayRef<int32_t>>(instData);
+
+ // For now only create_nd_desc op is supported.
+ auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+ if (!descOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a xegpu.create_nd_desc op, but got: "
+ << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Set layout attr in desc op's return type. Replaces old desc op.
+ auto layoutAttr =
+ createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+ auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+
+ // Map result handles.
+ results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+namespace {
+class XeGPUTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ XeGPUTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+ using Base::Base;
+
+ void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+ declareGeneratedDialect<scf::SCFDialect>();
+ declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+ >();
+}
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+ registry.addExtensions<XeGPUTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d12a04df..0a9ef0a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1219,6 +1219,70 @@ struct WgToSgMultiDimReductionOp
}
};
+// This pattern transforms vector.transpose ops to work at subgroup level.
+struct WgToSgVectorTransposeOp
+ : public OpConversionPattern<vector::TransposeOp> {
+ using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(op.getVector());
+ if (!sourceLayout || !sourceLayout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sourceSgLayout =
+ sourceLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
+ DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
+ DenseI32ArrayAttr resultOrder = layout.getOrder();
+
+ if (!sourceOrder || !resultOrder) {
+ return rewriter.notifyMatchFailure(
+ op, "Both source and result must have order attributes");
+ }
+
+ ArrayRef<int64_t> permutation = op.getPermutation();
+ size_t permutationSize = permutation.size();
+ if (sourceSgLayout.size() != permutationSize ||
+ resultSgLayout.size() != permutationSize) {
+ return rewriter.notifyMatchFailure(
+ op, "Layouts and permutation must have the same rank");
+ }
+
+ // Check that sgLayout, sgData & order are properly transposed for source
+ // and result
+ if (!layout.isTransposeOf(sourceLayout, permutation))
+ return rewriter.notifyMatchFailure(
+ op, "Result layout is not a valid transpose of source layout "
+ "according to permutation");
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+ SmallVector<Value> newTransposeOps;
+ for (auto src : adaptor.getVector()) {
+ auto newTranspose = vector::TransposeOp::create(
+ rewriter, op.getLoc(), newResultType, src, permutation);
+ xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
+ layout.dropSgLayoutAndData());
+ newTransposeOps.push_back(newTranspose.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newTransposeOps});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -1233,7 +1297,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
- WgToSgMultiDimReductionOp>(patterns.getContext());
+ WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -1360,7 +1425,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+ vector::TransposeOp, vector::BroadcastOp,
+ vector::MultiDimReductionOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
@@ -1379,16 +1446,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::BroadcastOp>(
- [=](vector::BroadcastOp op) -> bool {
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
- });
-
- target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
- [=](vector::MultiDimReductionOp op) -> bool {
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
- });
-
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());