aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/Globals.h
blob: 71a051cb3d9f51b5dc91739c4e96869194f0f5c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
//===- Globals.h - MLIR Python extension globals --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H

#include <optional>
#include <regex>
#include <string>
#include <unordered_set>
#include <vector>

#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Regex.h"

namespace mlir {
namespace python {

/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
class PyGlobals {
public:
  PyGlobals();
  ~PyGlobals();

  /// Most code should get the globals via this static accessor.
  static PyGlobals &get() {
    assert(instance && "PyGlobals is null");
    return *instance;
  }

  /// Get and set the list of parent modules to search for dialect
  /// implementation classes.
  std::vector<std::string> getDialectSearchPrefixes() {
    nanobind::ft_lock_guard lock(mutex);
    return dialectSearchPrefixes;
  }
  void setDialectSearchPrefixes(std::vector<std::string> newValues) {
    nanobind::ft_lock_guard lock(mutex);
    dialectSearchPrefixes.swap(newValues);
  }
  void addDialectSearchPrefix(std::string value) {
    nanobind::ft_lock_guard lock(mutex);
    dialectSearchPrefixes.push_back(std::move(value));
  }

  /// Loads a python module corresponding to the given dialect namespace.
  /// No-ops if the module has already been loaded or is not found. Raises
  /// an error on any evaluation issues.
  /// Note that this returns void because it is expected that the module
  /// contains calls to decorators and helpers that register the salient
  /// entities. Returns true if dialect is successfully loaded.
  bool loadDialectModule(llvm::StringRef dialectNamespace);

  /// Adds a user-friendly Attribute builder.
  /// Raises an exception if the mapping already exists and replace == false.
  /// This is intended to be called by implementation code.
  void registerAttributeBuilder(const std::string &attributeKind,
                                nanobind::callable pyFunc,
                                bool replace = false);

  /// Adds a user-friendly type caster. Raises an exception if the mapping
  /// already exists and replace == false. This is intended to be called by
  /// implementation code.
  void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
                          bool replace = false);

  /// Adds a user-friendly value caster. Raises an exception if the mapping
  /// already exists and replace == false. This is intended to be called by
  /// implementation code.
  void registerValueCaster(MlirTypeID mlirTypeID,
                           nanobind::callable valueCaster,
                           bool replace = false);

  /// Adds a concrete implementation dialect class.
  /// Raises an exception if the mapping already exists.
  /// This is intended to be called by implementation code.
  void registerDialectImpl(const std::string &dialectNamespace,
                           nanobind::object pyClass);

  /// Adds a concrete implementation operation class.
  /// Raises an exception if the mapping already exists and replace == false.
  /// This is intended to be called by implementation code.
  void registerOperationImpl(const std::string &operationName,
                             nanobind::object pyClass, bool replace = false);

  /// Returns the custom Attribute builder for Attribute kind.
  std::optional<nanobind::callable>
  lookupAttributeBuilder(const std::string &attributeKind);

  /// Returns the custom type caster for MlirTypeID mlirTypeID.
  std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
                                                     MlirDialect dialect);

  /// Returns the custom value caster for MlirTypeID mlirTypeID.
  std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
                                                      MlirDialect dialect);

  /// Looks up a registered dialect class by namespace. Note that this may
  /// trigger loading of the defining module and can arbitrarily re-enter.
  std::optional<nanobind::object>
  lookupDialectClass(const std::string &dialectNamespace);

  /// Looks up a registered operation class (deriving from OpView) by operation
  /// name. Note that this may trigger a load of the dialect, which can
  /// arbitrarily re-enter.
  std::optional<nanobind::object>
  lookupOperationClass(llvm::StringRef operationName);

  class TracebackLoc {
  public:
    bool locTracebacksEnabled();

    void setLocTracebacksEnabled(bool value);

    size_t locTracebackFramesLimit();

    void setLocTracebackFramesLimit(size_t value);

    void registerTracebackFileInclusion(const std::string &file);

    void registerTracebackFileExclusion(const std::string &file);

    bool isUserTracebackFilename(llvm::StringRef file);

    static constexpr size_t kMaxFrames = 512;

  private:
    nanobind::ft_mutex mutex;
    bool locTracebackEnabled_ = false;
    size_t locTracebackFramesLimit_ = 10;
    std::unordered_set<std::string> userTracebackIncludeFiles;
    std::unordered_set<std::string> userTracebackExcludeFiles;
    std::regex userTracebackIncludeRegex;
    bool rebuildUserTracebackIncludeRegex = false;
    std::regex userTracebackExcludeRegex;
    bool rebuildUserTracebackExcludeRegex = false;
    llvm::StringMap<bool> isUserTracebackFilenameCache;
  };

  TracebackLoc &getTracebackLoc() { return tracebackLoc; }

private:
  static PyGlobals *instance;

  nanobind::ft_mutex mutex;

  /// Module name prefixes to search under for dialect implementation modules.
  std::vector<std::string> dialectSearchPrefixes;
  /// Map of dialect namespace to external dialect class object.
  llvm::StringMap<nanobind::object> dialectClassMap;
  /// Map of full operation name to external operation class object.
  llvm::StringMap<nanobind::object> operationClassMap;
  /// Map of attribute ODS name to custom builder.
  llvm::StringMap<nanobind::callable> attributeBuilderMap;
  /// Map of MlirTypeID to custom type caster.
  llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
  /// Map of MlirTypeID to custom value caster.
  llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
  /// Set of dialect namespaces that we have attempted to import implementation
  /// modules for.
  llvm::StringSet<> loadedDialectModules;

  TracebackLoc tracebackLoc;
};

} // namespace python
} // namespace mlir

#endif // MLIR_BINDINGS_PYTHON_GLOBALS_H