aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
blob: afafbd8179affd755de950b00c08b8149cb97343 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
//===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace fir {
#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

#define DEBUG_TYPE "flang-constant-argument-globalisation-opt"

namespace {
unsigned uniqueLitId = 1;

class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
protected:
  const mlir::DominanceInfo &di;

public:
  using OpRewritePattern::OpRewritePattern;

  CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
      : OpRewritePattern(ctx), di(_di) {}

  llvm::LogicalResult
  matchAndRewrite(fir::CallOp callOp,
                  mlir::PatternRewriter &rewriter) const override {
    LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
    auto module = callOp->getParentOfType<mlir::ModuleOp>();
    bool needUpdate = false;
    fir::FirOpBuilder builder(rewriter, module);
    llvm::SmallVector<mlir::Value> newOperands;
    llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
    for (const mlir::Value &a : callOp.getArgs()) {
      auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
      // We can convert arguments that are alloca, and that has
      // the value by reference attribute. All else is just added
      // to the argument list.
      if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
        newOperands.push_back(a);
        continue;
      }

      mlir::Type varTy = alloca.getInType();
      assert(!fir::hasDynamicSize(varTy) &&
             "only expect statically sized scalars to be by value");

      // Find immediate store with const argument
      mlir::Operation *store = nullptr;
      for (mlir::Operation *s : alloca->getUsers()) {
        if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
          // We can only deal with ONE store - if already found one,
          // set to nullptr and exit the loop.
          if (store) {
            store = nullptr;
            break;
          }
          store = s;
        }
      }

      // If we didn't find any store, or multiple stores, add argument as is
      // and move on.
      if (!store) {
        newOperands.push_back(a);
        continue;
      }

      LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");

      mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
      // If not a constant, add to operands and move on.
      if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
        // Unable to remove alloca arg
        newOperands.push_back(a);
        continue;
      }

      LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");

      std::string globalName =
          "_global_const_." + std::to_string(uniqueLitId++);
      assert(!builder.getNamedGlobal(globalName) &&
             "We should have a unique name here");

      if (llvm::none_of(allocas,
                        [alloca](auto x) { return x.first == alloca; })) {
        allocas.push_back(std::make_pair(alloca, store));
      }

      auto loc = callOp.getLoc();
      fir::GlobalOp global = builder.createGlobalConstant(
          loc, varTy, globalName,
          [&](fir::FirOpBuilder &builder) {
            mlir::Operation *cln = definingOp->clone();
            builder.insert(cln);
            mlir::Value val =
                builder.createConvert(loc, varTy, cln->getResult(0));
            fir::HasValueOp::create(builder, loc, val);
          },
          builder.createInternalLinkage());
      mlir::Value addr = fir::AddrOfOp::create(
          builder, loc, global.resultType(), global.getSymbol());
      newOperands.push_back(addr);
      needUpdate = true;
    }

    if (needUpdate) {
      auto loc = callOp.getLoc();
      llvm::SmallVector<mlir::Type> newResultTypes;
      newResultTypes.append(callOp.getResultTypes().begin(),
                            callOp.getResultTypes().end());
      fir::CallOp newOp = fir::CallOp::create(builder, loc,
                                              callOp.getCallee().has_value()
                                                  ? callOp.getCallee().value()
                                                  : mlir::SymbolRefAttr{},
                                              newResultTypes, newOperands);
      // Copy all the attributes from the old to new op.
      newOp->setAttrs(callOp->getAttrs());
      rewriter.replaceOp(callOp, newOp);

      for (auto a : allocas) {
        if (a.first->hasOneUse()) {
          // If the alloca is only used for a store and the call operand, the
          // store is no longer required.
          rewriter.eraseOp(a.second);
          rewriter.eraseOp(a.first);
        }
      }
      LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
                              << newOp << '\n');
      return mlir::success();
    }

    // Failure here just means "we couldn't do the conversion", which is
    // perfectly acceptable to the upper layers of this function.
    return mlir::failure();
  }
};

// this pass attempts to convert immediate scalar literals in function calls
// to global constants to allow transformations such as Dead Argument
// Elimination
class ConstantArgumentGlobalisationOpt
    : public fir::impl::ConstantArgumentGlobalisationOptBase<
          ConstantArgumentGlobalisationOpt> {
public:
  ConstantArgumentGlobalisationOpt() = default;

  void runOnOperation() override {
    mlir::ModuleOp mod = getOperation();
    mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
    auto *context = &getContext();
    mlir::RewritePatternSet patterns(context);
    mlir::GreedyRewriteConfig config;
    config.setRegionSimplificationLevel(
        mlir::GreedySimplifyRegionLevel::Disabled);
    config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps);

    patterns.insert<CallOpRewriter>(context, *di);
    if (mlir::failed(
            mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
      mlir::emitError(mod.getLoc(),
                      "error in constant globalisation optimization\n");
      signalPassFailure();
    }
  }
};
} // namespace