1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
|
//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
/// A pattern that converts the result and operand types, attributes, and region
/// arguments of an OpenMP operation to the LLVM dialect.
///
/// Attributes are copied verbatim by default, and only translated if they are
/// type attributes.
///
/// Region bodies, if any, are not modified and expected to either be processed
/// by the conversion infrastructure or already contain ops compatible with LLVM
/// dialect types.
template <typename T>
struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
OpenMPOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<T>(typeConverter, benefit) {
// Operations using CanonicalLoopInfoType are lowered only by
// mlir::translateModuleToLLVMIR() using the OpenMPIRBuilder. Until then,
// the type and operations using it must be preserved.
typeConverter.addConversion(
[&](::mlir::omp::CanonicalLoopInfoType type) { return type; });
}
LogicalResult
matchAndRewrite(T op, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Translate result types.
const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
SmallVector<Type> resTypes;
if (failed(converter->convertTypes(op->getResultTypes(), resTypes)))
return failure();
// Translate type attributes.
// They are kept unmodified except if they are type attributes.
SmallVector<NamedAttribute> convertedAttrs;
for (NamedAttribute attr : op->getAttrs()) {
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
Type convertedType = converter->convertType(typeAttr.getValue());
convertedAttrs.emplace_back(attr.getName(),
TypeAttr::get(convertedType));
} else {
convertedAttrs.push_back(attr);
}
}
// Translate operands.
SmallVector<Value> convertedOperands;
convertedOperands.reserve(op->getNumOperands());
for (auto [originalOperand, convertedOperand] :
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
if (!originalOperand)
return failure();
// TODO: Revisit whether we need to trigger an error specifically for this
// set of operations. Consider removing this check or updating the list.
if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
omp::FlushOp, omp::MapBoundsOp,
omp::ThreadprivateOp>::value) {
if (isa<MemRefType>(originalOperand.getType())) {
// TODO: Support memref type in variable operands
return rewriter.notifyMatchFailure(op, "memref is not supported yet");
}
}
convertedOperands.push_back(convertedOperand);
}
// Create new operation.
auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands,
convertedAttrs);
// Translate regions.
for (auto [originalRegion, convertedRegion] :
llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
rewriter.inlineRegionBefore(originalRegion, convertedRegion,
convertedRegion.end());
if (failed(rewriter.convertRegionTypes(&convertedRegion,
*this->getTypeConverter())))
return failure();
}
// Delete old operation and replace result uses with those of the new one.
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
} // namespace
void mlir::configureOpenMPToLLVMConversionLegality(
ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
target.addDynamicallyLegalOp<
#define GET_OP_LIST
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes()) &&
llvm::all_of(op->getRegions(),
[&](Region ®ion) {
return typeConverter.isLegal(®ion);
}) &&
llvm::all_of(op->getAttrs(), [&](NamedAttribute attr) {
auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
return !typeAttr || typeConverter.isLegal(typeAttr.getValue());
});
});
}
/// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
/// passed as template argument.
template <typename... Ts>
static inline RewritePatternSet &
addOpenMPOpConversions(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
return patterns.add<OpenMPOpConversion<Ts>...>(converter);
}
void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// This type is allowed when converting OpenMP to LLVM Dialect, it carries
// bounds information for map clauses and the operation and type are
// discarded on lowering to LLVM-IR from the OpenMP dialect.
converter.addConversion(
[&](omp::MapBoundsType type) -> Type { return type; });
// Add conversions for all OpenMP operations.
addOpenMPOpConversions<
#define GET_OP_LIST
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
>(converter, patterns);
}
namespace {
struct ConvertOpenMPToLLVMPass
: public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace
void ConvertOpenMPToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to OpenMP operations with LLVM IR dialect
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
arith::populateArithToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
cf::populateAssertToLLVMConversionPattern(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
omp::TaskyieldOp, omp::TerminatorOp>();
configureOpenMPToLLVMConversionLegality(target, converter);
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
namespace {
/// Implement the interface to convert OpenMP to LLVM.
struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
void loadDependentDialects(MLIRContext *context) const final {
context->loadDialect<LLVM::LLVMDialect>();
}
/// Hook for derived dialect interface to provide conversion patterns
/// and mark dialect legal for the conversion target.
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
configureOpenMPToLLVMConversionLegality(target, typeConverter);
populateOpenMPToLLVMConversionPatterns(typeConverter, patterns);
}
};
} // namespace
void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
});
}
|