aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
blob: 69a3d98bc09e4d6a2cf4a6fb2b021e678a113615 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {

/// Replace this op (which is expected to have 1 result) with the operands.
struct TestDirectReplacementOp : public ConversionPattern {
  TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
      : ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const final {
    if (op->getNumResults() != 1)
      return failure();
    rewriter.replaceOpWithMultiple(op, {operands});
    return success();
  }
};

struct TestLLVMLegalizePatternsPass
    : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)

  TestLLVMLegalizePatternsPass() = default;
  TestLLVMLegalizePatternsPass(const TestLLVMLegalizePatternsPass &other)
      : PassWrapper(other) {}

  StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
  StringRef getDescription() const final {
    return "Run LLVM dialect legalization patterns";
  }

  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<LLVM::LLVMDialect>();
  }

  void runOnOperation() override {
    MLIRContext *ctx = &getContext();

    // Set up type converter.
    LLVMTypeConverter converter(ctx);
    converter.addConversion(
        [&](IntegerType type, SmallVectorImpl<Type> &result) {
          if (type.isInteger(17)) {
            // Convert i17 -> (i18, i18).
            result.append(2, Builder(ctx).getIntegerType(18));
            return success();
          }

          result.push_back(type);
          return success();
        });

    // Populate patterns.
    mlir::RewritePatternSet patterns(ctx);
    patterns.add<TestDirectReplacementOp>(ctx, converter);
    arith::populateArithToLLVMConversionPatterns(converter, patterns);
    populateFuncToLLVMConversionPatterns(converter, patterns);
    cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);

    // Define the conversion target used for the test.
    ConversionTarget target(*ctx);
    target.addLegalOp(OperationName("test.legal_op", ctx));
    target.addLegalDialect<LLVM::LLVMDialect>();
    target.addDynamicallyLegalOp<func::FuncOp>(
        [&](func::FuncOp funcOp) { return funcOp->hasAttr("is_legal"); });

    // Handle a partial conversion.
    DenseSet<Operation *> unlegalizedOps;
    ConversionConfig config;
    config.unlegalizedOps = &unlegalizedOps;
    config.allowPatternRollback = allowPatternRollback;
    if (failed(applyPartialConversion(getOperation(), target,
                                      std::move(patterns), config)))
      getOperation()->emitError() << "applyPartialConversion failed";
  }

  Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
                                    llvm::cl::desc("Allow pattern rollback"),
                                    llvm::cl::init(true)};
};
} // namespace

//===----------------------------------------------------------------------===//
// PassRegistration
//===----------------------------------------------------------------------===//

namespace mlir {
namespace test {
void registerTestLLVMLegalizePatternsPass() {
  PassRegistration<TestLLVMLegalizePatternsPass>();
}
} // namespace test
} // namespace mlir