aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/IR/Dialect.cpp
blob: 631dc410632ad42a64813f8d7f86cbe56db01190 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
//===- Dialect.cpp - Dialect implementation -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Dialect.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"

#define DEBUG_TYPE "dialect"

using namespace mlir;
using namespace detail;

//===----------------------------------------------------------------------===//
// DialectRegistry
//===----------------------------------------------------------------------===//

DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }

void DialectRegistry::addDialectInterface(
    StringRef dialectName, TypeID interfaceTypeID,
    DialectInterfaceAllocatorFunction allocator) {
  assert(allocator && "unexpected null interface allocation function");
  auto it = registry.find(dialectName.str());
  assert(it != registry.end() &&
         "adding an interface for an unregistered dialect");

  // Bail out if the interface with the given ID is already in the registry for
  // the given dialect. We expect a small number (dozens) of interfaces so a
  // linear search is fine here.
  auto &ifaces = interfaces[it->second.first];
  for (const auto &kvp : ifaces.dialectInterfaces) {
    if (kvp.first == interfaceTypeID) {
      LLVM_DEBUG(llvm::dbgs()
                 << "[" DEBUG_TYPE
                    "] repeated interface registration for dialect "
                 << dialectName);
      return;
    }
  }

  ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
}

void DialectRegistry::addObjectInterface(
    StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
    ObjectInterfaceAllocatorFunction allocator) {
  assert(allocator && "unexpected null interface allocation function");

  auto it = registry.find(dialectName.str());
  assert(it != registry.end() &&
         "adding an interface for an op from an unregistered dialect");

  auto dialectID = it->second.first;
  auto &ifaces = interfaces[dialectID];

  for (const auto &info : ifaces.objectInterfaces) {
    if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
      LLVM_DEBUG(llvm::dbgs()
                 << "[" DEBUG_TYPE
                    "] repeated interface object interface registration");
      return;
    }
  }

  ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
}

DialectAllocatorFunctionRef
DialectRegistry::getDialectAllocator(StringRef name) const {
  auto it = registry.find(name.str());
  if (it == registry.end())
    return nullptr;
  return it->second.second;
}

void DialectRegistry::insert(TypeID typeID, StringRef name,
                             DialectAllocatorFunction ctor) {
  auto inserted = registry.insert(
      std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
  if (!inserted.second && inserted.first->second.first != typeID) {
    llvm::report_fatal_error(
        "Trying to register different dialects for the same namespace: " +
        name);
  }
}

void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
  auto it = interfaces.find(dialect->getTypeID());
  if (it == interfaces.end())
    return;

  // Add an interface if it is not already present.
  for (const auto &kvp : it->getSecond().dialectInterfaces) {
    if (dialect->getRegisteredInterface(kvp.first))
      continue;
    dialect->addInterface(kvp.second(dialect));
  }

  // Add attribute, operation and type interfaces.
  for (const auto &info : it->getSecond().objectInterfaces)
    std::get<2>(info)(dialect->getContext());
}

//===----------------------------------------------------------------------===//
// Dialect
//===----------------------------------------------------------------------===//

Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
    : name(name), dialectID(id), context(context) {
  assert(isValidNamespace(name) && "invalid dialect namespace");
}

Dialect::~Dialect() {}

/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
                                                NamedAttribute) {
  return success();
}

/// Verify an attribute from this dialect on the result at 'resultIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
                                                   unsigned, NamedAttribute) {
  return success();
}

/// Parse an attribute registered to this dialect.
Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
  parser.emitError(parser.getNameLoc())
      << "dialect '" << getNamespace()
      << "' provides no attribute parsing hook";
  return Attribute();
}

/// Parse a type registered to this dialect.
Type Dialect::parseType(DialectAsmParser &parser) const {
  // If this dialect allows unknown types, then represent this with OpaqueType.
  if (allowsUnknownTypes()) {
    StringAttr ns = StringAttr::get(getContext(), getNamespace());
    return OpaqueType::get(ns, parser.getFullSymbolSpec());
  }

  parser.emitError(parser.getNameLoc())
      << "dialect '" << getNamespace() << "' provides no type parsing hook";
  return Type();
}

Optional<Dialect::ParseOpHook>
Dialect::getParseOperationHook(StringRef opName) const {
  return None;
}

llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
Dialect::getOperationPrinter(Operation *op) const {
  assert(op->getDialect() == this &&
         "Dialect hook invoked on non-dialect owned operation");
  return nullptr;
}

/// Utility function that returns if the given string is a valid dialect
/// namespace
bool Dialect::isValidNamespace(StringRef str) {
  llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
  return dialectNameRegex.match(str);
}

/// Register a set of dialect interfaces with this dialect instance.
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
  auto it = registeredInterfaces.try_emplace(interface->getID(),
                                             std::move(interface));
  (void)it;
  assert(it.second && "interface kind has already been registered");
}

//===----------------------------------------------------------------------===//
// Dialect Interface
//===----------------------------------------------------------------------===//

DialectInterface::~DialectInterface() {}

DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
    MLIRContext *ctx, TypeID interfaceKind) {
  for (auto *dialect : ctx->getLoadedDialects()) {
    if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
      interfaces.insert(interface);
      orderedInterfaces.push_back(interface);
    }
  }
}

DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}

/// Get the interface for the dialect of given operation, or null if one
/// is not registered.
const DialectInterface *
DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
  return getInterfaceFor(op->getDialect());
}