diff options
Diffstat (limited to 'mlir/lib/Dialect')
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 82 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 5 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 219 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 15 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 85 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt | 17 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp | 225 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 81 |
9 files changed, 576 insertions, 154 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 262d9b7..d43f881 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1752,15 +1752,21 @@ std::string NVVM::MBarrierInitOp::getPtx() { // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// +static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { + auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); + return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); +} + +static bool isPtrInSharedCTASpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +} + mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierInitOp>(op); - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) - .getAddressSpace(); - llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) - ? llvm::Intrinsic::nvvm_mbarrier_init_shared - : llvm::Intrinsic::nvvm_mbarrier_init; + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared + : llvm::Intrinsic::nvvm_mbarrier_init; // Fill the Intrinsic Args llvm::SmallVector<llvm::Value *> args; @@ -1773,16 +1779,72 @@ mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierInvalOp>(op); - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) - .getAddressSpace(); - llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_inval_shared : llvm::Intrinsic::nvvm_mbarrier_inval; return {id, {mt.lookupValue(thisOp.getAddr())}}; } +mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared + ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive; + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared + ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared + : llvm::Intrinsic::nvvm_mbarrier_test_wait; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getState())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + + llvm::Intrinsic::ID id; + if (thisOp.getNoinc()) { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc; + } else { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive; + } + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index bd25e94..027268c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -232,10 +232,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, // 2. Compute the permutation vector to shuffle packed shape into the shape // before any outer or inner permutations have been applied. - PackingMetadata packingMetadata = computePackingMetadata( - packedTensorType.getRank(), packOp.getInnerDimsPos()); + PackingMetadata packingMetadata; SmallVector<int64_t> packedToStripMinedShapePerm = - getPackInverseDestPerm(packOp); + getPackInverseDestPerm(packOp, packingMetadata); // 3. Compute the stripMinedShape: this is the packed shape before any outer // or inner permutations have been applied. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index cb6199f..19d2d85 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, return success(); } -/// Given a linalg::PackOp, return the `dest` shape before any packing -/// permutations. -static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp, - ArrayRef<int64_t> destShape) { - return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp)); -} - /// Determines whether a mask for xfer_write is trivially "all true" /// /// Given all the inputs required to generate a mask (mask sizes and shapes), @@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, return mlir::vector::maskOperation(builder, write, maskForWrite); } -/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant -/// padding value and (3) input vector sizes into: -/// -/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds -/// -/// As in the following example: -/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] -/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> -/// -/// This pack would be vectorized to: -/// -/// %load = vector.mask %mask { -/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst -/// {in_bounds = [true, true, true]} : -/// tensor<32x7x16xf32>, vector<32x8x16xf32> -/// } : vector<32x8x16xi1> -> vector<32x8x16xf32> -/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> -/// to vector<32x4x2x1x16xf32> -/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] -/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> -/// %write = vector.transfer_write %transpose, -/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] -/// {in_bounds = [true, true, true, true, true]} -/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> -/// -/// If the (3) input vector sizes are not provided, the vector sizes are -/// determined by the result tensor shape and the `in_bounds` -/// attribute is used instead of masking to mark out-of-bounds accesses. -/// -/// NOTE: The input vector sizes specify the dimensions corresponding to the -/// outer dimensions of the output tensor. The remaining dimensions are -/// computed based on, e.g., the static inner tiles. -/// Supporting dynamic inner tiles will require the user to specify the -/// missing vector sizes. This is left as a TODO. -static LogicalResult -vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, - ArrayRef<int64_t> inputVectorSizes, - SmallVectorImpl<Value> &newResults) { - // TODO: Introduce a parent class that will handle the insertion point update. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(packOp); - - Location loc = packOp.getLoc(); - std::optional<Value> padValue = packOp.getPaddingValue() - ? std::optional(packOp.getPaddingValue()) - : std::nullopt; - - // If the input vector sizes are not provided, then the vector sizes are - // determined by the result tensor shape. In case the vector sizes aren't - // provided, we update the inBounds attribute instead of masking. - bool useInBoundsInsteadOfMasking = false; - if (inputVectorSizes.empty()) { - ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); - inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank()); - useInBoundsInsteadOfMasking = true; - } - - // Create masked TransferReadOp. - SmallVector<int64_t> inputShape(inputVectorSizes); - auto innerTiles = packOp.getStaticInnerTiles(); - auto innerDimsPos = packOp.getInnerDimsPos(); - auto outerDimsPerm = packOp.getOuterDimsPerm(); - if (!outerDimsPerm.empty()) - applyPermutationToVector(inputShape, - invertPermutationVector(outerDimsPerm)); - for (auto [idx, size] : enumerate(innerTiles)) - inputShape[innerDimsPos[idx]] *= size; - auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, packOp.getSource(), inputShape, padValue, - useInBoundsInsteadOfMasking, - /*inputScalableVecSizes=*/{}); - - // Create ShapeCastOp. - SmallVector<int64_t> destShape(inputVectorSizes); - destShape.append(innerTiles.begin(), innerTiles.end()); - auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), - packOp.getDestType().getElementType()); - auto shapeCastOp = - vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead); - - // Create TransposeOp. - auto destPermutation = - invertPermutationVector(getPackInverseDestPerm(packOp)); - auto transposeOp = vector::TransposeOp::create( - rewriter, loc, shapeCastOp.getResult(), destPermutation); - - // Create TransferWriteOp. - Operation *write = createWriteOrMaskedWrite( - rewriter, loc, transposeOp.getResult(), packOp.getDest()); - newResults.push_back(write->getResult(0)); - return success(); -} - /// Given the re-associations, "collapses" the input Vector type /// /// This is similar to CollapseShapeOp::inferCollapsedType with two notable @@ -1901,12 +1801,121 @@ static VectorType getCollapsedVecType(VectorType type, return VectorType::get(newShape, type.getElementType(), newScalableFlags); } +/// Vectorize `linalg.pack` as: +/// * xfer_read -> shape_cast -> transpose -> xfer_write +/// +/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector +/// sizes for the xfer_write operation). This is sufficient to infer the other +/// vector sizes required here. +/// +/// If the vector sizes are not provided: +/// * the vector sizes are determined from the destination tensor static shape. +/// * the inBounds attribute is used instead of masking. +/// +/// EXAMPLE (no vector sizes): +/// ``` +/// %pack = tensor.pack %src +/// inner_dims_pos = [2, 1] +/// inner_tiles = [16, 2] +/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> +/// `` +/// is vectorizes as: +/// ``` +/// %read = vector.transfer_read %src +/// : tensor<32x7x16xf32>, vector<32x8x16xf32> +/// %sc = vector.shape_cast %read +/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> +/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2] +/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> +/// %write = vector.transfer_write %tr into %dest +/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> +/// ``` +static LogicalResult +vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, + ArrayRef<int64_t> inputVectorSizes, + SmallVectorImpl<Value> &newResults) { + if (!inputVectorSizes.empty()) { + assert(inputVectorSizes.size() == packOp.getDestRank() && + "Invalid number of input vector sizes!"); + } + + // TODO: Introduce a parent class that will handle the insertion point update. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(packOp); + + Location loc = packOp.getLoc(); + std::optional<Value> padValue = packOp.getPaddingValue() + ? std::optional(packOp.getPaddingValue()) + : std::nullopt; + + SmallVector<int64_t> destShape = + SmallVector<int64_t>(packOp.getDestType().getShape()); + + // This is just a convenience alias to clearly communicate that the input + // vector sizes determine the _write_ sizes. + ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes; + + // In the absence of input-vector-sizes, use the _static_ input tensor shape. + // In addition, use the inBounds attribute instead of masking. + bool useInBoundsInsteadOfMasking = false; + if (writeVectorSizes.empty()) { + if (ShapedType::isDynamicShape(destShape)) + return rewriter.notifyMatchFailure(packOp, + "unable to infer vector sizes"); + + writeVectorSizes = destShape; + useInBoundsInsteadOfMasking = true; + } + + // Compute pre-transpose-write-vector-type, i.e. the write vector type + // _before_ the transposition (i.e. before dimension permutation). This is + // done by inverting the permutation/transposition that's part of the Pack + // operation. This type is required to: + // 1) compute the read vector type for masked-read below, and + // 2) generate shape-cast Op below that expands the read vector type. + PackingMetadata packMetadata; + SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes); + auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata); + applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation); + auto preTransposeWriteVecType = VectorType::get( + preTransposeWriteVecSizses, packOp.getType().getElementType()); + + // Compute vector type for the _read_ opeartion. This is simply + // pre-transpose-write-vector-type with the dimensions collapsed + // as per the Pack operation. + VectorType readVecType = getCollapsedVecType( + preTransposeWriteVecType, + getSymbolLessAffineMaps(convertReassociationIndicesToExprs( + rewriter.getContext(), packMetadata.reassociations))); + + // Create masked TransferReadOp. + auto maskedRead = vector::createReadOrMaskedRead( + rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue, + useInBoundsInsteadOfMasking, + /*inputScalableVecSizes=*/{}); + + // Create ShapeCastOp. + auto shapeCastOp = vector::ShapeCastOp::create( + rewriter, loc, preTransposeWriteVecType, maskedRead); + + // Create TransposeOp. + auto destPermutation = invertPermutationVector(destInvPermutation); + auto transposeOp = vector::TransposeOp::create( + rewriter, loc, shapeCastOp.getResult(), destPermutation); + + // Create TransferWriteOp. + Operation *write = createWriteOrMaskedWrite( + rewriter, loc, transposeOp.getResult(), packOp.getDest()); + newResults.push_back(write->getResult(0)); + return success(); +} + /// Vectorize `linalg.unpack` as: /// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write /// -/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes -/// for the xfer_read operation). This is sufficient to infer the other vector -/// sizes required here. +/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector +/// sizes for the xfer_read operation). This is sufficient to infer the other +/// vector sizes required here. /// /// If the vector sizes are not provided: /// * the vector sizes are determined from the input tensor static shape. @@ -1960,7 +1969,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, // In the absence of input-vector-sizes, use the _static_ input tensor shape. if (inputVectorSizes.empty()) { if (ShapedType::isDynamicShape(sourceShape)) - return failure(); + return rewriter.notifyMatchFailure(unpackOp, + "Unable to infer vector sizes!"); readVectorSizes.assign(sourceShape.begin(), sourceShape.end()); useInBoundsInsteadOfMasking = true; @@ -2443,6 +2453,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef<int64_t> inputVectorSizes) { auto padValue = packOp.getPaddingValue(); Attribute cstAttr; + // TODO: Relax this condiiton if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) { LDBG() << "pad value is not constant: " << packOp; return failure(); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 24d3722..6eeb206 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -171,29 +171,24 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos, namespace mlir { namespace linalg { -SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) { +SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp, + PackingMetadata &metadata) { - PackingMetadata pMetadata; int64_t packedRank = packOp.getDestType().getRank(); ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos(); ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm(); SmallVector<int64_t> packInvDestPerm = - computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata); + computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata); return packInvDestPerm; } -SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) { - PackingMetadata metadata; - return getUnPackInverseSrcPerm(unpackOp, metadata); -} - SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp, PackingMetadata &metadata) { - int64_t unpackRank = unpackOp.getSourceType().getRank(); + int64_t packedRank = unpackOp.getSourceType().getRank(); ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm(); SmallVector<int64_t> unpackInvSrcPerm = - computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata); + computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata); return unpackInvSrcPerm; } diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6..46b8251 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 397107b..fb5d1e7 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -280,27 +280,82 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, FailureOr<SmallVector<Value>> LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { - // TODO: handle order attribute - auto hasDefaultOrder = [&]() { - DenseI32ArrayAttr order = getOrder(); - return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( - llvm::reverse(order.asArrayRef()))); - }; - if (!hasDefaultOrder()) - return mlir::emitError(loc, "order attribute is currently not supported."); - SmallVector<int64_t> layout; + SmallVector<int64_t> sgLayoutInt; if (isForWorkgroup()) { - layout = getEffectiveSgLayoutAsInt(); + sgLayoutInt = getEffectiveSgLayoutAsInt(); } else if (isForSubgroup()) { - layout = getEffectiveLaneLayoutAsInt(); + sgLayoutInt = getEffectiveLaneLayoutAsInt(); } else { return failure(); } - auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value { - return builder.createOrFold<arith::ConstantIndexOp>(loc, d); - }); - return affine::delinearizeIndex(builder, loc, linearId, dims); + DenseI32ArrayAttr orderAttr = getOrder(); + + // Handle order attribute + SmallVector<int64_t> order; + if (orderAttr && !orderAttr.empty()) { + order = llvm::to_vector( + llvm::map_range(orderAttr.asArrayRef(), + [](int32_t idx) { return static_cast<int64_t>(idx); })); + } else { + // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc. + order = llvm::to_vector( + llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size()))); + } + + if (order.size() != sgLayoutInt.size()) { + return failure(); + } + + SmallVector<Value> result(sgLayoutInt.size()); + Value remaining = linearId; + + /// Process dimensions in the order they appear in the order array + /// The first dimension in order is the fastest-changing + /// + /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]: + /// + /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx], + /// result=[?,?,?] + /// + /// i=0 (process columns, dimIdx=2, dimSize=4): + /// result[2] = 22 % 4 = 2 (column coordinate) + /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed) + /// + /// i=1 (process rows, dimIdx=1, dimSize=4): + /// result[1] = 5 % 4 = 1 (row coordinate) + /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed) + /// + /// i=2 (process layers, dimIdx=0, dimSize=2): + /// result[0] = 1 % 2 = 1 (layer coordinate) + /// (no remaining update - last iteration) + /// + /// Final result: [1,1,2] = Layer 1, Row 1, Column 2 + for (size_t i = 0; i < order.size(); ++i) { + int64_t dimIdx = order[i]; + int64_t dimSize = sgLayoutInt[dimIdx]; + + Value dimSizeVal = + builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize); + + /// Extract the coordinate for this dimension using modulo operation + /// This gives us "how far within this dimension" we are + /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within + /// this dimension) + result[dimIdx] = + builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal); + + /// Update remaining for the next dimension by removing what we've already + /// processed. Division tells us "how many complete groups of this dimension + /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've + /// completed 5 groups of 4) Skip this for the last iteration since there's + /// no next dimension to process + if (i < order.size() - 1) { + remaining = + builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal); + } + } + return result; } /// Implements DistributeLayoutAttr::computeDistributedCoords to generate diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..48fe841 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRXeGPUTransformOps + XeGPUTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/ + + DEPENDS + MLIRXeGPUTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRXeGPUDialect + MLIRXeGPUTransforms + MLIRIR + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp new file mode 100644 index 0000000..8943ba0 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -0,0 +1,225 @@ +//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" + +#include <optional> + +using namespace mlir; +using namespace mlir::transform; + +/// Assuming that `ofr` is an index attr or a param of index type +/// or a transform dialect handle mapped to exactly one op +/// with one index result, get that value and cast it to int type. +static DiagnosedSilenceableFailure convertMixedValuesToInt( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) { + for (OpFoldResult ofr : ofrs) { + // Attribute case. + if (auto attr = dyn_cast<Attribute>(ofr)) { + if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { + result.push_back(intAttr.getInt()); + continue; + } + return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; + } + + // Transform param case. + Value transformValue = cast<Value>(ofr); + if (isa<TransformParamTypeInterface>(transformValue.getType())) { + ArrayRef<Attribute> params = state.getParams(transformValue); + if (params.size() != 1) + return transformOp.emitDefiniteFailure() + << "requires exactly one parameter associated"; + result.push_back( + cast<IntegerAttr>(params.front()).getValue().getSExtValue()); + continue; + } + + // Payload value case. + auto payloadOps = state.getPayloadOps(transformValue); + if (!llvm::hasSingleElement(payloadOps)) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "handle must be mapped to exactly one payload op"; + diag.attachNote(transformValue.getLoc()) + << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; + return diag; + } + + Operation *op = *payloadOps.begin(); + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + + IntegerAttr intAttr; + if (!matchPattern(op->getResult(0), m_Constant(&intAttr))) + return transformOp.emitSilenceableError() + << "requires param or handle to be the result of a constant like " + "op"; + + result.push_back(intAttr.getInt()); + } + return DiagnosedSilenceableFailure::success(); +} + +/// Create a layout attribute from the given parameters. +static xegpu::LayoutAttr +createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, + ArrayRef<int32_t> sgData, + std::optional<ArrayRef<int32_t>> instData) { + return xegpu::LayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, sgLayout), + DenseI32ArrayAttr::get(ctx, sgData), + instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, + /*lane_layout=*/nullptr, + /*lane_data=*/nullptr, + /*order=*/nullptr); +} + +/// Replace xegpu.create_nd_desc op with a new one with the given layout. +static xegpu::CreateNdDescOp +setDescLayout(transform::TransformRewriter &rewriter, + xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) { + assert(descOp.getMixedOffsets().size() == 0 && + "create desc op with offsets is not supported"); + auto oldTensorDesc = descOp.getType(); + auto descType = xegpu::TensorDescType::get( + oldTensorDesc.getShape(), oldTensorDesc.getElementType(), + /*array_length=*/oldTensorDesc.getArrayLength(), + /*boundary_check=*/oldTensorDesc.getBoundaryCheck(), + /*memory_space=*/oldTensorDesc.getMemorySpace(), + /*layout=*/layout); + + rewriter.setInsertionPointAfter(descOp); + auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>( + descOp, descType, descOp.getSource(), descOp.getMixedSizes(), + descOp.getMixedStrides()); + return newDescOp; +} + +void transform::SetDescLayoutOp::build(OpBuilder &builder, + OperationState &result, Value target, + ArrayRef<OpFoldResult> mixedSgLayout, + ArrayRef<OpFoldResult> mixedSgData, + ArrayRef<OpFoldResult> mixedInstData) { + SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; + SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, result, target.getType(), + /*target=*/target, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData); +} + +DiagnosedSilenceableFailure +transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + SmallVector<int32_t> sgLayout; + DiagnosedSilenceableFailure status = + convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout()); + if (!status.succeeded()) + return status; + + SmallVector<int32_t> sgData; + status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData()); + if (!status.succeeded()) + return status; + + SmallVector<int32_t> instData; + status = + convertMixedValuesToInt(state, (*this), instData, getMixedInstData()); + if (!status.succeeded()) + return status; + auto maybeInstData = instData.empty() + ? std::nullopt + : std::optional<ArrayRef<int32_t>>(instData); + + // For now only create_nd_desc op is supported. + auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); + if (!descOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a xegpu.create_nd_desc op, but got: " + << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + // Set layout attr in desc op's return type. Replaces old desc op. + auto layoutAttr = + createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); + auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); + + // Map result handles. + results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetDescLayoutOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +namespace { +class XeGPUTransformDialectExtension + : public transform::TransformDialectExtension< + XeGPUTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) + + using Base::Base; + + void init(); +}; + +void XeGPUTransformDialectExtension::init() { + declareGeneratedDialect<scf::SCFDialect>(); + declareGeneratedDialect<arith::ArithDialect>(); + declareGeneratedDialect<xegpu::XeGPUDialect>(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + >(); +} +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + +void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions<XeGPUTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index d12a04df..0a9ef0a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1219,6 +1219,70 @@ struct WgToSgMultiDimReductionOp } }; +// This pattern transforms vector.transpose ops to work at subgroup level. +struct WgToSgVectorTransposeOp + : public OpConversionPattern<vector::TransposeOp> { + using OpConversionPattern<vector::TransposeOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + + ArrayRef<int64_t> wgShape = resultType.getShape(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(op.getVector()); + if (!sourceLayout || !sourceLayout.isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sourceSgLayout = + sourceLayout.getEffectiveSgLayoutAsInt(); + SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt(); + DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); + DenseI32ArrayAttr resultOrder = layout.getOrder(); + + if (!sourceOrder || !resultOrder) { + return rewriter.notifyMatchFailure( + op, "Both source and result must have order attributes"); + } + + ArrayRef<int64_t> permutation = op.getPermutation(); + size_t permutationSize = permutation.size(); + if (sourceSgLayout.size() != permutationSize || + resultSgLayout.size() != permutationSize) { + return rewriter.notifyMatchFailure( + op, "Layouts and permutation must have the same rank"); + } + + // Check that sgLayout, sgData & order are properly transposed for source + // and result + if (!layout.isTransposeOf(sourceLayout, permutation)) + return rewriter.notifyMatchFailure( + op, "Result layout is not a valid transpose of source layout " + "according to permutation"); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + SmallVector<Value> newTransposeOps; + for (auto src : adaptor.getVector()) { + auto newTranspose = vector::TransposeOp::create( + rewriter, op.getLoc(), newResultType, src, permutation); + xegpu::setDistributeLayoutAttr(newTranspose->getResult(0), + layout.dropSgLayoutAndData()); + newTransposeOps.push_back(newTranspose.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newTransposeOps}); + return success(); + } +}; + } // namespace namespace mlir { @@ -1233,7 +1297,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp>(patterns.getContext()); + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1360,7 +1425,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>( + target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, + vector::TransposeOp, vector::BroadcastOp, + vector::MultiDimReductionOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); @@ -1379,16 +1446,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::BroadcastOp>( - [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - - target.addDynamicallyLegalOp<vector::MultiDimReductionOp>( - [=](vector::MultiDimReductionOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); |
