1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
|
//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
/// Replace this op (which is expected to have 1 result) with the operands.
struct TestDirectReplacementOp : public ConversionPattern {
TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
: ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
if (op->getNumResults() != 1)
return failure();
rewriter.replaceOpWithMultiple(op, {operands});
return success();
}
};
struct TestLLVMLegalizePatternsPass
: public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
TestLLVMLegalizePatternsPass() = default;
TestLLVMLegalizePatternsPass(const TestLLVMLegalizePatternsPass &other)
: PassWrapper(other) {}
StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
StringRef getDescription() const final {
return "Run LLVM dialect legalization patterns";
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<LLVM::LLVMDialect>();
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
// Set up type converter.
LLVMTypeConverter converter(ctx);
converter.addConversion(
[&](IntegerType type, SmallVectorImpl<Type> &result) {
if (type.isInteger(17)) {
// Convert i17 -> (i18, i18).
result.append(2, Builder(ctx).getIntegerType(18));
return success();
}
result.push_back(type);
return success();
});
// Populate patterns.
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
arith::populateArithToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// Define the conversion target used for the test.
ConversionTarget target(*ctx);
target.addLegalOp(OperationName("test.legal_op", ctx));
target.addLegalDialect<LLVM::LLVMDialect>();
target.addDynamicallyLegalOp<func::FuncOp>(
[&](func::FuncOp funcOp) { return funcOp->hasAttr("is_legal"); });
// Handle a partial conversion.
DenseSet<Operation *> unlegalizedOps;
ConversionConfig config;
config.unlegalizedOps = &unlegalizedOps;
config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config)))
getOperation()->emitError() << "applyPartialConversion failed";
}
Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
llvm::cl::desc("Allow pattern rollback"),
llvm::cl::init(true)};
};
} // namespace
//===----------------------------------------------------------------------===//
// PassRegistration
//===----------------------------------------------------------------------===//
namespace mlir {
namespace test {
void registerTestLLVMLegalizePatternsPass() {
PassRegistration<TestLLVMLegalizePatternsPass>();
}
} // namespace test
} // namespace mlir
|