aboutsummaryrefslogtreecommitdiff
path: root/mlir/tools/mlir-tblgen
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/tools/mlir-tblgen')
-rw-r--r--mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp6
-rw-r--r--mlir/tools/mlir-tblgen/CMakeLists.txt1
-rw-r--r--mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp164
-rw-r--r--mlir/tools/mlir-tblgen/EnumsGen.cpp79
-rw-r--r--mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp12
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp63
-rw-r--r--mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp33
-rw-r--r--mlir/tools/mlir-tblgen/PassGen.cpp106
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp2
9 files changed, 316 insertions, 150 deletions
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 8ec2e03..2a513c3 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -637,8 +637,10 @@ void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
for (auto &method : iface.getMethods()) {
// Don't declare if the method has a body. Or if the method has a default
// implementation and the def didn't request that it always be declared.
- if (method.getBody() || (method.getDefaultImplementation() &&
- !alwaysDeclared.count(method.getName()))) {
+ if (method.getBody())
+ continue;
+ if (method.getDefaultImplementation() &&
+ !alwaysDeclared.count(method.getName())) {
genTraitMethodUsingDecl(trait, method);
continue;
}
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 2a7ef7e..d7087cb 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR
AttrOrTypeFormatGen.cpp
BytecodeDialectGen.cpp
DialectGen.cpp
+ DialectInterfacesGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
EnumPythonBindingGen.cpp
diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
new file mode 100644
index 0000000..1d3b24a
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -0,0 +1,164 @@
+//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// DialectInterfaceGen generates definitions for Dialect interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CppGenUtilities.h"
+#include "DocGenUtilities.h"
+#include "mlir/Support/IndentedOstream.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/CodeGenHelpers.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace mlir;
+using llvm::Record;
+using llvm::RecordKeeper;
+using mlir::tblgen::Interface;
+using mlir::tblgen::InterfaceMethod;
+
+/// Emit a string corresponding to a C++ type, followed by a space if necessary.
+static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
+ type = type.trim();
+ os << type;
+ if (type.back() != '&' && type.back() != '*')
+ os << " ";
+ return os;
+}
+
+/// Emit the method name and argument list for the given method.
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
+ raw_ostream &os) {
+ os << name << '(';
+ llvm::interleaveComma(method.getArguments(), os,
+ [&](const InterfaceMethod::Argument &arg) {
+ os << arg.type << " " << arg.name;
+ });
+ os << ") const";
+}
+
+/// Get an array of all Dialect Interface definitions
+static std::vector<const Record *>
+getAllInterfaceDefinitions(const RecordKeeper &records) {
+ std::vector<const Record *> defs =
+ records.getAllDerivedDefinitions("DialectInterface");
+
+ llvm::erase_if(defs, [&](const Record *def) {
+ // Ignore interfaces defined outside of the top-level file.
+ return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
+ llvm::SrcMgr.getMainFileID();
+ });
+ return defs;
+}
+
+namespace {
+/// This struct is the generator used when processing tablegen dialect
+/// interfaces.
+class DialectInterfaceGenerator {
+public:
+ DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
+ : defs(getAllInterfaceDefinitions(records)), os(os) {}
+
+ bool emitInterfaceDecls();
+
+protected:
+ void emitInterfaceDecl(const Interface &interface);
+
+ /// The set of interface records to emit.
+ std::vector<const Record *> defs;
+ // The stream to emit to.
+ raw_ostream &os;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface declarations
+//===----------------------------------------------------------------------===//
+
+static void emitInterfaceMethodDoc(const InterfaceMethod &method,
+ raw_ostream &os, StringRef prefix = "") {
+ if (std::optional<StringRef> description = method.getDescription())
+ tblgen::emitDescriptionComment(*description, os, prefix);
+}
+
+static void emitInterfaceMethodsDef(const Interface &interface,
+ raw_ostream &os) {
+
+ raw_indented_ostream ios(os);
+ ios.indent(2);
+
+ for (auto &method : interface.getMethods()) {
+ emitInterfaceMethodDoc(method, ios);
+ ios << "virtual ";
+ emitCPPType(method.getReturnType(), ios);
+ emitMethodNameAndArgs(method, method.getName(), ios);
+ ios << " {";
+
+ if (auto body = method.getBody()) {
+ ios << "\n";
+ ios.indent(4);
+ ios << body << "\n";
+ ios.indent(2);
+ }
+ os << "}\n";
+ }
+}
+
+void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
+ llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
+
+ StringRef interfaceName = interface.getName();
+
+ tblgen::emitSummaryAndDescComments(os, "",
+ interface.getDescription().value_or(""));
+
+ // Emit the main interface class declaration.
+ os << llvm::formatv(
+ "class {0} : public ::mlir::DialectInterface::Base<{0}> {\n"
+ "public:\n"
+ " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+ interfaceName);
+
+ emitInterfaceMethodsDef(interface, os);
+
+ os << "};\n";
+}
+
+bool DialectInterfaceGenerator::emitInterfaceDecls() {
+
+ llvm::emitSourceFileHeader("Dialect Interface Declarations", os);
+
+ // Sort according to ID, so defs are emitted in the order in which they appear
+ // in the Tablegen file.
+ std::vector<const Record *> sortedDefs(defs);
+ llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
+ return lhs->getID() < rhs->getID();
+ });
+
+ for (const Record *def : sortedDefs)
+ emitInterfaceDecl(Interface(def));
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration genDecls(
+ "gen-dialect-interface-decls", "Generate dialect interface declarations.",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
+ });
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 11bf9ce..8c7f9f7 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -702,41 +702,45 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
auto enumerants = enumInfo.getAllCases();
- llvm::NamespaceEmitter ns(os, cppNamespace);
-
- // Emit the enum class definition
- emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
-
- // Emit conversion function declarations
- if (llvm::all_of(enumerants, [](EnumCase enumerant) {
- return enumerant.getValue() >= 0;
- })) {
- os << formatv(
- "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
- underlyingType.empty() ? std::string("unsigned") : underlyingType);
- }
- os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
- os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName,
- strToSymFnName);
-
- if (enumInfo.isBitEnum()) {
- emitOperators(enumDef, os);
- } else {
- emitMaxValueFn(enumDef, os);
- }
+ {
+ llvm::NamespaceEmitter ns(os, cppNamespace);
+
+ // Emit the enum class definition
+ emitEnumClass(enumDef, enumName, underlyingType, description, enumerants,
+ os);
+
+ // Emit conversion function declarations
+ if (llvm::all_of(enumerants, [](EnumCase enumerant) {
+ return enumerant.getValue() >= 0;
+ })) {
+ os << formatv(
+ "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
+ underlyingType.empty() ? std::string("unsigned") : underlyingType);
+ }
+ os << formatv("{2} {1}({0});\n", enumName, symToStrFnName,
+ symToStrFnRetType);
+ os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName,
+ strToSymFnName);
+
+ if (enumInfo.isBitEnum()) {
+ emitOperators(enumDef, os);
+ } else {
+ emitMaxValueFn(enumDef, os);
+ }
- // Generate a generic `stringifyEnum` function that forwards to the method
- // specified by the user.
- const char *const stringifyEnumStr = R"(
+ // Generate a generic `stringifyEnum` function that forwards to the method
+ // specified by the user.
+ const char *const stringifyEnumStr = R"(
inline {0} stringifyEnum({1} enumValue) {{
return {2}(enumValue);
}
)";
- os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName);
+ os << formatv(stringifyEnumStr, symToStrFnRetType, enumName,
+ symToStrFnName);
- // Generate a generic `symbolizeEnum` function that forwards to the method
- // specified by the user.
- const char *const symbolizeEnumStr = R"(
+ // Generate a generic `symbolizeEnum` function that forwards to the method
+ // specified by the user.
+ const char *const symbolizeEnumStr = R"(
template <typename EnumType>
::std::optional<EnumType> symbolizeEnum(::llvm::StringRef);
@@ -745,9 +749,9 @@ inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
return {1}(str);
}
)";
- os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
+ os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
- const char *const attrClassDecl = R"(
+ const char *const attrClassDecl = R"(
class {1} : public ::mlir::{2} {
public:
using ValueType = {0};
@@ -757,13 +761,12 @@ public:
{0} getValue() const;
};
)";
- if (enumInfo.genSpecializedAttr()) {
- StringRef attrClassName = enumInfo.getSpecializedAttrClassName();
- StringRef baseAttrClassName = "IntegerAttr";
- os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
- }
-
- ns.close();
+ if (enumInfo.genSpecializedAttr()) {
+ StringRef attrClassName = enumInfo.getSpecializedAttrClassName();
+ StringRef baseAttrClassName = "IntegerAttr";
+ os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
+ }
+ } // close `ns`.
// Generate a generic parser and printer for the enum.
std::string qualName =
diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 525c8d6..54cc4b7 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -14,6 +14,7 @@
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/StringSwitch.h"
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/PrettyStackTrace.h"
@@ -60,8 +61,13 @@ using IndicesTy = llvm::SmallBitVector;
/// Return a CodeGen value type entry from a type record.
static llvm::MVT::SimpleValueType getValueType(const Record *rec) {
- return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
- "Value");
+ return StringSwitch<llvm::MVT::SimpleValueType>(
+ rec->getValueAsDef("VT")->getValueAsString("LLVMName"))
+#define GET_VT_ATTR(Ty, Sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \
+ .Case(#Ty, llvm::MVT::Ty)
+#include "llvm/CodeGen/GenVT.inc"
+#undef GET_VT_ATTR
+ .Case("INVALID_SIMPLE_VALUE_TYPE", llvm::MVT::INVALID_SIMPLE_VALUE_TYPE);
}
/// Return the indices of the definitions in a list of definitions that
@@ -191,7 +197,7 @@ private:
/// Prints the elements in "range" separated by commas and surrounded by "[]".
template <typename Range>
-void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
+static void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
os << '[';
llvm::interleaveComma(range, os);
os << ']';
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 3718648..dbae5d92 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -17,6 +17,7 @@
#include "OpGenHelpers.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
@@ -24,16 +25,24 @@
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Property.h"
+#include "mlir/TableGen/Region.h"
#include "mlir/TableGen/SideEffects.h"
+#include "mlir/TableGen/Successor.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/CodeGenHelpers.h"
@@ -380,9 +389,8 @@ public:
Formatter emitErrorPrefix() const {
return [this](raw_ostream &os) -> raw_ostream & {
if (emitForOp)
- return os << "emitOpError(";
- return os << formatv("emitError(loc, \"'{0}' op \"",
- op.getOperationName());
+ return os << "emitOpError(\"";
+ return os << formatv("emitError(loc, \"'{0}' op ", op.getOperationName());
};
}
@@ -940,7 +948,7 @@ genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx,
// {4}: Attribute/constraint description.
const char *const verifyAttrInline = R"(
if ({0} && !({1}))
- return {2}"attribute '{3}' failed to satisfy constraint: {4}");
+ return {2}attribute '{3}' failed to satisfy constraint: {4}");
)";
// Verify the attribute using a uniqued constraint. Can only be used within
// the context of an op.
@@ -993,10 +1001,11 @@ while (true) {{
(constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
} else {
- body << formatv(verifyAttrInline, varName,
- tgfmt(condition, &ctx.withSelf(varName)),
- emitHelper.emitErrorPrefix(), attrName,
- escapeString(attr.getSummary()));
+ body << formatv(
+ verifyAttrInline, varName, tgfmt(condition, &ctx.withSelf(varName)),
+ emitHelper.emitErrorPrefix(), attrName,
+ buildErrorStreamingString(attr.getSummary(), ctx.withSelf(varName),
+ ErrorStreamType::InsideOpError));
}
};
@@ -1017,7 +1026,7 @@ while (true) {{
it.first);
if (metadata.isRequired)
body << formatv(
- "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
+ "if (!tblgen_{0}) return {1}requires attribute '{0}'\");\n",
it.first, emitHelper.emitErrorPrefix());
}
} else {
@@ -1099,7 +1108,7 @@ static void genPropertyVerifier(
// {3}: Property description.
const char *const verifyPropertyInline = R"(
if (!({0}))
- return {1}"property '{2}' failed to satisfy constraint: {3}");
+ return {1}property '{2}' failed to satisfy constraint: {3}");
)";
// Verify the property using a uniqued constraint. Can only be used
@@ -1143,9 +1152,12 @@ static void genPropertyVerifier(
if (uniquedFn.has_value() && emitHelper.isEmittingForOp())
body << formatv(verifyPropertyUniqued, *uniquedFn, varName, prop.name);
else
- body << formatv(
- verifyPropertyInline, tgfmt(rawCondition, &ctx.withSelf(varName)),
- emitHelper.emitErrorPrefix(), prop.name, prop.prop.getSummary());
+ body << formatv(verifyPropertyInline,
+ tgfmt(rawCondition, &ctx.withSelf(varName)),
+ emitHelper.emitErrorPrefix(), prop.name,
+ buildErrorStreamingString(
+ prop.prop.getSummary(), ctx.withSelf(varName),
+ ErrorStreamType::InsideOpError));
}
}
@@ -1629,7 +1641,7 @@ void OpEmitter::genPropertiesSupport() {
// Hashing for the property
const char *propHashFmt = R"decl(
- auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
+ auto hash_{0}_ = [] (const auto &propStorage) -> llvm::hash_code {
using ::llvm::hash_value;
return {1};
};
@@ -1655,7 +1667,7 @@ void OpEmitter::genPropertiesSupport() {
if (const auto *namedProperty =
llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
if (!namedProperty->prop.getHashPropertyCall().empty()) {
- hashMethod << "\n hash_" << namedProperty->name << "(prop."
+ hashMethod << "\n hash_" << namedProperty->name << "_(prop."
<< namedProperty->name << ")";
} else {
hashMethod << "\n hash_value(prop." << namedProperty->name
@@ -2629,14 +2641,23 @@ void OpEmitter::genInlineCreateBody(
std::string nonBuilderStateArgs = "";
if (!nonBuilderStateArgsList.empty()) {
llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
- interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS);
+ interleave(
+ nonBuilderStateArgsList,
+ [&](StringRef name) {
+ nonBuilderStateArgsOS << "std::forward<decltype(" << name << ")>("
+ << name << ')';
+ },
+ [&] { nonBuilderStateArgsOS << ", "; });
+
nonBuilderStateArgs = ", " + nonBuilderStateArgs;
}
- cWithLoc->body() << llvm::formatv(inlineCreateBody, locParamName,
- nonBuilderStateArgs,
- opClass.getClassName());
- cImplicitLoc->body() << llvm::formatv(inlineCreateBodyImplicitLoc,
- nonBuilderStateArgs);
+ if (cWithLoc)
+ cWithLoc->body() << llvm::formatv(inlineCreateBody, locParamName,
+ nonBuilderStateArgs,
+ opClass.getClassName());
+ if (cImplicitLoc)
+ cImplicitLoc->body() << llvm::formatv(inlineCreateBodyImplicitLoc,
+ nonBuilderStateArgs);
}
void OpEmitter::genSeparateArgParamBuilder() {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0172b3f..2c33f4e 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -341,6 +341,22 @@ static std::string attrSizedTraitForKind(const char *kind) {
StringRef(kind).drop_front());
}
+static StringRef getPythonType(StringRef cppType) {
+ return llvm::StringSwitch<StringRef>(cppType)
+ .Case("::mlir::MemRefType", "_ods_ir.MemRefType")
+ .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
+ .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
+ .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
+ .Case("::mlir::VectorType", "_ods_ir.VectorType")
+ .Case("::mlir::IntegerType", "_ods_ir.IntegerType")
+ .Case("::mlir::FloatType", "_ods_ir.FloatType")
+ .Case("::mlir::IndexType", "_ods_ir.IndexType")
+ .Case("::mlir::ComplexType", "_ods_ir.ComplexType")
+ .Case("::mlir::TupleType", "_ods_ir.TupleType")
+ .Case("::mlir::NoneType", "_ods_ir.NoneType")
+ .Default(StringRef());
+}
+
/// Emits accessors to "elements" of an Op definition. Currently, the supported
/// elements are operands and results, indicated by `kind`, which must be either
/// `operand` or `result` and is used verbatim in the emitted code.
@@ -370,8 +386,11 @@ static void emitElementAccessors(
seenVariableLength = true;
if (element.name.empty())
continue;
- const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+ std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
+ if (StringRef pythonType = getPythonType(element.constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
if (element.isVariableLength()) {
if (element.isOptional()) {
os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
@@ -418,6 +437,12 @@ static void emitElementAccessors(
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
}
+ if (std::strcmp(type.c_str(), "_ods_ir.Value") == 0 ||
+ std::strcmp(type.c_str(), "_ods_ir.OpResult") == 0) {
+ StringRef pythonType = getPythonType(element.constraint.getCppType());
+ if (!pythonType.empty())
+ type += "[" + pythonType.str() + "]";
+ }
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
kind, numSimpleLength, numVariadicGroups,
numPrecedingSimple, numPrecedingVariadic, type);
@@ -449,6 +474,12 @@ static void emitElementAccessors(
if (!element.isVariableLength() || element.isOptional()) {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
+ if (std::strcmp(type.c_str(), "_ods_ir.Value") == 0 ||
+ std::strcmp(type.c_str(), "_ods_ir.OpResult") == 0) {
+ StringRef pythonType = getPythonType(element.constraint.getCppType());
+ if (!pythonType.empty())
+ type += "[" + pythonType.str() + "]";
+ }
if (!element.isVariableLength()) {
trailing = "[0]";
} else if (element.isOptional()) {
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index f7134ce..e4ae78f 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
+#ifdef {1}
inline void register{0}() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
- return {1};
+ return {2};
});
}
// Old registration code, kept for temporary backwards compatibility.
inline void register{0}Pass() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
- return {1};
+ return {2};
});
}
+
+#undef {1}
+#endif // {1}
)";
/// The code snippet used to generate a function to register all passes in a
@@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) {
return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
}
+static std::string getPassRegistrationVarName(const Pass &pass) {
+ return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper();
+}
+
/// Emit the code to be included in the public header of the pass.
static void emitPassDecls(const Pass &pass, raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
@@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
/// PassRegistry.
static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
+ os << "// Generate registrations for all passes.\n";
+ for (const Pass &pass : passes)
+ os << "#define " << getPassRegistrationVarName(pass) << "\n";
+ os << "#endif // GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
+ std::string passName = pass.getDef()->getName().str();
+ std::string passEnableVarName = getPassRegistrationVarName(pass);
+
std::string constructorCall;
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str();
else
- constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
-
- os << formatv(passRegistrationCode, pass.getDef()->getName(),
+ constructorCall = formatv("create{0}()", passName).str();
+ os << formatv(passRegistrationCode, passName, passEnableVarName,
constructorCall);
}
+ os << "#ifdef GEN_PASS_REGISTRATION\n";
os << formatv(passGroupRegistrationCode, groupName);
for (const Pass &pass : passes)
@@ -372,81 +387,6 @@ static void emitPass(const Pass &pass, raw_ostream &os) {
emitPassDefs(pass, os);
}
-// TODO: Drop old pass declarations.
-// The old pass base class is being kept until all the passes have switched to
-// the new decls/defs design.
-const char *const oldPassDeclBegin = R"(
-template <typename DerivedT>
-class {0}Base : public {1} {
-public:
- using Base = {0}Base;
-
- {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
- {0}Base(const {0}Base &other) : {1}(other) {{}
- {0}Base& operator=(const {0}Base &) = delete;
- {0}Base({0}Base &&) = delete;
- {0}Base& operator=({0}Base &&) = delete;
- ~{0}Base() = default;
-
- /// Returns the command-line argument attached to this pass.
- static constexpr ::llvm::StringLiteral getArgumentName() {
- return ::llvm::StringLiteral("{2}");
- }
- ::llvm::StringRef getArgument() const override { return "{2}"; }
-
- ::llvm::StringRef getDescription() const override { return R"PD({3})PD"; }
-
- /// Returns the derived pass name.
- static constexpr ::llvm::StringLiteral getPassName() {
- return ::llvm::StringLiteral("{0}");
- }
- ::llvm::StringRef getName() const override { return "{0}"; }
-
- /// Support isa/dyn_cast functionality for the derived pass class.
- static bool classof(const ::mlir::Pass *pass) {{
- return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
- }
-
- /// A clone method to create a copy of this pass.
- std::unique_ptr<::mlir::Pass> clonePass() const override {{
- return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
- }
-
- /// Register the dialects that must be loaded in the context before this pass.
- void getDependentDialects(::mlir::DialectRegistry &registry) const override {
- {4}
- }
-
- /// Explicitly declare the TypeID for this class. We declare an explicit private
- /// instantiation because Pass classes should only be visible by the current
- /// library.
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
-
-protected:
-)";
-
-// TODO: Drop old pass declarations.
-/// Emit a backward-compatible declaration of the pass base class.
-static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
- StringRef defName = pass.getDef()->getName();
- std::string dependentDialectRegistrations;
- {
- llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- llvm::interleave(
- pass.getDependentDialects(), dialectsOs,
- [&](StringRef dependentDialect) {
- dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
- },
- "\n ");
- }
- os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
- pass.getArgument(), pass.getSummary().trim(),
- dependentDialectRegistrations);
- emitPassOptionDecls(pass, os);
- emitPassStatisticDecls(pass, os);
- os << "};\n";
-}
-
static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
std::vector<Pass> passes = getPasses(records);
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
@@ -464,12 +404,10 @@ static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
emitRegistrations(passes, os);
- // TODO: Drop old pass declarations.
+ // TODO: Remove warning, kept in to make error understandable.
// Emit the old code until all the passes have switched to the new design.
- os << "// Deprecated. Please use the new per-pass macros.\n";
os << "#ifdef GEN_PASS_CLASSES\n";
- for (const Pass &pass : passes)
- emitOldPassDecl(pass, os);
+ os << "#error \"GEN_PASS_CLASSES is deprecated; use per-pass macros\"\n";
os << "#undef GEN_PASS_CLASSES\n";
os << "#endif // GEN_PASS_CLASSES\n";
}
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index c3034bb8..08d6483 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1129,7 +1129,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
// Emit RewritePattern for Pattern.
- auto locs = pattern.getLocation();
+ auto locs = pattern.getLocation(/*forSourceOutput=*/true);
os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
llvm::reverse(locs));
os << formatv(R"(struct {0} : public ::mlir::RewritePattern {