//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains patterns for unrolling XeGPU operations. It follows a // similar concept and design as vector unroll patterns, serving as a complement // to them. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/DebugLog.h" namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUUNROLL #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" } // namespace xegpu } // namespace mlir #define DEBUG_TYPE "xegpu-unroll" using namespace mlir; namespace { template struct UnrollPattern : public OpRewritePattern { UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} protected: /// Return the target shape for the given `op`. Return std::nullopt if the /// op shouldn't be or cannot be unrolled. std::optional> getTargetShape(Operation *op) const { LDBG() << "Get unroll shape for: " << *op; if (options.filterConstraint && failed(options.filterConstraint(op))) { LDBG() << "--no filter constraint -> BAIL"; return std::nullopt; } assert(options.nativeShape && "expects the native shape for native shape call back function."); auto nativeShape = options.nativeShape(op); return nativeShape; } SmallVector getUnrolledTypes(ShapedType type, ArrayRef tileShape, bool returnSingleType = false) const { return options.getUnrolledTypes(type, tileShape, returnSingleType); } /// Emulate the the unpack behavior using insert_strided_slice for VectorType /// values and unrealized_conversion_cast for TensorDescType values. Value unpack(ValueRange srcs, Type destTy, ArrayRef blockSize, Location loc, PatternRewriter &rewriter) const { if (auto vecTy = dyn_cast(destTy)) { assert(vecTy.getRank() == static_cast(blockSize.size()) && "Expecting blockSize size to match the rank of destTy."); auto shape = vecTy.getShape(); return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape); } if (isa(destTy)) { auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName), rewriter.getUnitAttr()); auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), rewriter.getDenseI64ArrayAttr(blockSize)); auto castOp = UnrealizedConversionCastOp::create( rewriter, loc, destTy, srcs, ArrayRef({attr, blkAttr})); return castOp.getResult(0); } llvm_unreachable("Unexpected destTy."); return Value(); } /// Emulate the the pack behavior using extract_strided_slice for VectorType /// values and unrealized_conversion_cast for TensorDescType values. SmallVector pack(Value src, TypeRange destTypes, ArrayRef blockSize, Location loc, PatternRewriter &rewriter) const { if (auto vecTy = dyn_cast(src.getType())) { assert(vecTy.getRank() == static_cast(blockSize.size()) && "Expecting blockSize size to match the rank of src."); return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src, blockSize); } if (isa(src.getType())) { auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName), rewriter.getUnitAttr()); auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), rewriter.getDenseI64ArrayAttr(blockSize)); auto castOp = UnrealizedConversionCastOp::create( rewriter, loc, destTypes, src, ArrayRef({attr, blkAttr})); return castOp.getResults(); } llvm_unreachable("Unexpected src type."); return SmallVector(); } private: const char *const packAttrName = "__xegpu_blocking_pack__"; const char *const unpackAttrName = "__xegpu_blocking_unpack__"; const char *const blockAttrName = "__xegpu_blocking_tile_shape__"; xegpu::UnrollOptions options; }; // Generic helper function for unrolling operations with offsets. // // Iterates over tile offsets within the tensor descriptor shape and calls // the provided createOp function for each computed offset. This is used by // operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they // have explicit offsets that need to be adjusted for each unrolled tile. SmallVector computeUnrolledOffsets( SmallVector mixedOffsets, xegpu::TensorDescType tdescTy, ArrayRef targetShape, const std::function)> &createOp, Location loc, PatternRewriter &rewriter) { int64_t rank = tdescTy.getRank(); ArrayRef shape = tdescTy.getShape(); auto addi = [&](OpFoldResult a, int64_t b) -> Value { std::optional maybeInt = getConstantIntValue(a); if (maybeInt) { return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b); } else { auto aV = llvm::cast(a); auto bV = arith::ConstantIndexOp::create(rewriter, loc, b); return rewriter.createOrFold(loc, aV, bV); } }; SmallVector oldOffsets = llvm::to_vector( llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank)); auto validIdxes = llvm::seq(mixedOffsets.size() - rank, mixedOffsets.size()); SmallVector newOps; for (SmallVector offsets : StaticTileOffsetRange(shape, targetShape)) { for (auto [idx, oldOff, offset] : llvm::zip(validIdxes, oldOffsets, offsets)) mixedOffsets[idx] = addi(oldOff, offset); auto newOp = createOp(mixedOffsets); newOps.push_back(newOp); } return newOps; } struct UnrollCreateNdOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getType(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); SmallVector newOps; auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; bool hasOffsets = op.getMixedOffsets().size() != 0; if (!hasOffsets) { auto newOp = xegpu::CreateNdDescOp::create( rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(), op.getMixedStrides()); newOps.push_back(newOp); } else { auto createOp = [&](SmallVector offsets) -> Value { return xegpu::CreateNdDescOp::create( rewriter, loc, newTdescTy, op.getSource(), offsets, op.getMixedSizes(), op.getMixedStrides()); }; newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createOp, loc, rewriter); } Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; struct UnrollUpdateNdOffsetOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdesc = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector newOps; for (auto t : convertedTdesc) { auto newOp = xegpu::UpdateNdOffsetOp::create( rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets()); newOps.push_back(newOp); } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; struct UnrollPrefetchNdOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); int64_t offsetSize = static_cast(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); SmallVector convertedTdescTypes = getUnrolledTypes( tdescTy, *targetShape, /*returnSingleType*/ hasOffsets); SmallVector convertedTdesc = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); if (!hasOffsets) { for (auto t : convertedTdesc) xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t, op->getAttrs()); } else { auto createPrefetch = [&](SmallVector offsets) -> Value { xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); // return dummy Value to satisfy function's signature return nullptr; }; computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createPrefetch, loc, rewriter); } rewriter.eraseOp(op); return success(); } }; struct UnrollLoadNdOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::LoadNdOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType valueTy = op.getType(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); int64_t offsetSize = static_cast(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); Type elemTy = tdescTy.getElementType(); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); SmallVector convertedTdescTypes = getUnrolledTypes( tdescTy, *targetShape, /*returnSingleType*/ hasOffsets); SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector newOps; if (!hasOffsets) { for (auto t : convertedTdescs) { auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs()); newOps.push_back(newOp); } } else { auto createLoad = [&](SmallVector offsets) { return xegpu::LoadNdOp::create( rewriter, loc, newValueTy, convertedTdescs[0], offsets, op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); }; newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createLoad, loc, rewriter); } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; struct UnrollStoreNdOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::StoreNdOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType valueTy = op.getValueType(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); int64_t offsetSize = static_cast(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); SmallVector convertedValTypes = getUnrolledTypes(valueTy, *targetShape); SmallVector convertedTdescTypes = getUnrolledTypes( tdescTy, *targetShape, /*returnSingleType*/ hasOffsets); SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); if (!hasOffsets) { for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } else { size_t valueIndex = 0; auto createStore = [&](SmallVector offsets) { xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++], convertedTdescs[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); // return dummy Value to satisfy function's signature return nullptr; }; computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createStore, loc, rewriter); } rewriter.eraseOp(op); return success(); } }; struct UnrollDpasOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::DpasOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); // expecting every operands is a 2D Vector if (llvm::any_of(op->getOperandTypes(), [&](Type type) { auto vecTy = dyn_cast(type); return !vecTy || vecTy.getRank() != 2; })) return failure(); // A vector of 3 elements should be returned, representing M, K, N // respectively. std::optional> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != 3) return failure(); auto M = (*targetShape)[0]; auto K = (*targetShape)[1]; auto N = (*targetShape)[2]; int64_t aBlockSize[2] = {M, K}; int64_t bBlockSize[2] = {K, N}; int64_t cBlockSize[2] = {M, N}; auto packWrapper = [&](TypedValue val, ArrayRef blockSize) { VectorType type = val.getType(); std::optional> grids = computeShapeRatio(type.getShape(), blockSize); assert(grids && "Expecting grids to be computed."); auto numNewOps = computeProduct(*grids); if (numNewOps == 1) return SmallVector({val}); VectorType newVecTy = type.cloneWith(blockSize, type.getElementType()); SmallVector convertedTypes(numNewOps, newVecTy); SmallVector values = pack(val, convertedTypes, blockSize, loc, rewriter); return values; }; auto a = op.getLhs(); auto b = op.getRhs(); auto c = op.getAcc(); auto aShape = a.getType().getShape(); auto bShape = b.getType().getShape(); SmallVector aVals, bVals, cVals; aVals = packWrapper(a, aBlockSize); bVals = packWrapper(b, bBlockSize); if (c) cVals = packWrapper(c, cBlockSize); // Skip the operation if every operand has an invalid blocking size (empty) // or if the original shape matches the blocking size (size == 1). auto ranges = c ? SmallVector({aVals, bVals, cVals}) : SmallVector({aVals, bVals}); if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) || llvm::all_of(ranges, [](auto &v) { return v.size() == 1; })) return failure(); VectorType resultTy = op.getResult().getType(); auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType()); int64_t mIters = aShape[0] / M; int64_t kIters = aShape[1] / K; int64_t nIters = bShape[1] / N; SmallVector newOps; for (int64_t i = 0; i < mIters; ++i) { for (int64_t j = 0; j < nIters; ++j) { Value tmpC; if (c) tmpC = cVals[i * nIters + j]; // init with acc for (int64_t k = 0; k < kIters; ++k) { Value aVec = aVals[i * kIters + k]; Value bVec = bVals[k * nIters + j]; SmallVector operands({aVec, bVec}); if (tmpC) operands.push_back(tmpC); tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands, op->getAttrs()); } newOps.push_back(tmpC); } } Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; struct UnrollCreateDescOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::CreateDescOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getType(); TypedValue<::mlir::VectorType> indiceVec = op.getOffsets(); VectorType indiceVecTy = indiceVec.getType(); if (!tdescTy.isScattered()) return failure(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); SmallVector targetIndiceShape(*targetShape); int64_t originalChunkSize = tdescTy.getChunkSizeAsInt(); // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1. if (originalChunkSize > 1) targetIndiceShape.pop_back(); auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; SmallVector convertedIndiceTypes = getUnrolledTypes(indiceVecTy, targetIndiceShape); SmallVector convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter); SmallVector newOps; // More indices is need when chunkSize > 1. Since a big load from one // address could be break into multiple small loads. if (originalChunkSize > 1) { int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = originalChunkSize / blockedChunkSize; for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) { for (int64_t i = 0; i < numNewChunks; ++i) { // Compute the offset Value inc = arith::ConstantIndexOp::create(rewriter, loc, i * blockedChunkSize); Value incVec = vector::BroadcastOp::create(rewriter, loc, indiceType, inc); Value offsetIndice = arith::AddIOp::create(rewriter, loc, indice, incVec); auto newOp = xegpu::CreateDescOp::create( rewriter, loc, newTdescTy, op.getSource(), offsetIndice); newOps.push_back(newOp); } } } else { for (auto indice : convertedIndiceVec) { auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy, op.getSource(), indice); newOps.push_back(newOp); } } Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; struct UnrollLoadGatherOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::LoadGatherOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); // TODO: handle the unstructure source case (!tdesTy) if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); SmallVector targetMaskShape(*targetShape); int64_t originalChunkSize = tdescTy.getChunkSizeAsInt(); VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); Type elemTy = tdescTy.getElementType(); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector convertedMaskTypes; SmallVector convertedMasks; if (originalChunkSize > 1) { targetMaskShape.pop_back(); convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = originalChunkSize / blockedChunkSize; // the mask is reused across the chunk_size dimension for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter)) convertedMasks.append(numNewChunks, mask); newValueTy = valueTy.cloneWith(*targetShape, elemTy); } else { convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter); } SmallVector newOps; for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) { auto newOp = xegpu::LoadGatherOp::create( rewriter, loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); newOps.push_back(newOp); } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; /// This pattern handles the unrolling of LoadGatherOp with offsets (gathered /// load). /// It unrolls the offsets and mask operands accordingly, and creates multiple /// LoadGatherOp with the unrolled operands. struct UnrollLoadGatherOpWithOffset : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::LoadGatherOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType valueTy = llvm::dyn_cast(op.getType()); Value offsets = op.getOffsets(); Value mask = op.getMask(); // Only handle the case where offsets are present (scattered load) if (!offsets) return failure(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); SmallVector targetMaskShape(*targetShape); int64_t chunkSize = 1; if (auto chunkSizeAttr = op->getAttr("chunk_size")) { if (auto intAttr = llvm::dyn_cast(chunkSizeAttr)) chunkSize = intAttr.getInt(); } // Unroll mask and offsets with correct shape VectorType maskTy = llvm::dyn_cast(mask.getType()); VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); Type elemTy = valueTy.getElementType(); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); SmallVector convertedMaskTypes; SmallVector convertedMasks; SmallVector convertedOffsetTypes; SmallVector convertedOffsets; if (chunkSize > 1) { // For chunked loads, mask and offsets have one less dimension targetMaskShape.pop_back(); int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = chunkSize / blockedChunkSize; chunkSize = blockedChunkSize; convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape); SmallVector convertedMasksBase = pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); SmallVector convertedOffsetsBase = pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter); for (auto maskVal : convertedMasksBase) convertedMasks.append(numNewChunks, maskVal); for (auto [baseOffset, offsetType] : llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) { for (int64_t i = 0; i < numNewChunks; ++i) { Value inc = arith::ConstantIndexOp::create(rewriter, loc, i * blockedChunkSize); Value incVec = vector::BroadcastOp::create(rewriter, loc, offsetType, inc); Value offsetVal = arith::AddIOp::create(rewriter, loc, baseOffset, incVec); convertedOffsets.push_back(offsetVal); } } } else { convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); convertedMasks = pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape); convertedOffsets = pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); } SmallVector newOps; for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) { auto newOp = xegpu::LoadGatherOp::create( rewriter, loc, newValueTy, op.getSource(), o, m, rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); newOps.push_back(newOp); } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; /// This pattern handles the unrolling of StoreScatterOp with offsets (scattered /// store). /// It unrolls the offsets and mask operands accordingly, and creates multiple /// StoreScatterOp with the unrolled operands. struct UnrollStoreScatterOpWithOffsets : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::StoreScatterOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); Value offsets = op.getOffsets(); Value mask = op.getMask(); // Only handle the case where offsets are present (scattered store) if (!offsets) return failure(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); int64_t chunkSize = 1; if (auto chunkSizeAttr = op->getAttr("chunk_size")) { if (auto intAttr = llvm::dyn_cast(chunkSizeAttr)) chunkSize = intAttr.getInt(); } SmallVector targetMaskShape(*targetShape); VectorType maskTy = llvm::dyn_cast(mask.getType()); VectorType offsetsTy = llvm::dyn_cast(offsets.getType()); SmallVector convertedMaskTypes; SmallVector convertedMasks; SmallVector convertedOffsetTypes; SmallVector convertedOffsets; if (chunkSize > 1) { targetMaskShape.pop_back(); int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = chunkSize / blockedChunkSize; chunkSize = blockedChunkSize; convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape); SmallVector convertedMasksBase = pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); SmallVector convertedOffsetsBase = pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter); for (auto maskVal : convertedMasksBase) convertedMasks.append(numNewChunks, maskVal); for (auto [baseOffset, offsetType] : llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) { for (int64_t i = 0; i < numNewChunks; ++i) { Value inc = arith::ConstantIndexOp::create(rewriter, loc, i * blockedChunkSize); Value incVec = vector::BroadcastOp::create(rewriter, loc, offsetType, inc); Value offsetVal = arith::AddIOp::create(rewriter, loc, baseOffset, incVec); convertedOffsets.push_back(offsetVal); } } } else { convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); convertedMasks = pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter); convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape); convertedOffsets = pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); } SmallVector convertedValTypes = getUnrolledTypes(valueTy, *targetShape); SmallVector convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); for (auto [v, o, m] : llvm::zip(convertedValues, convertedOffsets, convertedMasks)) { xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m, rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); return success(); } }; struct UnrollPrefetchOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::PrefetchOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); // TODO: handle the unstructure source case (!tdesTy) if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdesc = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); for (auto t : convertedTdesc) xegpu::PrefetchOp::create(rewriter, loc, TypeRange(), t, op->getAttrs()); rewriter.eraseOp(op); return success(); } }; struct UnrollStoreScatterOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::StoreScatterOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType valueTy = llvm::dyn_cast(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); // TODO: handle the unstructure source case (!tdesTy) if (!tdescTy || op.getOffsets()) return failure(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); SmallVector targetMaskShape(*targetShape); int64_t originalChunkSize = tdescTy.getChunkSizeAsInt(); VectorType maskTy = llvm::dyn_cast(op.getMask().getType()); SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); SmallVector convertedMaskTypes; SmallVector convertedMasks; if (originalChunkSize > 1) { targetMaskShape.pop_back(); int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = originalChunkSize / blockedChunkSize; convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); // the mask is reused across the chunk_size dimension for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter)) convertedMasks.append(numNewChunks, mask); } else { convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape); convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter); } SmallVector convertedValTypes = getUnrolledTypes(valueTy, *targetShape); SmallVector convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); for (size_t i = 0; i < convertedValues.size(); ++i) { Value v = convertedValues[i]; Value t = convertedTdescs[i]; Value m = op.getMask() ? convertedMasks[i] : nullptr; xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); return success(); } }; struct UnrollUpdateOffsetOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); if (!tdescTy.isScattered()) return failure(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdesc = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); TypedValue<::mlir::VectorType> offsetVec = op.getOffsets(); VectorType offsetVecTy = offsetVec.getType(); SmallVector convertedOffsetTypes; SmallVector convertedOffsetVec; SmallVector newOps; int64_t originalChunkSize = tdescTy.getChunkSizeAsInt(); if (originalChunkSize > 1) { auto targetOffsetShape = ArrayRef(*targetShape).drop_back(); convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape); int64_t blockedChunkSize = targetShape->back(); int64_t numNewChunks = originalChunkSize / blockedChunkSize; // the offset is reused across the chunk_size dimension for (auto offset : pack(offsetVec, convertedOffsetTypes, targetOffsetShape, loc, rewriter)) convertedOffsetVec.append(numNewChunks, offset); } else { convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape); convertedOffsetVec = pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter); } for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) { auto newOp = xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o); newOps.push_back(newOp); } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; struct UnrollLoadMatrixOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType valueTy = op.getType(); std::optional> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); Type elemTy = valueTy.getElementType(); ArrayRef shape = valueTy.getShape(); auto layout = dyn_cast(op.getLayoutAttr()); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); SmallVector mixedOffsets = op.getMixedOffsets(); SmallVector> offsetsList; for (SmallVector offsets : StaticTileOffsetRange(shape, *targetShape)) { auto adds = xegpu::addElementwise( rewriter, loc, mixedOffsets, getAsIndexOpFoldResult(op.getContext(), offsets)); offsetsList.push_back(adds); } SmallVector newOps; layout = layout.dropInstData(); for (SmallVector offsets : offsetsList) { auto newOp = xegpu::LoadMatrixOp::create( rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout); newOps.push_back(newOp); } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); return success(); } }; struct UnrollStoreMatrixOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op, PatternRewriter &rewriter) const override { std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); Location loc = op.getLoc(); VectorType valueTy = op.getData().getType(); ArrayRef shape = valueTy.getShape(); auto layout = dyn_cast(op.getLayoutAttr()); SmallVector convertedValTypes = getUnrolledTypes(valueTy, *targetShape); SmallVector convertedValues = pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter); SmallVector mixedOffsets = op.getMixedOffsets(); SmallVector> offsetsList; for (SmallVector offsets : StaticTileOffsetRange(shape, *targetShape)) { auto adds = xegpu::addElementwise( rewriter, loc, mixedOffsets, getAsIndexOpFoldResult(op.getContext(), offsets)); offsetsList.push_back(adds); } for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList)) xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets, layout.dropInstData()); rewriter.eraseOp(op); return success(); } }; } // namespace void mlir::xegpu::populateXeGPUUnrollPatterns( RewritePatternSet &patterns, const xegpu::UnrollOptions &options) { patterns .add( patterns.getContext(), options); }