//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering of vector operations to XeGPU dialect ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include #include namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { // Return true if value represents a zero constant. static bool isZeroConstant(Value val) { auto constant = val.getDefiningOp(); if (!constant) return false; return TypeSwitch(constant.getValue()) .Case( [](auto floatAttr) { return floatAttr.getValue().isZero(); }) .Case( [](auto intAttr) { return intAttr.getValue().isZero(); }) .Default([](auto) { return false; }); } static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, Operation *op, VectorType vecTy) { // Validate only vector as the basic vector store and load ops guarantee // XeGPU-compatible memref source. unsigned vecRank = vecTy.getRank(); if (!(vecRank == 1 || vecRank == 2)) return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector"); return success(); } static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp) { if (xferOp.getMask()) return rewriter.notifyMatchFailure(xferOp, "Masked transfer is not supported"); auto srcTy = dyn_cast(xferOp.getShapedType()); if (!srcTy) return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); // Perform common data transfer checks. VectorType vecTy = xferOp.getVectorType(); if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy))) return failure(); // Validate further transfer op semantics. SmallVector strides; int64_t offset; if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1) return rewriter.notifyMatchFailure( xferOp, "Buffer must be contiguous in the innermost dimension"); unsigned vecRank = vecTy.getRank(); if (xferOp.hasOutOfBoundsDim() && vecRank < 2) return rewriter.notifyMatchFailure( xferOp, "Boundary check is available only for block instructions."); AffineMap map = xferOp.getPermutationMap(); if (!map.isProjectedPermutation(/*allowZeroInResults=*/false)) return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map"); unsigned numInputDims = map.getNumInputs(); for (AffineExpr expr : map.getResults().take_back(vecRank)) { auto dim = dyn_cast(expr); if (dim.getPosition() < (numInputDims - vecRank)) return rewriter.notifyMatchFailure( xferOp, "Only the innermost dimensions can be accessed"); } return success(); } static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::TensorDescType descType, TypedValue src, Operation::operand_range offsets) { MemRefType srcTy = src.getType(); auto [strides, offset] = srcTy.getStridesAndOffset(); xegpu::CreateNdDescOp ndDesc; if (srcTy.hasStaticShape()) { ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, getAsOpFoldResult(offsets)); } else { // In case of any dynamic shapes, source's shape and strides have to be // explicitly provided. SmallVector sourceDims; unsigned srcRank = srcTy.getRank(); for (unsigned i = 0; i < srcRank; ++i) sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); SmallVector constOffsets; SmallVector dynOffsets; for (Value offset : offsets) { std::optional staticVal = getConstantIntValue(offset); if (!staticVal) dynOffsets.push_back(offset); constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); } SmallVector dynShapes; for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { if (shape == ShapedType::kDynamic) dynShapes.push_back(sourceDims[idx]); } // Compute strides in reverse order. SmallVector dynStrides; Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); // Last stride is guaranteed to be static and unit. for (int i = static_cast(strides.size()) - 2; i >= 0; --i) { accStride = arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); if (strides[i] == ShapedType::kDynamic) dynStrides.push_back(accStride); } std::reverse(dynStrides.begin(), dynStrides.end()); ndDesc = xegpu::CreateNdDescOp::create( rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides, DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), DenseI64ArrayAttr::get(rewriter.getContext(), strides)); } return ndDesc; } struct TransferReadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { Location loc = readOp.getLoc(); if (failed(transferPreconditions(rewriter, readOp))) return failure(); bool isOutOfBounds = readOp.hasOutOfBoundsDim(); if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) return rewriter.notifyMatchFailure( readOp, "Unsupported non-zero padded out-of-bounds read"); AffineMap readMap = readOp.getPermutationMap(); bool isTransposeLoad = !readMap.isMinorIdentity(); VectorType vecTy = readOp.getVectorType(); Type elementType = vecTy.getElementType(); unsigned minTransposeBitWidth = 32; if (isTransposeLoad && elementType.getIntOrFloatBitWidth() < minTransposeBitWidth) return rewriter.notifyMatchFailure( readOp, "Unsupported data type for transposition"); // If load is transposed, get the base shape for the tensor descriptor. SmallVector descShape(vecTy.getShape()); if (isTransposeLoad) std::reverse(descShape.begin(), descShape.end()); auto descType = xegpu::TensorDescType::get( descShape, elementType, /*array_length=*/1, /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor(rewriter, loc, descType, dyn_cast>(readOp.getBase()), readOp.getIndices()); DenseI64ArrayAttr transposeAttr = !isTransposeLoad ? nullptr : DenseI64ArrayAttr::get(rewriter.getContext(), ArrayRef{1, 0}); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(readOp, loadOp); return success(); } }; struct TransferWriteLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { Location loc = writeOp.getLoc(); if (failed(transferPreconditions(rewriter, writeOp))) return failure(); AffineMap map = writeOp.getPermutationMap(); if (!map.isMinorIdentity()) return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); VectorType vecTy = writeOp.getVectorType(); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor(rewriter, loc, descType, dyn_cast>(writeOp.getBase()), writeOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(writeOp, storeOp); return success(); } }; struct LoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { Location loc = loadOp.getLoc(); VectorType vecTy = loadOp.getResult().getType(); if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy))) return failure(); // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto loadNdOp = xegpu::LoadNdOp::create( rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(loadOp, loadNdOp); return success(); } }; struct StoreLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override { Location loc = storeOp.getLoc(); TypedValue vector = storeOp.getValueToStore(); VectorType vecTy = vector.getType(); if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy))) return failure(); // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(storeOp, storeNdOp); return success(); } }; struct ContractionLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { Location loc = contractOp.getLoc(); if (contractOp.getKind() != vector::CombiningKind::ADD) return rewriter.notifyMatchFailure(contractOp, "Expects add combining kind"); TypedValue acc = contractOp.getAcc(); VectorType accType = dyn_cast(acc.getType()); if (!accType || accType.getRank() != 2) return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector"); // Accept only plain 2D data layout. // VNNI packing is applied to DPAS as a separate lowering step. TypedValue lhs = contractOp.getLhs(); TypedValue rhs = contractOp.getRhs(); if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2) return rewriter.notifyMatchFailure(contractOp, "Expects lhs and rhs 2D vectors"); if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr())) return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps"); auto dpasOp = xegpu::DpasOp::create(rewriter, loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc}); rewriter.replaceOp(contractOp, dpasOp); return success(); } }; struct ConvertVectorToXeGPUPass : public impl::ConvertVectorToXeGPUBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorToXeGPUConversionPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; } // namespace void mlir::populateVectorToXeGPUConversionPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }