//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/DebugLog.h" namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUBLOCKING #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" } // namespace xegpu } // namespace mlir #define DEBUG_TYPE "xegpu-blocking" using namespace mlir; namespace { // reslove the unrealized conversion cast ops generated when doing SCF // Structural Type Conversion. It will have two formats, N:1 vector // cast and 1:N vector cast. vector::insert_strided_slice ops will be // used for the first case, and vector::extract_strided_slice ops will be // used for the second case. static void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) { ValueRange inputs = castOp.getInputs(); ValueRange outputs = castOp.getOutputs(); auto hasIdenticalVectorTypes = [](ValueRange values) { auto types = values.getTypes(); return llvm::all_of(types, [&](Type type) { return isa(type) && type == types.front(); }); }; // We only interest in the case where all inputs and outputs have the // identical VectorTypes if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) { LDBG() << "skip unrealized conversion cast op not emulating pack/unpack."; return; } VectorType outputTy = dyn_cast(outputs[0].getType()); OpBuilder builder(castOp); if (inputs.size() > 1 && outputs.size() == 1) { // the castOp is emulating an unpack op ArrayRef shape = outputTy.getShape(); Value result = xegpu::createVectorWithShapeFromValues( builder, castOp.getLoc(), inputs, shape); castOp->replaceAllUsesWith(ValueRange(result)); castOp->erase(); } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) { // the castOp is emulating a pack op ArrayRef tileShape = outputTy.getShape(); SmallVector results = xegpu::extractVectorsWithShapeFromValue( builder, castOp.getLoc(), inputs[0], tileShape); castOp->replaceAllUsesWith(results); castOp->erase(); } } // This pattern lowers ConvertLayoutOp by removing the inst_data field from the // layout attributes. Since both producer and consumer operations handle data // partitioning based on their own inst_data, while maintaining original input // and output shape, ConvertLayoutOp does not need to manage inst_data. struct ConvertLayoutOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr(); xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr(); if (inputLayout.getEffectiveInstDataAsInt().empty() || targetLayout.getEffectiveInstDataAsInt().empty()) return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp."); inputLayout = inputLayout.dropInstData(); targetLayout = targetLayout.dropInstData(); auto newOp = rewriter.createOrFold( op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout); rewriter.replaceOp(op, newOp); return success(); } }; //===------------------------------------------------------------------------===// // The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops // to partition operations that process large shapes into multiple operations on // smaller shapes, as specified by the inst_data in the layout attribute. This // enables each resulting operation to be efficiently mapped to a hardware // instruction. //===------------------------------------------------------------------------===// class XeGPUBlockingPass final : public xegpu::impl::XeGPUBlockingBase { public: void runOnOperation() override; private: // Get the tile shape for a given OpOperand or OpResult by examining the // corresponding layout attribute. If layout is not present or is not a // subgroup level layout, it returns std::nullopt. template || std::is_same_v>> std::optional> getTileShape(const T &operandOrResult) const; // Get the tile shape for a given operation. std::optional> getTileShape(Operation *op) const; // Determine if the operation requires unrolling. Return false if all operands // and results have tile shapes identical to their original types. Otherwise, // return true. bool needsUnroll(Operation *op) const; }; } // namespace template std::optional> XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { Value value; if constexpr (std::is_same_v) value = operandOrResult.get(); else value = (Value)operandOrResult; xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult); if (layout && layout.isForSubgroup()) { if (!layout.getEffectiveInstDataAsInt().empty()) return layout.getEffectiveInstDataAsInt(); if (auto type = dyn_cast(value.getType())) return llvm::to_vector(type.getShape()); } LDBG() << "failed to getTileShape for: " << value; return std::nullopt; } std::optional> XeGPUBlockingPass::getTileShape(Operation *op) const { if (isa(op)) return getTileShape(op->getOpResult(0)); if (isa(op)) return getTileShape(op->getOpOperand(0)); if (isa(op)) return getTileShape(op->getOpOperand(1)); if (isa(op)) { std::optional> aTile = getTileShape(op->getOpOperand(0)); std::optional> bTile = getTileShape(op->getOpOperand(1)); if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2) return std::nullopt; // semantic check for A and B if ((*aTile)[1] != (*bTile)[0]) return std::nullopt; // semantic check for C if (op->getNumOperands() == 3) { std::optional> cTile = getTileShape(op->getOpOperand(2)); int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]}; if (!cTile || !llvm::equal(*cTile, expectedCTile)) return std::nullopt; } return SmallVector({(*aTile)[0], (*aTile)[1], (*bTile)[1]}); } if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) return getTileShape(op->getOpResult(0)); if (isa(op)) return getTileShape(op->getOpOperand(0)); if (isa(op)) return getTileShape(op->getOpResult(0)); return std::nullopt; } bool XeGPUBlockingPass::needsUnroll(Operation *op) const { // skip the op if any of its operands or results has workgroup level layouts bool hasWgLayoutOperands = llvm::any_of(op->getOpOperands(), [](OpOperand &opr) { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(opr); return layout && layout.isForWorkgroup(); }); bool hasWgLayoutResults = llvm::any_of(op->getOpResults(), [](OpResult result) { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(result); return layout && layout.isForWorkgroup(); }); if (hasWgLayoutOperands || hasWgLayoutResults) { LDBG() << "skip unrolling for op with workgroup level layout: " << *op; return false; } auto isUnrollable = [](Value value, ArrayRef tileShape) { Type valTy = value.getType(); if (auto tdescTy = dyn_cast(valTy)) { xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); return layout && !layout.getEffectiveInstDataAsInt().empty(); } auto shapedType = dyn_cast(valTy); return shapedType && !llvm::equal(tileShape, shapedType.getShape()); }; bool hasUnrollableOperands = llvm::any_of(op->getOpOperands(), [&](OpOperand &opr) { std::optional> tileShape = getTileShape(opr); return tileShape.has_value() && isUnrollable(opr.get(), *tileShape); }); bool hasUnrollableResults = llvm::any_of(op->getOpResults(), [&](OpResult result) { std::optional> tileShape = getTileShape(result); return tileShape.has_value() && isUnrollable(result, *tileShape); }); return hasUnrollableOperands || hasUnrollableResults; } void XeGPUBlockingPass::runOnOperation() { MLIRContext *ctx = &getContext(); Operation *op = getOperation(); // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr. // This ensures that the LayoutAttr remains accessible even if the defining // operation is replaced. xegpu::setDistributeLayoutAttrs( op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); }); auto getTileShapeAndCount = [](llvm::ArrayRef shape, xegpu::LayoutAttr layout) { int count = 1; SmallVector tileShape(shape); if (layout && layout.getInstData()) { DenseI32ArrayAttr instData = layout.getInstData(); tileShape = llvm::to_vector_of(instData.asArrayRef()); count = computeProduct(shape) / computeProduct(tileShape); } return std::make_pair(tileShape, count); }; // Perform type conversion for SCF control folow ops TypeConverter converter; converter.addConversion([](Type type) -> Type { return type; }); converter.addConversion( [&](RankedTensorType type, SmallVectorImpl &result) -> std::optional { Type elemTy = type.getElementType(); ArrayRef shape = type.getShape(); auto layout = llvm::dyn_cast_if_present(type.getEncoding()); if (layout && layout.isForWorkgroup()) return failure(); int count; SmallVector subShape; std::tie(subShape, count) = getTileShapeAndCount(shape, layout); auto newTy = VectorType::get(subShape, elemTy); result.append(count, newTy); return success(); }); converter.addConversion( [&](xegpu::TensorDescType type, SmallVectorImpl &result) -> std::optional { Type elemTy = type.getElementType(); ArrayRef shape = type.getShape(); xegpu::LayoutAttr layout = type.getLayoutAttr(); if (layout && layout.isForWorkgroup()) return failure(); int count; SmallVector subShape; std::tie(subShape, count) = getTileShapeAndCount(shape, layout); if (layout) layout = layout.dropInstData(); auto newTy = xegpu::TensorDescType::get( type.getContext(), subShape, elemTy, type.getEncoding(), layout); result.append(count, newTy); return success(); }); xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter); xegpu::UnrollOptions options; options.setFilterConstraint( [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); }); options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); }); options.setUnrolledTypesFn([&](ShapedType type, ArrayRef tileShape, bool returnSingleType = false) { Type elemTy = type.getElementType(); Type newTy; if (auto tdescTy = dyn_cast(type)) { Attribute encoding = tdescTy.getEncoding(); // If the encoding is a ScatterTensorDescAttr, we need to // potentially adjust the chunk size based on the inst_data. if (tdescTy.isScattered()) { int64_t chunkSize = tdescTy.getChunkSizeAsInt(); if (chunkSize > 1) { int64_t blockedChunkSize = chunkSize; auto instData = tdescTy.getLayoutAttr().getInstData(); if (!instData.empty()) blockedChunkSize = instData.asArrayRef().back(); // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( ctx, tdescTy.getMemorySpace(), blockedChunkSize); encoding = newEncoding; } } newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, tdescTy.getLayoutAttr().dropInstData()); } else { newTy = type.clone(tileShape, elemTy); } if (returnSingleType) return SmallVector{newTy}; std::optional> ratio = computeShapeRatio(type.getShape(), tileShape); assert(ratio && "The shape of the type must be a multiple of tileShape."); return SmallVector(computeProduct(*ratio), newTy); }); RewritePatternSet patterns(ctx); patterns.add(ctx); vector::UnrollVectorOptions vectorOptions; vectorOptions.setNativeShapeFn(options.nativeShape); populateXeGPUUnrollPatterns(patterns, options); vector::populateVectorUnrollPatterns(patterns, vectorOptions); (void)applyPatternsGreedily(op, std::move(patterns)); op->walk([](Operation *op) { // Remove the layout attributes cached per operands. for (OpOperand &opr : op->getOpOperands()) { std::string name = xegpu::getLayoutName(opr); if (op->hasAttrOfType(name)) op->removeAttr(name); } // Update the layout attributes per result. for (OpResult result : op->getOpResults()) { std::string name = xegpu::getLayoutName(result); if (auto layout = op->getAttrOfType(name)) { op->removeAttr(name); if (!isa(op)) xegpu::setDistributeLayoutAttr(result, layout.dropInstData()); } } // Resolve unrealized conversion cast ops emulating pack/unpack if (auto castOp = dyn_cast(op)) resolveUnrealizedConversionCastOp(castOp); }); }