//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUVECTORLINEARIZE #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" } // namespace xegpu } // namespace mlir #define DEBUG_TYPE "xegpu-vector-linearize" using namespace mlir; namespace { struct XeGPUVectorLinearizePass final : public xegpu::impl::XeGPUVectorLinearizeBase { void runOnOperation() override { // vector.broadcast and vector.gather requires progressive lowering { RewritePatternSet patterns(&getContext()); vector::populateVectorBroadcastLoweringPatterns(patterns); vector::populateVectorGatherLoweringPatterns(patterns); vector::populateVectorGatherToConditionalLoadPatterns(patterns); // vector.transpose lowering // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes. vector::populateVectorTransposeLoweringPatterns( patterns, vector::VectorTransposeLowering::Shuffle16x16); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } // Unroll load/store from to (d1*d2*...*d(k-1)) slices of // <1x1x...x1xdk>. { RewritePatternSet patterns(&getContext()); vector::UnrollVectorOptions vectorOptions; vectorOptions.setNativeShapeFn( [](Operation *op) -> std::optional> { auto extractVectorType = [](Operation *op) -> VectorType { if (auto loadOp = dyn_cast(op)) return loadOp.getVectorType(); if (auto storeOp = dyn_cast(op)) return storeOp.getVectorType(); return nullptr; }; VectorType vecType = extractVectorType(op); if (!vecType) return std::nullopt; // Only handle rank >= 2 so we actually unroll something. int64_t rank = vecType.getRank(); if (rank < 2) return std::nullopt; ArrayRef shape = vecType.getShape(); // Produce native shape: 1 x 1 x ... x (original last dim). SmallVector native(rank, 1); native.back() = shape.back(); return native; }); vector::populateVectorUnrollPatterns(patterns, vectorOptions); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { LDBG() << "Unroll failed."; return signalPassFailure(); } } // Use vector linearization patterns { MLIRContext &context = getContext(); TypeConverter converter; RewritePatternSet patterns(&context); ConversionTarget target(context); vector::populateForVectorLinearize(converter, target); vector::populateVectorLinearizeBasePatterns(converter, target, patterns); vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, patterns); scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { LDBG() << "Linearization failed."; return signalPassFailure(); } } } }; } // namespace