//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering of vector operations to XeGPU dialect ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include #include namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { // Return true if value represents a zero constant. static bool isZeroConstant(Value val) { auto constant = val.getDefiningOp(); if (!constant) return false; return TypeSwitch(constant.getValue()) .Case( [](auto floatAttr) { return floatAttr.getValue().isZero(); }) .Case( [](auto intAttr) { return intAttr.getValue().isZero(); }) .Default([](auto) { return false; }); } static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, Operation *op, VectorType vecTy) { // Validate only vector as the basic vector store and load ops guarantee // XeGPU-compatible memref source. unsigned vecRank = vecTy.getRank(); if (!(vecRank == 1 || vecRank == 2)) return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector"); return success(); } static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp) { if (xferOp.getMask()) return rewriter.notifyMatchFailure(xferOp, "Masked transfer is not supported"); auto srcTy = dyn_cast(xferOp.getShapedType()); if (!srcTy) return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); // Validate further transfer op semantics. SmallVector strides; int64_t offset; if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1) return rewriter.notifyMatchFailure( xferOp, "Buffer must be contiguous in the innermost dimension"); VectorType vecTy = xferOp.getVectorType(); unsigned vecRank = vecTy.getRank(); if (xferOp.hasOutOfBoundsDim() && vecRank < 2) return rewriter.notifyMatchFailure( xferOp, "Boundary check is available only for block instructions."); AffineMap map = xferOp.getPermutationMap(); if (!map.isProjectedPermutation(/*allowZeroInResults=*/false)) return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map"); unsigned numInputDims = map.getNumInputs(); for (AffineExpr expr : map.getResults().take_back(vecRank)) { auto dim = dyn_cast(expr); if (dim.getPosition() < (numInputDims - vecRank)) return rewriter.notifyMatchFailure( xferOp, "Only the innermost dimensions can be accessed"); } return success(); } static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::TensorDescType descType, TypedValue src, Operation::operand_range offsets) { MemRefType srcTy = src.getType(); auto [strides, offset] = srcTy.getStridesAndOffset(); xegpu::CreateNdDescOp ndDesc; if (srcTy.hasStaticShape()) { ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, getAsOpFoldResult(offsets)); } else { // In case of any dynamic shapes, source's shape and strides have to be // explicitly provided. SmallVector sourceDims; unsigned srcRank = srcTy.getRank(); for (unsigned i = 0; i < srcRank; ++i) sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); SmallVector constOffsets; SmallVector dynOffsets; for (Value offset : offsets) { std::optional staticVal = getConstantIntValue(offset); if (!staticVal) dynOffsets.push_back(offset); constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); } SmallVector dynShapes; for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { if (shape == ShapedType::kDynamic) dynShapes.push_back(sourceDims[idx]); } // Compute strides in reverse order. SmallVector dynStrides; Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); // Last stride is guaranteed to be static and unit. for (int i = static_cast(strides.size()) - 2; i >= 0; --i) { accStride = arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); if (strides[i] == ShapedType::kDynamic) dynStrides.push_back(accStride); } std::reverse(dynStrides.begin(), dynStrides.end()); ndDesc = xegpu::CreateNdDescOp::create( rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides, DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), DenseI64ArrayAttr::get(rewriter.getContext(), strides)); } return ndDesc; } // Adjusts the strides of a memref according to a given permutation map for // vector operations. // // This function updates the innermost strides in the `strides` array to // reflect the permutation specified by `permMap`. The permutation is computed // using the inverse and broadcasting-aware version of the permutation map, // and is applied to the relevant strides. This ensures that memory accesses // are consistent with the logical permutation of vector elements. // // Example: // Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`. // If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1, // 0]), then after calling this function, the last two strides will be // swapped: // Original strides: [s0, s1, s2, s3] // After permutation: [s0, s1, s3, s2] // static void adjustStridesForPermutation(AffineMap permMap, SmallVectorImpl &strides) { AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap); SmallVector perms; invMap.isPermutationOfMinorIdentityWithBroadcasting(perms); SmallVector perms64(perms.begin(), perms.end()); strides = applyPermutation(strides, perms64); } // Computes memory strides and a memref offset for vector transfer operations, // handling both static and dynamic memrefs while applying permutation // transformations for XeGPU lowering. template < typename OpType, typename = std::enable_if_t, vector::TransferReadOp, vector::TransferWriteOp, vector::GatherOp, vector::ScatterOp>::value>> static std::pair, Value> computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) { SmallVector strides; Value baseMemref = xferOp.getBase(); MemRefType memrefType = dyn_cast(baseMemref.getType()); Location loc = xferOp.getLoc(); Value offsetVal = nullptr; if (memrefType.hasStaticShape()) { int64_t offset; SmallVector intStrides; if (failed(memrefType.getStridesAndOffset(intStrides, offset))) return {{}, offsetVal}; bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) { return ShapedType::isDynamic(strideVal); }); if (!hasDynamicStrides) for (int64_t s : intStrides) strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); if (!ShapedType::isDynamic(offset)) offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset); } if (strides.empty() || !offsetVal) { // For dynamic shape memref, use memref.extract_strided_metadata to get // stride values unsigned rank = memrefType.getRank(); Type indexType = rewriter.getIndexType(); // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1, // size0, size1, ..., sizeN-1] SmallVector resultTypes; resultTypes.push_back(MemRefType::get( {}, memrefType.getElementType())); // base memref (unranked) resultTypes.push_back(indexType); // offset for (unsigned i = 0; i < rank; ++i) resultTypes.push_back(indexType); // strides for (unsigned i = 0; i < rank; ++i) resultTypes.push_back(indexType); // sizes auto meta = memref::ExtractStridedMetadataOp::create( rewriter, loc, resultTypes, baseMemref); if (strides.empty()) strides.append(meta.getStrides().begin(), meta.getStrides().end()); if (!offsetVal) offsetVal = meta.getOffset(); } if constexpr (llvm::is_one_of, vector::TransferReadOp, vector::TransferWriteOp>::value) { AffineMap permMap = xferOp.getPermutationMap(); // Adjust strides according to the permutation map (e.g., for transpose) adjustStridesForPermutation(permMap, strides); } return {strides, offsetVal}; } // This function compute the vectors of localOffsets for scattered load/stores. // It is used in the lowering of vector.transfer_read/write to // load_gather/store_scatter Example: // %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], // %cst {in_bounds = [true, true, true, true]}>} : // memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16> // // %6 = vector.step: vector<4xindex> // %7 = vector.step: vector<2xindex> // %8 = vector.step: vector<6xindex> // %9 = vector.step: vector<32xindex> // %10 = arith.mul %6, 384 // %11 = arith.mul %7, 192 // %12 = arith.mul %8, 32 // %13 = arith.mul %9, 1 // %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16> // %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16> // %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16> // %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16> // %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex> // %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex> // %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex> // %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex> // %22 = arith.add %18, %19 // %23 = arith.add %20, %21 // %local_offsets = arith.add %22, %23 // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map // %offsets = memref_offset + orig_offset + local_offsets static Value computeOffsets(VectorTransferOpInterface xferOp, PatternRewriter &rewriter, ArrayRef strides, Value baseOffset) { Location loc = xferOp.getLoc(); VectorType vectorType = xferOp.getVectorType(); SmallVector indices(xferOp.getIndices().begin(), xferOp.getIndices().end()); ArrayRef vectorShape = vectorType.getShape(); // Create vector.step operations for each dimension SmallVector stepVectors; llvm::map_to_vector(vectorShape, [&](int64_t dim) { auto stepType = VectorType::get({dim}, rewriter.getIndexType()); auto stepOp = vector::StepOp::create(rewriter, loc, stepType); stepVectors.push_back(stepOp); return stepOp; }); // Multiply step vectors by corresponding strides size_t memrefRank = strides.size(); size_t vectorRank = vectorShape.size(); SmallVector strideMultiplied; for (size_t i = 0; i < vectorRank; ++i) { size_t memrefDim = memrefRank - vectorRank + i; Value strideValue = strides[memrefDim]; auto mulType = dyn_cast(stepVectors[i].getType()); auto bcastOp = vector::BroadcastOp::create(rewriter, loc, mulType, strideValue); auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp); strideMultiplied.push_back(mulOp); } // Shape cast each multiplied vector to add singleton dimensions SmallVector shapeCasted; for (size_t i = 0; i < vectorRank; ++i) { SmallVector newShape(vectorRank, 1); newShape[i] = vectorShape[i]; auto newType = VectorType::get(newShape, rewriter.getIndexType()); auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType, strideMultiplied[i]); shapeCasted.push_back(castOp); } // Broadcast each shape-casted vector to full vector shape SmallVector broadcasted; auto fullIndexVectorType = VectorType::get(vectorShape, rewriter.getIndexType()); for (Value shapeCastVal : shapeCasted) { auto broadcastOp = vector::BroadcastOp::create( rewriter, loc, fullIndexVectorType, shapeCastVal); broadcasted.push_back(broadcastOp); } // Add all broadcasted vectors together to compute local offsets Value localOffsets = broadcasted[0]; for (size_t i = 1; i < broadcasted.size(); ++i) localOffsets = arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]); // Compute base offset from transfer read indices for (size_t i = 0; i < indices.size(); ++i) { Value strideVal = strides[i]; Value offsetContrib = arith::MulIOp::create(rewriter, loc, indices[i], strideVal); baseOffset = arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); } // Broadcast base offset to match vector shape Value bcastBase = vector::BroadcastOp::create( rewriter, loc, fullIndexVectorType, baseOffset); localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); return localOffsets; } // Compute the element-wise offsets for vector.gather or vector.scatter ops. // // This function linearizes the base offsets of the gather/scatter operation // and combines them with the per-element indices to produce a final vector of // memory offsets. template < typename OpType, typename = std::enable_if_t, vector::GatherOp, vector::ScatterOp>::value>> static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp, ArrayRef strides, Value baseOffset) { Location loc = gatScatOp.getLoc(); SmallVector offsets = gatScatOp.getOffsets(); for (size_t i = 0; i < offsets.size(); ++i) { Value offsetContrib = arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]); baseOffset = arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); } Value indices = gatScatOp.getIndices(); VectorType vecType = cast(indices.getType()); Value strideVector = vector::BroadcastOp::create(rewriter, loc, vecType, strides.back()) .getResult(); Value stridedIndices = arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult(); Value baseVector = vector::BroadcastOp::create( rewriter, loc, VectorType::get(vecType.getShape(), rewriter.getIndexType()), baseOffset) .getResult(); return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices) .getResult(); } template < typename OpType, typename = std::enable_if_t, vector::TransferReadOp, vector::TransferWriteOp, vector::GatherOp, vector::ScatterOp>::value>> // Convert memref to i64 base pointer static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) { Location loc = xferOp.getLoc(); auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create( rewriter, loc, xferOp.getBase()) .getResult(); return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), indexPtr) .getResult(); } static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, PatternRewriter &rewriter) { Location loc = readOp.getLoc(); VectorType vectorType = readOp.getVectorType(); ArrayRef vectorShape = vectorType.getShape(); auto memrefType = dyn_cast(readOp.getShapedType()); if (!memrefType) return rewriter.notifyMatchFailure(readOp, "Expected memref source"); auto meta = computeMemrefMeta(readOp, rewriter); if (meta.first.empty()) return rewriter.notifyMatchFailure(readOp, "Failed to compute strides"); Value localOffsets = computeOffsets(readOp, rewriter, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(readOp, rewriter); Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape); auto gatherOp = xegpu::LoadGatherOp::create( rewriter, loc, vectorType, flatMemref, localOffsets, mask, /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, /*l3_hint=*/xegpu::CachePolicyAttr{}); rewriter.replaceOp(readOp, gatherOp.getResult()); return success(); } static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) { Location loc = writeOp.getLoc(); VectorType vectorType = writeOp.getVectorType(); ArrayRef vectorShape = vectorType.getShape(); auto memrefType = dyn_cast(writeOp.getShapedType()); if (!memrefType) return rewriter.notifyMatchFailure(writeOp, "Expected memref source"); auto meta = computeMemrefMeta(writeOp, rewriter); if (meta.first.empty()) return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides"); Value localOffsets = computeOffsets(writeOp, rewriter, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(writeOp, rewriter); Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape); xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref, localOffsets, mask, /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, /*l3_hint=*/xegpu::CachePolicyAttr{}); rewriter.eraseOp(writeOp); return success(); } struct TransferReadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { Location loc = readOp.getLoc(); if (failed(transferPreconditions(rewriter, readOp))) return failure(); // TODO:This check needs to be replaced with proper uArch capability check auto chip = xegpu::getChipStr(readOp); if (chip != "pvc" && chip != "bmg") { // lower to scattered load Op if the target HW doesn't have 2d block load // support // TODO: add support for OutOfBound access if (readOp.hasOutOfBoundsDim()) return failure(); return lowerToScatteredLoadOp(readOp, rewriter); } // Perform common data transfer checks. VectorType vecTy = readOp.getVectorType(); if (failed(storeLoadPreconditions(rewriter, readOp, vecTy))) return failure(); bool isOutOfBounds = readOp.hasOutOfBoundsDim(); if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) return rewriter.notifyMatchFailure( readOp, "Unsupported non-zero padded out-of-bounds read"); AffineMap readMap = readOp.getPermutationMap(); bool isTransposeLoad = !readMap.isMinorIdentity(); Type elementType = vecTy.getElementType(); unsigned minTransposeBitWidth = 32; if (isTransposeLoad && elementType.getIntOrFloatBitWidth() < minTransposeBitWidth) return rewriter.notifyMatchFailure( readOp, "Unsupported data type for transposition"); // If load is transposed, get the base shape for the tensor descriptor. SmallVector descShape(vecTy.getShape()); if (isTransposeLoad) std::reverse(descShape.begin(), descShape.end()); auto descType = xegpu::TensorDescType::get( descShape, elementType, /*array_length=*/1, /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor(rewriter, loc, descType, dyn_cast>(readOp.getBase()), readOp.getIndices()); DenseI64ArrayAttr transposeAttr = !isTransposeLoad ? nullptr : DenseI64ArrayAttr::get(rewriter.getContext(), ArrayRef{1, 0}); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(readOp, loadOp); return success(); } }; struct TransferWriteLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { Location loc = writeOp.getLoc(); if (failed(transferPreconditions(rewriter, writeOp))) return failure(); // TODO:This check needs to be replaced with proper uArch capability check auto chip = xegpu::getChipStr(writeOp); if (chip != "pvc" && chip != "bmg") { // lower to scattered store Op if the target HW doesn't have 2d block // store support // TODO: add support for OutOfBound access if (writeOp.hasOutOfBoundsDim()) return failure(); return lowerToScatteredStoreOp(writeOp, rewriter); } // Perform common data transfer checks. VectorType vecTy = writeOp.getVectorType(); if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy))) return failure(); AffineMap map = writeOp.getPermutationMap(); if (!map.isMinorIdentity()) return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor(rewriter, loc, descType, dyn_cast>(writeOp.getBase()), writeOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(writeOp, storeOp); return success(); } }; struct GatherLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::GatherOp gatherOp, PatternRewriter &rewriter) const override { auto srcTy = dyn_cast(gatherOp.getBase().getType()); if (!srcTy) return rewriter.notifyMatchFailure(gatherOp, "Expects memref source"); Location loc = gatherOp.getLoc(); VectorType vectorType = gatherOp.getVectorType(); auto meta = computeMemrefMeta(gatherOp, rewriter); if (meta.first.empty()) return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides"); Value localOffsets = computeOffsets(rewriter, gatherOp, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(gatherOp, rewriter); auto xeGatherOp = xegpu::LoadGatherOp::create( rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(), /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, /*l3_hint=*/xegpu::CachePolicyAttr{}); auto selectOp = arith::SelectOp::create(rewriter, loc, gatherOp.getMask(), xeGatherOp.getResult(), gatherOp.getPassThru()); rewriter.replaceOp(gatherOp, selectOp.getResult()); return success(); } }; struct ScatterLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ScatterOp scatterOp, PatternRewriter &rewriter) const override { auto srcTy = dyn_cast(scatterOp.getBase().getType()); if (!srcTy) return rewriter.notifyMatchFailure(scatterOp, "Expects memref source"); Location loc = scatterOp.getLoc(); auto meta = computeMemrefMeta(scatterOp, rewriter); if (meta.first.empty()) return rewriter.notifyMatchFailure(scatterOp, "Failed to compute strides"); Value localOffsets = computeOffsets(rewriter, scatterOp, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(scatterOp, rewriter); xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(), flatMemref, localOffsets, scatterOp.getMask(), /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, /*l3_hint=*/xegpu::CachePolicyAttr{}); rewriter.eraseOp(scatterOp); return success(); } }; struct LoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { Location loc = loadOp.getLoc(); VectorType vecTy = loadOp.getResult().getType(); if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy))) return failure(); // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto loadNdOp = xegpu::LoadNdOp::create( rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(loadOp, loadNdOp); return success(); } }; struct StoreLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override { Location loc = storeOp.getLoc(); TypedValue vector = storeOp.getValueToStore(); VectorType vecTy = vector.getType(); if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy))) return failure(); // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(storeOp, storeNdOp); return success(); } }; struct ContractionLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { Location loc = contractOp.getLoc(); if (contractOp.getKind() != vector::CombiningKind::ADD) return rewriter.notifyMatchFailure(contractOp, "Expects add combining kind"); TypedValue acc = contractOp.getAcc(); VectorType accType = dyn_cast(acc.getType()); if (!accType || accType.getRank() != 2) return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector"); // Accept only plain 2D data layout. // VNNI packing is applied to DPAS as a separate lowering step. TypedValue lhs = contractOp.getLhs(); TypedValue rhs = contractOp.getRhs(); if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2) return rewriter.notifyMatchFailure(contractOp, "Expects lhs and rhs 2D vectors"); if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr())) return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps"); auto dpasOp = xegpu::DpasOp::create(rewriter, loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc}); rewriter.replaceOp(contractOp, dpasOp); return success(); } }; struct ConvertVectorToXeGPUPass : public impl::ConvertVectorToXeGPUBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorToXeGPUConversionPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; } // namespace void mlir::populateVectorToXeGPUConversionPatterns( RewritePatternSet &patterns) { patterns .add( patterns.getContext()); }