//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToAMX/VectorToAMX.h" #include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/DebugLog.h" #include namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOAMX #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; #define DEBUG_TYPE "vector-to-amx" namespace { /// Return true if vector shape is compatible with AMX tiles. /// The validation accounts for VNNI packing. static bool verifyAmxShape(VectorType vec) { // Check overall shape: // - 2D for plain layout input or output // - 3D for VNNI packed input if (vec.getRank() != 2 && vec.getRank() != 3) return false; ArrayRef shape = vec.getShape(); int64_t rows = shape[0]; int64_t cols = shape[1]; unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth(); // 3D shape indicates VNNI packed layout. if (vec.getRank() == 3) { int64_t vnniFactor = 32 / elemBitWidth; if (shape.back() != vnniFactor) { LDBG() << "invalid VNNI packing factor"; return false; } cols *= vnniFactor; } // AMX tile supports up to 16 rows of 64 bytes each. constexpr unsigned maxRows = 16; constexpr unsigned maxBitsPerRow = 64 * 8; return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow; } /// Check if contraction operands are in AMX-compatible packed VNNI layout. static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter, vector::ContractionOp contractOp) { VectorType accType = dyn_cast(contractOp.getAcc().getType()); if (!accType || accType.getRank() != 2) return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector"); // Expect 3D inputs for VNNI packed data. VectorType lhsType = contractOp.getLhs().getType(); VectorType rhsType = contractOp.getRhs().getType(); if (lhsType.getRank() != 3 || rhsType.getRank() != 3) return rewriter.notifyMatchFailure(contractOp, "Expects lhs and rhs 3D vectors"); // Check if shapes are compatible with AMX tile. if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) || !verifyAmxShape(accType)) return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape"); // Validate affine maps. // // Iterators can be ordered arbitrarily. Indexing map positions are based on // operands' target shapes. // The matrix layouts must match the following: // - matrix A - [M]x[K/vnniFactor]x[vnniFactor] // - matrix B - [K/vnniFactor]x[N]x[vnniFactor] // - matrix C - [M]x[N] SmallVector indexingMaps = contractOp.getIndexingMapsArray(); AffineMap mapA = indexingMaps[0]; AffineMap mapB = indexingMaps[1]; if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 || mapB.getNumResults() != 3) return rewriter.notifyMatchFailure(contractOp, "Invalid input indexing maps"); FailureOr dims = linalg::inferContractionDims(indexingMaps); if (failed(dims)) return rewriter.notifyMatchFailure(contractOp, "Failed to infer contraction dims"); // Two reduction dimensions are expected: // - one for the K dimension // - one for the VNNI factor if (dims->k.size() != 2) return rewriter.notifyMatchFailure(contractOp, "Expected two reduction dims"); assert(dims->m.size() == 1 && dims->n.size() == 1 && "Invalid parallel contraction dims"); SmallVector iteratorTypes = contractOp.getIteratorTypesArray(); // Check VNNI dim maps - the innermost dim for A and B inputs. auto vnniDimA = dyn_cast(mapA.getResult(2)); auto vnniDimB = dyn_cast(mapB.getResult(2)); if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction) return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map"); // Check K dim maps - non-transposed row-major layout. auto redDimA = dyn_cast(mapA.getResult(1)); auto redDimB = dyn_cast(mapB.getResult(0)); if (!redDimA || !redDimB || redDimA != redDimB || iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction) return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map"); // Check M and N dim maps - map to non-transposed output. AffineMap mapC = indexingMaps[2]; auto mDimC = dyn_cast(mapC.getResult(0)); auto nDimC = dyn_cast(mapC.getResult(1)); if (!mDimC || !nDimC) return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps"); auto parallelDimA = dyn_cast(mapA.getResult(0)); if (!parallelDimA || iteratorTypes[parallelDimA.getPosition()] != vector::IteratorType::parallel || parallelDimA != mDimC) return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map"); auto parallelDimB = dyn_cast(mapB.getResult(1)); if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] != vector::IteratorType::parallel || parallelDimB != nDimC) return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map"); return success(); } /// Validate contraction operands for AMX lowering. static LogicalResult validateOperands(PatternRewriter &rewriter, vector::ContractionOp contractOp) { VectorType accType = dyn_cast(contractOp.getAcc().getType()); if (!accType) return rewriter.notifyMatchFailure(contractOp, "Expects vector acc"); // Check if operand types are compatible with AMX compute ops. bool validElemTypes = false; Type lhsElemType = contractOp.getLhs().getType().getElementType(); Type rhsElemType = contractOp.getRhs().getType().getElementType(); Type accElemType = accType.getElementType(); if (accElemType.isInteger(32)) { validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8); } else if (accElemType.isF32()) { validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) || (lhsElemType.isBF16() && rhsElemType.isBF16()); } if (!validElemTypes) return rewriter.notifyMatchFailure(contractOp, "Invalid combination of operand types"); if (failed(isAmxVnniLayout(rewriter, contractOp))) return failure(); return success(); } /// Collapse the two innermost dimensions together. static TypedValue collapseLastDim(PatternRewriter &rewriter, TypedValue memref) { int64_t rank = memref.getType().getRank(); SmallVector reassocIndices; for (auto i : llvm::seq(0, rank - 2)) reassocIndices.push_back({i}); reassocIndices.push_back({rank - 2, rank - 1}); return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref, reassocIndices); } /// Attempt to create an AMX tile load/store operation equivalent to the given /// vector transfer `xfer` op. /// This approach allows to skip longer route through registers and a temporary /// buffer otherwise required to move data to/from an AMX tile. static Operation * loadStoreFromTransfer(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool isPacked, TypedValue tileToStore = nullptr) { if (!xferOp || !isa(xferOp)) return nullptr; if (xferOp.hasOutOfBoundsDim() || !xferOp.getPermutationMap().isMinorIdentity()) return nullptr; // Extra checks in case of a write op. // Stores must not be packed. if (isa(xferOp) && (!tileToStore || isPacked || tileToStore.getType().getShape() != xferOp.getVectorType().getShape())) return nullptr; // Check for a memref source buffer. // AMX data transfer requires at least 2D shape to correctly // infer stride between rows. Value base = xferOp.getBase(); auto memTy = dyn_cast(base.getType()); int64_t memRank = memTy.getRank(); if (!memTy || memRank < 2) return nullptr; // Check that the source buffer has enough contiguous elements to load whole // AMX tile row. // // To ensure correctness, the validation is conservative and expects the // buffer's innermost dimensions to be statically known, equal to or larger // than the vector row length, and equal to the VNNI dimension if applicable. // // This check could be relaxed to accept more arbitrarily shaped buffers as // long as there are enough contiguous elements to load a whole row. if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1)) return nullptr; VectorType vecTy = xferOp.getVectorType(); ArrayRef vecShape = vecTy.getShape(); ArrayRef memShape = memTy.getShape(); if (memShape.back() == ShapedType::kDynamic || memShape.back() < vecShape.back()) return nullptr; if (isPacked && (memShape.back() != vecShape.back() || memShape[memShape.size() - 2] == ShapedType::kDynamic || memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2])) return nullptr; // Load values directly from the buffer to an AMX tile. PatternRewriter::InsertionGuard g(rewriter); rewriter.setInsertionPoint(xferOp); Location loc = xferOp.getLoc(); // Create a subview of the source buffer based on the transfer op to resolve // offsets. SmallVector strides(memRank, rewriter.getIndexAttr(1)); int64_t vecRank = vecTy.getRank(); assert(memRank >= vecRank && "Expects buffer to be the same or greater rank than vector"); SmallVector shape(memRank - vecRank, 1); shape.append(vecShape.begin(), vecShape.end()); TypedValue src = memref::SubViewOp::create( rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()), getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides) .getResult(); // Collapse the VNNI dimension in case of packing. if (isPacked) src = collapseLastDim(rewriter, src); int64_t rows = vecShape[0]; int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1, std::multiplies()); auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); Value zeroIndex = rewriter.createOrFold(loc, 0); SmallVector tileIndicides(src.getType().getRank(), zeroIndex); Operation *amxTileOp = nullptr; if (isa(xferOp)) { amxTileOp = amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides); } else if (isa(xferOp)) { amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides, tileToStore); } else { llvm_unreachable("unsupported vector transfer op"); } return amxTileOp; } /// Attempt to create an AMX tile load operation equivalent to the given /// vector transfer `readOp`. /// Returns loaded AMX tile if successful. static FailureOr> loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp, bool isPacked) { amx::TileLoadOp loadOp = dyn_cast_if_present( loadStoreFromTransfer(rewriter, readOp, isPacked)); if (!loadOp) return failure(); return loadOp.getRes(); } /// Attempt to create an AMX tile store operation equivalent to the given /// vector transfer `writeOp`. static LogicalResult storeFromTransfer(PatternRewriter &rewriter, vector::TransferWriteOp writeOp, TypedValue tileToStore) { return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false, tileToStore)); } /// Load vector values to an AMX tile. static TypedValue loadTile(PatternRewriter &rewriter, TypedValue vec) { Location loc = vec.getLoc(); VectorType vecTy = vec.getType(); bool isPacked = vecTy.getRank() == 3; // Try to load tile directly from vector producer's buffer. auto readOp = vec.getDefiningOp(); FailureOr> tile = loadFromTransfer(rewriter, readOp, isPacked); if (succeeded(tile)) return *tile; // Transfer the vector to a tile through an intermediate buffer. Value buf = memref::AllocaOp::create( rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType())); Value zeroIndex = rewriter.createOrFold(loc, 0); SmallVector indices(vecTy.getRank(), zeroIndex); vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices); // Collapse the VNNI dimension in case of packing. if (isPacked) buf = collapseLastDim(rewriter, cast>(buf)); ArrayRef shape = vecTy.getShape(); int64_t rows = shape[0]; int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1, std::multiplies()); auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); return amx::TileLoadOp::create(rewriter, loc, tileType, buf, {zeroIndex, zeroIndex}); } /// Store an AMX tile in a vector. static TypedValue storeTile(PatternRewriter &rewriter, TypedValue tile) { Location loc = tile.getLoc(); // Transfer the tile to a vector through an intermediate buffer. amx::TileType tileTy = tile.getType(); Value buf = memref::AllocaOp::create( rewriter, loc, MemRefType::get(tileTy.getShape(), tileTy.getElementType())); Value zeroIndex = rewriter.createOrFold(loc, 0); SmallVector indices(2, zeroIndex); amx::TileStoreOp::create(rewriter, loc, buf, indices, tile); auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType()); return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {}); } struct ContractionToAMX : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { Location loc = contractOp.getLoc(); if (contractOp.getKind() != vector::CombiningKind::ADD) return rewriter.notifyMatchFailure(contractOp, "Expects add combining kind"); if (failed(validateOperands(rewriter, contractOp))) return failure(); TypedValue lhsTile = loadTile(rewriter, contractOp.getLhs()); TypedValue rhsTile = loadTile(rewriter, contractOp.getRhs()); auto acc = dyn_cast>(contractOp.getAcc()); assert(acc && "Invalid accumulator type"); TypedValue accTile = loadTile(rewriter, acc); TypedValue tileMul; if (acc.getType().getElementType().isFloat()) { tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(), lhsTile, rhsTile, accTile); } else { tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(), lhsTile, rhsTile, accTile); } // If the contraction result is only written back to memory, try to replace // the vector op with an AMX store directly. Value res = contractOp.getResult(); if (res.hasOneUse()) { auto writeOp = dyn_cast(*res.getUsers().begin()); LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul); if (succeeded(storeRes)) { rewriter.eraseOp(writeOp); rewriter.eraseOp(contractOp); return success(); } } // Load the result back into a vector. Value newResult = storeTile(rewriter, tileMul); rewriter.replaceOp(contractOp, newResult); return success(); } }; struct ConvertVectorToAMXPass : public impl::ConvertVectorToAMXBase { void runOnOperation() override { MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); populateVectorToAMXConversionPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; } // namespace void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }