//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Transforms/DialectConversion.h" #include namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" } // namespace xegpu } // namespace mlir using namespace mlir; namespace { // Retrieve the RangeAttr if it is specified. static xegpu::RangeAttr getRangeSpecAttr(Operation *op) { Operation *parent = op->getParentOfType(); while (parent) { if (auto attr = llvm::dyn_cast_if_present( parent->getAttr("sg_id_range"))) return attr; parent = parent->getParentOfType(); } return {}; } static std::pair, int> getSgShapeAndCount(ArrayRef shape, xegpu::DistributeLayoutAttr layout) { int count = 1; SmallVector sgShape(shape); if (layout && layout.isForWorkgroup()) { SmallVector sgLayout = layout.getEffectiveSgLayoutAsInt(); if (!layout.getEffectiveSgDataAsInt().empty()) sgShape = layout.getEffectiveSgDataAsInt(); else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) sgShape = *maybeDerivedSgData; SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); // Clamp distUnit to the original shape to handle cases where data is // shared among subgroups, which may cause distUnit to exceed the original // shape. for (size_t i = 0; i < distUnit.size(); ++i) distUnit[i] = std::min(shape[i], distUnit[i]); count = computeProduct(shape) / computeProduct(distUnit); } return std::make_pair(sgShape, count); } /// Utility helper for deriving a list of offsets for each sub-TensorDescs /// or sub-MemDescs to be accessed by current subgroup (sgId) based on the /// associated distribute layout attribute, the shape, subgroup id and the /// original offsets of the op template < typename OpType, typename = std::enable_if_t::value>> static LogicalResult genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, SmallVector> &offsetsList) { Location loc = op.getLoc(); SmallVector origOffsets = op.getMixedOffsets(); // not applicable to ops without offsets operands. if (origOffsets.empty()) return failure(); // not applicable to ops without workgroup layout attributes xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (!layout || !layout.isForWorkgroup()) return failure(); Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); // verify and adjust the sgId if the range specifier is present xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op); if (sgIdRange) { int64_t startOfRange = sgIdRange.getStart().getInt(); int64_t endOfRange = sgIdRange.getEnd().getInt(); // verify the RangeAttr against the layout attribute if (layout.getNumSubgroups() != endOfRange - startOfRange) return rewriter.notifyMatchFailure( op, "sg_layout size must match the sg_id_range"); // adjust the sgId if necessary if (startOfRange > 0) { Value startOfRangeVal = arith::ConstantIndexOp::create(rewriter, loc, startOfRange); sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal); } } // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory // descriptors to be accessed, based on the layout information. ArrayRef wgShape = op.getDataShape(); auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); if (failed(maybeDescOffsets)) return failure(); // Compute the final global offsets for each accessed sub-tensor // or sub-memory descriptor. for (const auto &sgOffsets : *maybeDescOffsets) { SmallVector newOffsets = xegpu::addWithRightAligned( rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); offsetsList.push_back(std::move(newOffsets)); } // callback(offsetsList); return success(); } /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. /// It uses round-robin assignment to distribute the work to the subgroups. /// Following create_nd_desc operation:, /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32> /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout> /// is converted to 9 subgroup level operations based on the sg_layout & /// sg_data: /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> -> /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout> /// /// The sg_layout and sg_data attributes are dropped after the pass as they are /// no longer needed. /// /// 24x24 matrix distribution example: /// sg_layout = [4, 4], sg_data = [2, 2] /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit. /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i] /// /// +------------------------+ /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across /// |-----+-----+-----| /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down /// |-----+-----+-----| /// | 8x8 | 8x8 | 8x8 | /// +------------------------+ /// /// Each 8x8 tile is further subdivided among subgroups: /// +------------------------+ /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns) /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows) /// | 2x2 2x2 2x2 2x2 | /// | 2x2 2x2 2x2 2x2 | /// +------------------------+ /// /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations. /// The pass currently has entire distribution logic in the WgToSgCreateNdOp /// pattern and all the other ops just follow. /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the /// ops in the pass. struct WgToSgCreateNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); ArrayRef wgShape = tdescTy.getShape(); Type elemTy = tdescTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; auto newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), layout.dropSgLayoutAndData()); SmallVector newOps; for (auto offsets : offsetsList) { auto newOp = xegpu::CreateNdDescOp::create( rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets, op.getMixedSizes(), op.getMixedStrides()); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; // This pattern transforms the CreateNdDescOp without offsets to create a // subgroup descriptor from a workgroup descriptor struct WgToSgCreateNdOpNoOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check no offsets are specified. if (!op.getMixedOffsets().empty()) return failure(); Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); auto layout = dyn_cast(tdescTy.getLayout()); if (!layout || !layout.isForWorkgroup()) return failure(); Type elemTy = tdescTy.getElementType(); ArrayRef wgShape = tdescTy.getShape(); SmallVector sgShape; int count; std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); xegpu::TensorDescType newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), layout.dropSgLayoutAndData()); SmallVector newCreateNdOps(count); std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() { return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(), op.getMixedStrides()); }); rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); return success(); } }; /// This pattern transforms the LoadNdOp to load subgroup data. struct WgToSgLoadNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getMixedOffsets().empty()) return failure(); SmallVector newLoadOps; for (auto src : adaptor.getTensorDesc()) { xegpu::TensorDescType tdescTy = dyn_cast(src.getType()); ArrayRef srcShape = tdescTy.getShape(); VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType()); auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy, src, op->getAttrs()); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); return mlir::success(); } }; /// This pattern transforms the StoreNdOp to store to a subgroup descriptor /// It creates a StoreNdOp op to store the updated values to the new subgroup /// src tensor descriptors. struct WgToSgStoreNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getMixedOffsets().empty()) return failure(); for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); rewriter.eraseOp(op); return success(); } }; // This pattern transforms the LoadNdOp with explicit offsets to load // subgroup data. struct WgToSgLoadNdOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); SmallVector newOps; for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { auto tdescTy = dyn_cast(tdesc.getType()); VectorType newResTy = VectorType::get(tdescTy.getShape(), tdescTy.getElementType()); auto newOp = xegpu::LoadNdOp::create( rewriter, op.getLoc(), newResTy, tdesc, offsets, /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; // This pattern transforms the StoreNdOp with explicit offsets to store // subgroup data. struct WgToSgStoreNdOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); for (auto [v, tdesc, offsets] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); return success(); } }; // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch // subgroup data. struct WgToSgPrefetchNdOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); return success(); } }; /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the /// offsets of the new subgroup src tensor descriptors. struct WgToSgUpdateNdOffsetOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { llvm::SmallVector newUpdateTileOffsetOps; for (auto tDesc : adaptor.getTensorDesc()) { auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create( rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(), op.getConstOffsets()); newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp); } rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps}); return success(); } }; /// This pattern transforms the DpasOp to work at subgroup level. struct WgToSgDpasOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType resultTy = op.getResult().getType(); if (resultTy.getRank() != 2) return failure(); auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!originalLayout) return failure(); size_t i = 0; SmallVector newDpasOps; for (auto aVec : adaptor.getLhs()) { for (auto bVec : adaptor.getRhs()) { llvm::SmallVector operands({aVec, bVec}); Value tmpC; if (op.getAcc()) { tmpC = adaptor.getAcc()[i++]; operands.push_back(tmpC); } ArrayRef aVecShape = llvm::cast(aVec.getType()).getShape(); ArrayRef bVecShape = llvm::cast(bVec.getType()).getShape(); VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands); xegpu::setDistributeLayoutAttr(cast(tmpC), originalLayout.dropSgLayoutAndData()); newDpasOps.push_back(tmpC); } } rewriter.replaceOpWithMultiple(op, {newDpasOps}); return success(); } }; /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data. struct WgToSgPrefetchNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { int64_t offsetSize = static_cast(op.getOffsets().size()); if ((offsetSize != 0) || op.getConstOffsetsAttr()) return failure(); for (auto src : adaptor.getTensorDesc()) xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src, op->getAttrs()); rewriter.eraseOp(op); return success(); } }; /// This pattern transforms vector.broadcast ops to work at subgroup level. struct WgToSgVectorBroadcastOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = op.getResult().getType(); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout)) return failure(); SmallVector newBroadcastOps; for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); return success(); } }; // This pattern transforms elementwise ops to work at subgroup level. struct WgToSgElementwiseOp : public ConversionPattern { WgToSgElementwiseOp(MLIRContext *ctx) : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Only match ops with elementwise trait and single result. if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) return failure(); auto resultType = dyn_cast(op->getResult(0).getType()); assert(resultType && "Expected result to be a VectorType"); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; size_t numVariants = operands.empty() ? 0 : operands.front().size(); if (llvm::any_of(operands, [&](const ValueRange &operandVec) { return operandVec.size() != numVariants; })) return failure(); SmallVector newResults; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); for (size_t i = 0; i < numVariants; ++i) { SmallVector opOperands; for (auto &operandVec : operands) opOperands.push_back(operandVec[i]); OperationState state(op->getLoc(), op->getName()); state.addOperands(opOperands); state.addTypes(newResultType); // Copy all attributes, but update "layout_result_0" to drop // sgLayout/sgData for (auto attr : op->getAttrs()) { if (auto layout = dyn_cast(attr.getValue())) { if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) state.addAttribute(attr.getName(), layout.dropSgLayoutAndData()); } else { state.addAttribute(attr.getName(), attr.getValue()); } } Operation *newOp = rewriter.create(state); newResults.push_back(newOp->getResult(0)); } rewriter.replaceOpWithMultiple(op, {newResults}); return success(); } }; // clang-format off // Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data. // If input_layout and target_layout have identical sg_layout and sg_data, // the op is rewritten to a subgroup-level ConvertLayoutOp with these fields // dropped. For example: // #a = #xegpu.layout // #b = #xegpu.layout // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32> // becomes: // #a = #xegpu.layout // #b = #xegpu.layout // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32> // (vector<16x16xf32> is determined by sg_data = [16, 16]) // // If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups. // For example: // #a = #xegpu.layout // #b = #xegpu.layout // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32> // is lowered to: // #a = #xegpu.layout // #b = #xegpu.layout // store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> // %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> // clang-format on struct WgToSgConvertLayoutOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO: currently, we only support LayoutAttr auto input = dyn_cast(op.getInputLayout()); auto target = dyn_cast(op.getTargetLayout()); if (!input || !target || !input.isForWorkgroup() || !target.isForWorkgroup()) return rewriter.notifyMatchFailure( op, "Input and target layouts must have subgroup layout"); DenseI32ArrayAttr inputSgLayout = input.getSgLayout(); DenseI32ArrayAttr inputSgData = input.getSgData(); DenseI32ArrayAttr inputOrder = input.getOrder(); DenseI32ArrayAttr targetSgLayout = target.getSgLayout(); DenseI32ArrayAttr targetSgData = target.getSgData(); DenseI32ArrayAttr targetOrder = target.getOrder(); // TODO: currently we only support for optimal case, where input and // output has the same sg_layout and sg_data, so SLM is not involved. if (inputSgLayout != targetSgLayout || inputSgData != targetSgData || inputOrder != targetOrder) return failure(); input = input.dropSgLayoutAndData(); target = target.dropSgLayoutAndData(); SmallVector newOps(adaptor.getSource()); if (input && target) { // keep the ConvertLayoutOp for rest fields, e.g., inst_data. for (auto [i, src] : llvm::enumerate(adaptor.getSource())) { auto newOp = xegpu::ConvertLayoutOp::create( rewriter, op.getLoc(), src.getType(), src, input, target); newOps[i] = newOp; } } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; // Handles UnrealizedConversionCastOp generated during // SCFStructuralTypeConversions (step 1). This op may appear as either a // target or source materialization for Vector values, e.g.: // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ... // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32> // it could be either 1:N or N:1 cast. In both cases, the pattern // simply forwards the inputs to the outputs using 1:1 or 1:N interface. // for example, the following scf::forOp // ``` // %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) { // %n = use(%arg1): vector<128x128xf16> // scf.yield %n : vector<128x128xf16> // } // ``` // Could be converted to: // ``` // %1 = unrealized_conversion_cast %0 // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> // %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2) // -> (vector<16x16xf16>, vector<16x16xf16) { // %m = unrealized_conversion_cast %arg1, %arg2 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> // %n = use(%m): vector<128x128xf16> // %b = unrealized_conversion_cast %n // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16> // scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16> // } // %cast = unrealized_conversion_cast %for:2 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16> // ``` // TODO: remove it when context-aware type converter is ready. struct UnrealizedConversionCastOpPattern : public OpConversionPattern { using OpConversionPattern< mlir::UnrealizedConversionCastOp>::OpConversionPattern; mlir::LogicalResult matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector inputs = xegpu::flattenValues(adaptor.getInputs()); auto inputTy = dyn_cast(inputs[0].getType()); auto outputTy = dyn_cast(op->getOpResult(0).getType()); if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) || !llvm::all_equal(ValueRange(inputs).getTypes())) return failure(); // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...". // It is generated by source materialization (e.g., inits to scf forOp). // The input values provided by the adaptor should already be distributed, // and their types should correspond exactly to the result types of the // operation. if (op.getNumOperands() == 1 && llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) { rewriter.replaceOp(op, inputs); return success(); } // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>". // It is generated by target materialization (e.g., arguments/results // of scf forOp). All input values must have the same vector type, and // their shape must be evenly divisible by the output vector's shape // (determined by the nature of the workgroup to subgroup distribution). // TODO: it is not safe to do such forward, since such N:1 cast could be // from others. if (op.getNumResults() == 1 && computeShapeRatio(outputTy.getShape(), inputTy.getShape())) { rewriter.replaceOpWithMultiple(op, {inputs}); return success(); } return mlir::failure(); } }; // This pattern distributes arith.constant op into subgroup-level constants struct WgToSgArithConstantOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto vecAttr = dyn_cast(op.getValue()); auto vecType = dyn_cast(op.getType()); if (!vecAttr || !vecAttr.isSplat() || !vecType) return failure(); xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); ArrayRef wgShape = vecType.getShape(); SmallVector sgShape; int count; std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); // Current limitation: constant of vector with single value. // TODO: support more complex cases, e.g., vector with multiple values. Attribute singleVal = vecAttr.getSplatValue(); auto newType = VectorType::get(sgShape, vecType.getElementType()); auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(cstOp->getResult(0), layout.dropSgLayoutAndData()); SmallVector newConsts(count, cstOp); rewriter.replaceOpWithMultiple(op, {newConsts}); return success(); } }; // This pattern transforms the LoadGatherOp with explicit offsets to load // subgroup data struct WgToSgLoadGatherOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getOffsets()) return failure(); Location loc = op.getLoc(); VectorType resultType = dyn_cast(op.getResult().getType()); if (!resultType) return failure(); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; // The offsets need to be distributed auto offsetsVecType = dyn_cast(adaptor.getOffsets().front().getType()); auto maskVecType = dyn_cast(adaptor.getMask().front().getType()); if (!offsetsVecType || !maskVecType || offsetsVecType.getShape() != maskVecType.getShape()) { return rewriter.notifyMatchFailure(op, "offsets have not been distributed"); } SmallVector newLoadOps; auto chunkSizeAttr = rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); for (auto [offsets, mask] : llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { auto newLoadOp = xegpu::LoadGatherOp::create( rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), layout.dropSgLayoutAndData()); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); return success(); } }; // This pattern transforms the StoreScatterOp with explicit offsets to store // subgroup data struct WgToSgStoreScatterOpWithOffset : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getOffsets()) return failure(); Location loc = op.getLoc(); VectorType valueType = dyn_cast(op.getValue().getType()); if (!valueType) return failure(); xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getValue()); if (!layout || !layout.isForWorkgroup()) return failure(); // The offsets need to be distributed auto offsetsVecType = dyn_cast(adaptor.getOffsets().front().getType()); auto maskVecType = dyn_cast(adaptor.getMask().front().getType()); if (!offsetsVecType || !maskVecType || offsetsVecType.getShape() != maskVecType.getShape()) { return rewriter.notifyMatchFailure(op, "offsets have not been distributed"); } auto chunkSizeOpt = op.getChunkSize(); int64_t chunkSize = chunkSizeOpt ? static_cast(*chunkSizeOpt) : 1; auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); for (auto [val, offs, mask] : llvm::zip( adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); // Update the layout attribute to drop sg_layout and sg_data. if (auto newLayout = layout.dropSgLayoutAndData()) op->setAttr("layout", newLayout); } rewriter.eraseOp(op); return success(); } }; struct WgToSgLoadMatrixOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); ArrayRef wgShape = op.getDataShape(); VectorType valueTy = op.getRes().getType(); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResTy = VectorType::get(sgShape, elemTy); SmallVector newOps; for (auto offsets : offsetsList) { auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy, op.getMemDesc(), offsets, layout.dropSgLayoutAndData()); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; struct WgToSgStoreMatrixOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector> offsetsList; if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList)) xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(), offsets, layout.dropSgLayoutAndData()); rewriter.eraseOp(op); return success(); } }; // This pattern distributes the vector.step ops to work at subgroup level struct WgToSgVectorStepOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); Location loc = op.getLoc(); VectorType type = op.getResult().getType(); auto wgShape = type.getShape(); std::optional> sgShape = getSgShapeAndCount(wgShape, layout).first; if (!sgShape) return failure(); Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); VectorType newTy = type.cloneWith(*sgShape, type.getElementType()); auto steps = vector::StepOp::create(rewriter, loc, newTy); SmallVector newOps; for (auto offsets : *sgOffsets) { // Broadcast the offset scalar to a vector & add to the base steps auto bcastOffset = vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); auto finalSteps = arith::AddIOp::create(rewriter, loc, steps, bcastOffset); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) { xegpu::setDistributeLayoutAttr(steps->getResult(0), layout.dropSgLayoutAndData()); xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), layout.dropSgLayoutAndData()); xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), layout.dropSgLayoutAndData()); } newOps.push_back(finalSteps); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; // This pattern transforms vector.shape_cast ops to work at subgroup level. struct WgToSgVectorShapeCastOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = dyn_cast(op.getResult().getType()); if (!resultType) return failure(); ArrayRef wgShape = resultType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); // TODO: Add check for compatible layouts in layout attr. auto srcType = dyn_cast(adaptor.getSource()[0].getType()); if (!srcType) return failure(); // Check that shape_cast only adds/removes unit dimensions, auto onlyUnitDims = [](ArrayRef src, ArrayRef dst) { // Remove all 1s from both shapes and compare the rest. SmallVector srcNonUnit, dstNonUnit; for (int64_t d : src) if (d != 1) srcNonUnit.push_back(d); for (int64_t d : dst) if (d != 1) dstNonUnit.push_back(d); return srcNonUnit == dstNonUnit; }; if (!onlyUnitDims(srcType.getShape(), sgShape)) return failure(); // For rank reducing or increasing shape_cast ops, the lower rank layout // must be a slice of higher rank layout. int64_t sourceRank = srcType.getRank(); int64_t resultRank = sgShape.size(); xegpu::DistributeLayoutAttr sourceLayout = xegpu::getDistributeLayoutAttr(op.getSource()); if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout)) return failure(); if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout)) return failure(); SmallVector newShapeCastOps; for (auto src : adaptor.getSource()) { auto newShapeCast = rewriter.create(op.getLoc(), newResultType, src); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), layout.dropSgLayoutAndData()); newShapeCastOps.push_back(newShapeCast.getResult()); } rewriter.replaceOpWithMultiple(op, {newShapeCastOps}); return success(); } }; /// Pattern for lowering vector.multi_reduction op to subgroup level. /// Current limitation: the sg_layout in the reduced dimension being 1 /// so that reduction is local to subgroup & no cross-subgroup communication is /// needed. /// TODO: Add cases to handle more general situations which require SLM access. struct WgToSgMultiDimReductionOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType srcType = op.getSourceVectorType(); VectorType dstType = dyn_cast(op.getResult().getType()); if (!dstType) return failure(); auto srcShape = srcType.getShape(); xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); auto reductionDims = llvm::to_vector(op.getReductionDims()); SmallVector sgLayout = llvm::cast(layout) .getParent() .getEffectiveSgLayoutAsInt(); SmallVector sgData = llvm::cast(layout) .getParent() .getEffectiveSgDataAsInt(); // Check that the sgLayout in the reduced dimension is 1 and // each sg gets the entire slice to reduce. for (int64_t dim : reductionDims) { if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim]) return rewriter.notifyMatchFailure( op, "sgLayout in each reduced dimension must be 1 and sgData in the " "reduced dim must match srcShape in that dim"); } SmallVector sgShape = getSgShapeAndCount(srcShape, layout).first; VectorType newDstType = VectorType::get({sgShape}, dstType.getElementType()); SmallVector newReductions; for (auto sgSrc : adaptor.getSource()) { auto newOp = rewriter.create( op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], op.getReductionDims()); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newOp->getResult(0), layout.dropSgLayoutAndData()); newReductions.push_back(newOp.getResult()); } rewriter.replaceOpWithMultiple(op, {newReductions}); return success(); } }; } // namespace namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns .add(patterns.getContext()); } } // namespace xegpu } // namespace mlir namespace { struct XeGPUWgToSgDistributePass : public xegpu::impl::XeGPUWgToSgDistributeBase { void runOnOperation() override; }; } // namespace void XeGPUWgToSgDistributePass::runOnOperation() { // Track existing UnrealizedConversionCastOps SmallVector existingCastOps; getOperation()->walk([&](UnrealizedConversionCastOp castOp) { existingCastOps.push_back(castOp.getOperation()); }); { // Step 1: Apply SCFStructuralTypeConversions to SCF operations with // VectorType operands. This first converts such operands to // RankedTensorType, propagates the layout attribute into the encoding // attribute, and finally converts the RankedTensorType to VectorType based // on the encoding. TypeConverter converter; converter.addConversion([&](Type type) -> Type { return type; }); converter.addConversion( [&](RankedTensorType type, SmallVectorImpl &result) -> std::optional { Type elemTy = type.getElementType(); ArrayRef shape = type.getShape(); int count; SmallVector subShape; std::tie(subShape, count) = getSgShapeAndCount( shape, dyn_cast_if_present(type.getEncoding())); auto newTy = VectorType::get(subShape, elemTy); result.append(count, newTy); return success(); }); xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter); } // Step 2: Perform workgroup to subgroup distribution for TensorDesc values, // as well as XeGPU, Arith, and Vector operations. MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); ConversionTarget target(*ctx); TypeConverter converter; converter.addConversion([&](Type type) -> Type { return type; }); converter.addConversion( [&](xegpu::TensorDescType type, SmallVectorImpl &result) -> std::optional { Type elemTy = type.getElementType(); ArrayRef shape = type.getShape(); int count; SmallVector subShape; xegpu::LayoutAttr layout = type.getLayoutAttr(); std::tie(subShape, count) = getSgShapeAndCount(shape, layout); if (layout) layout = layout.dropSgLayoutAndData(); auto newTy = xegpu::TensorDescType::get( type.getContext(), subShape, elemTy, type.getEncoding(), layout); result.append(count, newTy); return success(); }); auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { if (auto createOp = dyn_cast(op)) return createOp.getType(); if (auto loadOp = dyn_cast(op)) return loadOp.getTensorDescType(); if (auto storeOp = dyn_cast(op)) return storeOp.getTensorDescType(); if (auto updateOp = dyn_cast(op)) return updateOp.getType(); if (auto prefetchOp = dyn_cast(op)) return prefetchOp.getTensorDescType(); return xegpu::TensorDescType(); }; auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool { return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp([=](Operation *op) -> bool { auto tdescTy = getTensorDescType(op); auto layout = dyn_cast_if_present(tdescTy.getLayout()); return isLegal(layout); }); target.addDynamicallyLegalOp([=](xegpu::DpasOp op) -> bool { auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](xegpu::LoadMatrixOp op) -> bool { return isLegal(op.getLayoutAttr()); }); target.addDynamicallyLegalOp( [=](xegpu::StoreMatrixOp op) -> bool { return isLegal(op.getLayoutAttr()); }); target.addDynamicallyLegalOp( [=](arith::ConstantOp op) -> bool { auto vecType = dyn_cast(op.getType()); if (!vecType) return true; auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](xegpu::LoadGatherOp op) -> bool { auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](xegpu::StoreScatterOp op) -> bool { // Check if the layout attribute is present on the result. auto layout = op->getAttrOfType("layout"); if (!layout) return true; return isLegal(layout); }); target.addDynamicallyLegalOp( [=](vector::BroadcastOp op) -> bool { return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); target.addDynamicallyLegalOp( [=](vector::MultiDimReductionOp op) -> bool { return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); }); target.addDynamicallyLegalDialect( [=](Operation *op) -> std::optional { // Only handle elementwise mappable ops if (!OpTrait::hasElementwiseMappableTraits(op)) return true; VectorType resultType = dyn_cast(op->getResult(0).getType()); if (!resultType) return true; // Check if all operands are vectors of the same shape // TODO: Support other types. for (Value operand : op->getOperands()) { VectorType operandType = dyn_cast(operand.getType()); if (!operandType || operandType.getShape() != resultType.getShape()) { return true; } } xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); return isLegal(layout); }); target.addDynamicallyLegalOp( [=](UnrealizedConversionCastOp op) { return llvm::is_contained(existingCastOps, op.getOperation()); }); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); xegpu::populateXeGPUWgToSgDistributePatterns(patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); // Remove sg_layout and sg_data attributes from the Layout // attribute for each VectorType result of the operation. // For Structured Control Flow ops, the layout is simply removed, // since in 1:N case, the layout for new results are missing. // Layout propagation pass will activated. getOperation()->walk([](Operation *op) { for (OpResult result : op->getOpResults()) { std::string name = xegpu::getLayoutName(result); if (auto layout = op->getAttrOfType(name)) { op->removeAttr(name); if (!isa(op)) { if (auto newLayout = layout.dropSgLayoutAndData()) op->setAttr(name, newLayout); } } } }); }