1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
|
//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
} // namespace xegpu
} // namespace mlir
#define DEBUG_TYPE "xegpu-vector-linearize"
using namespace mlir;
namespace {
struct XeGPUVectorLinearizePass final
: public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
void runOnOperation() override {
// vector.broadcast and vector.gather requires progressive lowering
{
RewritePatternSet patterns(&getContext());
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorGatherLoweringPatterns(patterns);
vector::populateVectorGatherToConditionalLoadPatterns(patterns);
// vector.transpose lowering
// Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransposeLowering::Shuffle16x16);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
// Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
// <1x1x...x1xdk>.
{
RewritePatternSet patterns(&getContext());
vector::UnrollVectorOptions vectorOptions;
vectorOptions.setNativeShapeFn(
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
auto extractVectorType = [](Operation *op) -> VectorType {
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
return loadOp.getVectorType();
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
return storeOp.getVectorType();
return nullptr;
};
VectorType vecType = extractVectorType(op);
if (!vecType)
return std::nullopt;
// Only handle rank >= 2 so we actually unroll something.
int64_t rank = vecType.getRank();
if (rank < 2)
return std::nullopt;
ArrayRef<int64_t> shape = vecType.getShape();
// Produce native shape: 1 x 1 x ... x (original last dim).
SmallVector<int64_t> native(rank, 1);
native.back() = shape.back();
return native;
});
vector::populateVectorUnrollPatterns(patterns, vectorOptions);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
LDBG() << "Unroll failed.";
return signalPassFailure();
}
}
// Use vector linearization patterns
{
MLIRContext &context = getContext();
TypeConverter converter;
RewritePatternSet patterns(&context);
ConversionTarget target(context);
vector::populateForVectorLinearize(converter, target);
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
patterns);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
LDBG() << "Linearization failed.";
return signalPassFailure();
}
}
}
};
} // namespace
|