aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp
blob: 378037e9494f4a2000c5ce232ae97802b2b78e55 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
//===- SetRuntimeCallAttributes.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
/// \file
/// SetRuntimeCallAttributesPass looks for fir.call operations
/// that are calling into Fortran runtime, and tries to set different
/// attributes on them to enable more optimizations in LLVM backend
/// (granted that they are preserved all the way to LLVM IR).
/// This pass is currently only attaching fir.call wide atttributes,
/// such as ones corresponding to llvm.memory, nosync, nocallbac, etc.
/// It is not designed to attach attributes to the arguments and the results
/// of a call.
//===----------------------------------------------------------------------===//
#include "flang/Common/static-multimap-view.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Runtime/io-api.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"

namespace fir {
#define GEN_PASS_DEF_SETRUNTIMECALLATTRIBUTES
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

#define DEBUG_TYPE "set-runtime-call-attrs"

using namespace Fortran::runtime;
using namespace Fortran::runtime::io;

#define mkIOKey(X) FirmkKey(IONAME(X))
#define mkRTKey(X) FirmkKey(RTNAME(X))

// Return LLVM dialect MemoryEffectsAttr for the given Fortran runtime call.
// This function is computing a generic value of this attribute
// by analyzing the arguments and their types.
// It tries to figure out if an "indirect" memory access is possible
// during this call. If it is not possible, then the memory effects
// are:
//   * other = NoModRef
//   * argMem = ModRef
//   * inaccessibleMem = ModRef
//
// Otherwise, it returns an empty attribute meaning ModRef for all kinds
// of memory.
//
// The attribute deduction is conservative in a sense that it applies
// to most of the runtime calls, but it may still be incorrect for some
// runtime calls.
static mlir::LLVM::MemoryEffectsAttr getGenericMemoryAttr(fir::CallOp callOp) {
  bool maybeIndirectAccess = false;
  for (auto arg : callOp.getArgOperands()) {
    mlir::Type argType = arg.getType();
    if (mlir::isa<fir::BaseBoxType>(argType)) {
      // If it is a null/absent box, then this particular call
      // cannot access memory indirectly through the box's
      // base_addr.
      auto def = arg.getDefiningOp();
      if (!mlir::isa_and_nonnull<fir::ZeroOp, fir::AbsentOp>(def)) {
        maybeIndirectAccess = true;
        break;
      }
    }
    if (auto refType = mlir::dyn_cast<fir::ReferenceType>(argType)) {
      if (!fir::isa_trivial(refType.getElementType())) {
        maybeIndirectAccess = true;
        break;
      }
    }
    if (auto ptrType = mlir::dyn_cast<mlir::LLVM::LLVMPointerType>(argType)) {
      maybeIndirectAccess = true;
      break;
    }
  }
  if (!maybeIndirectAccess) {
    return mlir::LLVM::MemoryEffectsAttr::get(
        callOp->getContext(),
        {/*other=*/mlir::LLVM::ModRefInfo::NoModRef,
         /*argMem=*/mlir::LLVM::ModRefInfo::ModRef,
         /*inaccessibleMem=*/mlir::LLVM::ModRefInfo::ModRef});
  }

  return {};
}

namespace {
class SetRuntimeCallAttributesPass
    : public fir::impl::SetRuntimeCallAttributesBase<
          SetRuntimeCallAttributesPass> {
public:
  void runOnOperation() override;
};

// A helper to match a type against a list of types.
template <typename T, typename... Ts>
constexpr bool IsAny = std::disjunction_v<std::is_same<T, Ts>...>;
} // end anonymous namespace

// MemoryAttrDesc type provides get() method for computing
// mlir::LLVM::MemoryEffectsAttr for the given Fortran runtime call.
// If needed, add specializations for particular runtime calls.
namespace {
// Default implementation just uses getGenericMemoryAttr().
// Note that it may be incorrect for some runtime calls.
template <typename KEY, typename Enable = void>
struct MemoryAttrDesc {
  static mlir::LLVM::MemoryEffectsAttr get(fir::CallOp callOp) {
    return getGenericMemoryAttr(callOp);
  }
};
} // end anonymous namespace

// NosyncAttrDesc type provides get() method for computing
// LLVM nosync attribute for the given call.
namespace {
// Default implementation always returns LLVM nosync.
// This should be true for the majority of the Fortran runtime calls.
template <typename KEY, typename Enable = void>
struct NosyncAttrDesc {
  static std::optional<mlir::NamedAttribute> get(fir::CallOp callOp) {
    // TODO: replace llvm.nosync with an LLVM dialect callback.
    return mlir::NamedAttribute("llvm.nosync",
                                mlir::UnitAttr::get(callOp->getContext()));
  }
};
} // end anonymous namespace

// NocallbackAttrDesc type provides get() method for computing
// LLVM nocallback attribute for the given call.
namespace {
// Default implementation always returns LLVM nocallback.
// It must be specialized for Fortran runtime functions that may call
// user functions during their execution (e.g. defined IO, assignment).
template <typename KEY, typename Enable = void>
struct NocallbackAttrDesc {
  static std::optional<mlir::NamedAttribute> get(fir::CallOp callOp) {
    // TODO: replace llvm.nocallback with an LLVM dialect callback.
    return mlir::NamedAttribute("llvm.nocallback",
                                mlir::UnitAttr::get(callOp->getContext()));
  }
};

// Derived types IO may call back into a Fortran module.
// This specialization is conservative for Input/OutputDerivedType,
// and it might be improved by checking if the NonTbpDefinedIoTable
// pointer argument is null.
template <typename KEY>
struct NocallbackAttrDesc<
    KEY, std::enable_if_t<
             IsAny<KEY, mkIOKey(OutputDerivedType), mkIOKey(InputDerivedType),
                   mkIOKey(OutputNamelist), mkIOKey(InputNamelist)>>> {
  static std::optional<mlir::NamedAttribute> get(fir::CallOp) {
    return std::nullopt;
  }
};
} // end anonymous namespace

namespace {
// RuntimeFunction provides different callbacks that compute values
// of fir.call attributes for a Fortran runtime function.
struct RuntimeFunction {
  using MemoryAttrGeneratorTy = mlir::LLVM::MemoryEffectsAttr (*)(fir::CallOp);
  using NamedAttrGeneratorTy =
      std::optional<mlir::NamedAttribute> (*)(fir::CallOp);
  using Key = std::string_view;
  constexpr operator Key() const { return key; }
  Key key;
  MemoryAttrGeneratorTy memoryAttrGenerator;
  NamedAttrGeneratorTy nosyncAttrGenerator;
  NamedAttrGeneratorTy nocallbackAttrGenerator;
};

// Helper type to create a RuntimeFunction descriptor given
// the KEY and a function name.
template <typename KEY>
struct RuntimeFactory {
  static constexpr RuntimeFunction create(const char name[]) {
    // GCC 7 does not recognize this as a constant expression:
    //   ((const char *)RuntimeFunction<>::name) == nullptr
    // This comparison comes from the basic_string_view(const char *)
    // constructor. We have to use the other constructor
    // that takes explicit length parameter.
    return RuntimeFunction{
        std::string_view{name, std::char_traits<char>::length(name)},
        MemoryAttrDesc<KEY>::get, NosyncAttrDesc<KEY>::get,
        NocallbackAttrDesc<KEY>::get};
  }
};
} // end anonymous namespace

#define KNOWN_IO_FUNC(X) RuntimeFactory<mkIOKey(X)>::create(mkIOKey(X)::name)
#define KNOWN_RUNTIME_FUNC(X)                                                  \
  RuntimeFactory<mkRTKey(X)>::create(mkRTKey(X)::name)

// A table of RuntimeFunction descriptors for all recognized
// Fortran runtime functions.
static constexpr RuntimeFunction runtimeFuncsTable[] = {
#include "flang/Optimizer/Transforms/RuntimeFunctions.inc"
};

static constexpr Fortran::common::StaticMultimapView<RuntimeFunction>
    runtimeFuncs(runtimeFuncsTable);
static_assert(runtimeFuncs.Verify() && "map must be sorted");

// Set attributes for the given Fortran runtime call.
// The symbolTable is used to cache the name lookups in the module.
static void setRuntimeCallAttributes(fir::CallOp callOp,
                                     mlir::SymbolTableCollection &symbolTable) {
  auto iface = mlir::cast<mlir::CallOpInterface>(callOp.getOperation());
  auto funcOp = mlir::dyn_cast_or_null<mlir::func::FuncOp>(
      iface.resolveCallableInTable(&symbolTable));

  if (!funcOp || !funcOp->hasAttrOfType<mlir::UnitAttr>(
                     fir::FIROpsDialect::getFirRuntimeAttrName()))
    return;

  llvm::StringRef name = funcOp.getName();
  if (auto range = runtimeFuncs.equal_range(name);
      range.first != range.second) {
    // There should not be duplicate entries.
    assert(range.first + 1 == range.second);
    const RuntimeFunction &desc = *range.first;
    LLVM_DEBUG(llvm::dbgs()
               << "Identified runtime function call: " << desc.key << '\n');
    if (mlir::LLVM::MemoryEffectsAttr memoryAttr =
            desc.memoryAttrGenerator(callOp))
      callOp->setAttr(fir::FIROpsDialect::getFirCallMemoryAttrName(),
                      memoryAttr);
    if (auto attr = desc.nosyncAttrGenerator(callOp))
      callOp->setAttr(attr->getName(), attr->getValue());
    if (auto attr = desc.nocallbackAttrGenerator(callOp))
      callOp->setAttr(attr->getName(), attr->getValue());
    LLVM_DEBUG(llvm::dbgs() << "Operation with attrs: " << callOp << '\n');
  }
}

void SetRuntimeCallAttributesPass::runOnOperation() {
  mlir::func::FuncOp funcOp = getOperation();
  // Exit early for declarations to skip the debug output for them.
  if (funcOp.isDeclaration())
    return;
  LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
  LLVM_DEBUG(llvm::dbgs() << "Func-name:" << funcOp.getSymName() << "\n");

  mlir::SymbolTableCollection symbolTable;
  funcOp.walk([&](fir::CallOp callOp) {
    setRuntimeCallAttributes(callOp, symbolTable);
  });
  LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
}