diff options
Diffstat (limited to 'mlir/tools/mlir-tblgen')
| -rw-r--r-- | mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 6 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp | 164 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/EnumsGen.cpp | 79 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp | 12 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 63 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 33 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/PassGen.cpp | 106 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/RewriterGen.cpp | 2 |
9 files changed, 316 insertions, 150 deletions
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 8ec2e03..2a513c3 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -637,8 +637,10 @@ void DefGen::emitTraitMethods(const InterfaceTrait &trait) { for (auto &method : iface.getMethods()) { // Don't declare if the method has a body. Or if the method has a default // implementation and the def didn't request that it always be declared. - if (method.getBody() || (method.getDefaultImplementation() && - !alwaysDeclared.count(method.getName()))) { + if (method.getBody()) + continue; + if (method.getDefaultImplementation() && + !alwaysDeclared.count(method.getName())) { genTraitMethodUsingDecl(trait, method); continue; } diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index 2a7ef7e..d7087cb 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR AttrOrTypeFormatGen.cpp BytecodeDialectGen.cpp DialectGen.cpp + DialectInterfacesGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp EnumPythonBindingGen.cpp diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp new file mode 100644 index 0000000..1d3b24a --- /dev/null +++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp @@ -0,0 +1,164 @@ +//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// DialectInterfaceGen generates definitions for Dialect interfaces. +// +//===----------------------------------------------------------------------===// + +#include "CppGenUtilities.h" +#include "DocGenUtilities.h" +#include "mlir/Support/IndentedOstream.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace mlir; +using llvm::Record; +using llvm::RecordKeeper; +using mlir::tblgen::Interface; +using mlir::tblgen::InterfaceMethod; + +/// Emit a string corresponding to a C++ type, followed by a space if necessary. +static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { + type = type.trim(); + os << type; + if (type.back() != '&' && type.back() != '*') + os << " "; + return os; +} + +/// Emit the method name and argument list for the given method. +static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name, + raw_ostream &os) { + os << name << '('; + llvm::interleaveComma(method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); + os << ") const"; +} + +/// Get an array of all Dialect Interface definitions +static std::vector<const Record *> +getAllInterfaceDefinitions(const RecordKeeper &records) { + std::vector<const Record *> defs = + records.getAllDerivedDefinitions("DialectInterface"); + + llvm::erase_if(defs, [&](const Record *def) { + // Ignore interfaces defined outside of the top-level file. + return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != + llvm::SrcMgr.getMainFileID(); + }); + return defs; +} + +namespace { +/// This struct is the generator used when processing tablegen dialect +/// interfaces. +class DialectInterfaceGenerator { +public: + DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) + : defs(getAllInterfaceDefinitions(records)), os(os) {} + + bool emitInterfaceDecls(); + +protected: + void emitInterfaceDecl(const Interface &interface); + + /// The set of interface records to emit. + std::vector<const Record *> defs; + // The stream to emit to. + raw_ostream &os; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// GEN: Interface declarations +//===----------------------------------------------------------------------===// + +static void emitInterfaceMethodDoc(const InterfaceMethod &method, + raw_ostream &os, StringRef prefix = "") { + if (std::optional<StringRef> description = method.getDescription()) + tblgen::emitDescriptionComment(*description, os, prefix); +} + +static void emitInterfaceMethodsDef(const Interface &interface, + raw_ostream &os) { + + raw_indented_ostream ios(os); + ios.indent(2); + + for (auto &method : interface.getMethods()) { + emitInterfaceMethodDoc(method, ios); + ios << "virtual "; + emitCPPType(method.getReturnType(), ios); + emitMethodNameAndArgs(method, method.getName(), ios); + ios << " {"; + + if (auto body = method.getBody()) { + ios << "\n"; + ios.indent(4); + ios << body << "\n"; + ios.indent(2); + } + os << "}\n"; + } +} + +void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) { + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); + + StringRef interfaceName = interface.getName(); + + tblgen::emitSummaryAndDescComments(os, "", + interface.getDescription().value_or("")); + + // Emit the main interface class declaration. + os << llvm::formatv( + "class {0} : public ::mlir::DialectInterface::Base<{0}> {\n" + "public:\n" + " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n", + interfaceName); + + emitInterfaceMethodsDef(interface, os); + + os << "};\n"; +} + +bool DialectInterfaceGenerator::emitInterfaceDecls() { + + llvm::emitSourceFileHeader("Dialect Interface Declarations", os); + + // Sort according to ID, so defs are emitted in the order in which they appear + // in the Tablegen file. + std::vector<const Record *> sortedDefs(defs); + llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { + return lhs->getID() < rhs->getID(); + }); + + for (const Record *def : sortedDefs) + emitInterfaceDecl(Interface(def)); + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: Interface registration hooks +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration genDecls( + "gen-dialect-interface-decls", "Generate dialect interface declarations.", + [](const RecordKeeper &records, raw_ostream &os) { + return DialectInterfaceGenerator(records, os).emitInterfaceDecls(); + }); diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index 11bf9ce..8c7f9f7 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -702,41 +702,45 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName(); auto enumerants = enumInfo.getAllCases(); - llvm::NamespaceEmitter ns(os, cppNamespace); - - // Emit the enum class definition - emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); - - // Emit conversion function declarations - if (llvm::all_of(enumerants, [](EnumCase enumerant) { - return enumerant.getValue() >= 0; - })) { - os << formatv( - "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName, - underlyingType.empty() ? std::string("unsigned") : underlyingType); - } - os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType); - os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName, - strToSymFnName); - - if (enumInfo.isBitEnum()) { - emitOperators(enumDef, os); - } else { - emitMaxValueFn(enumDef, os); - } + { + llvm::NamespaceEmitter ns(os, cppNamespace); + + // Emit the enum class definition + emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, + os); + + // Emit conversion function declarations + if (llvm::all_of(enumerants, [](EnumCase enumerant) { + return enumerant.getValue() >= 0; + })) { + os << formatv( + "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName, + underlyingType.empty() ? std::string("unsigned") : underlyingType); + } + os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, + symToStrFnRetType); + os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName, + strToSymFnName); + + if (enumInfo.isBitEnum()) { + emitOperators(enumDef, os); + } else { + emitMaxValueFn(enumDef, os); + } - // Generate a generic `stringifyEnum` function that forwards to the method - // specified by the user. - const char *const stringifyEnumStr = R"( + // Generate a generic `stringifyEnum` function that forwards to the method + // specified by the user. + const char *const stringifyEnumStr = R"( inline {0} stringifyEnum({1} enumValue) {{ return {2}(enumValue); } )"; - os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName); + os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, + symToStrFnName); - // Generate a generic `symbolizeEnum` function that forwards to the method - // specified by the user. - const char *const symbolizeEnumStr = R"( + // Generate a generic `symbolizeEnum` function that forwards to the method + // specified by the user. + const char *const symbolizeEnumStr = R"( template <typename EnumType> ::std::optional<EnumType> symbolizeEnum(::llvm::StringRef); @@ -745,9 +749,9 @@ inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) { return {1}(str); } )"; - os << formatv(symbolizeEnumStr, enumName, strToSymFnName); + os << formatv(symbolizeEnumStr, enumName, strToSymFnName); - const char *const attrClassDecl = R"( + const char *const attrClassDecl = R"( class {1} : public ::mlir::{2} { public: using ValueType = {0}; @@ -757,13 +761,12 @@ public: {0} getValue() const; }; )"; - if (enumInfo.genSpecializedAttr()) { - StringRef attrClassName = enumInfo.getSpecializedAttrClassName(); - StringRef baseAttrClassName = "IntegerAttr"; - os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); - } - - ns.close(); + if (enumInfo.genSpecializedAttr()) { + StringRef attrClassName = enumInfo.getSpecializedAttrClassName(); + StringRef baseAttrClassName = "IntegerAttr"; + os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); + } + } // close `ns`. // Generate a generic parser and printer for the enum. std::string qualName = diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp index 525c8d6..54cc4b7 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -14,6 +14,7 @@ #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/CodeGenTypes/MachineValueType.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/PrettyStackTrace.h" @@ -60,8 +61,13 @@ using IndicesTy = llvm::SmallBitVector; /// Return a CodeGen value type entry from a type record. static llvm::MVT::SimpleValueType getValueType(const Record *rec) { - return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt( - "Value"); + return StringSwitch<llvm::MVT::SimpleValueType>( + rec->getValueAsDef("VT")->getValueAsString("LLVMName")) +#define GET_VT_ATTR(Ty, Sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \ + .Case(#Ty, llvm::MVT::Ty) +#include "llvm/CodeGen/GenVT.inc" +#undef GET_VT_ATTR + .Case("INVALID_SIMPLE_VALUE_TYPE", llvm::MVT::INVALID_SIMPLE_VALUE_TYPE); } /// Return the indices of the definitions in a list of definitions that @@ -191,7 +197,7 @@ private: /// Prints the elements in "range" separated by commas and surrounded by "[]". template <typename Range> -void printBracketedRange(const Range &range, llvm::raw_ostream &os) { +static void printBracketedRange(const Range &range, llvm::raw_ostream &os) { os << '['; llvm::interleaveComma(range, os); os << ']'; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 3718648..dbae5d92 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -17,6 +17,7 @@ #include "OpGenHelpers.h" #include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Builder.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" @@ -24,16 +25,24 @@ #include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Property.h" +#include "mlir/TableGen/Region.h" #include "mlir/TableGen/SideEffects.h" +#include "mlir/TableGen/Successor.h" #include "mlir/TableGen/Trait.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Signals.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/CodeGenHelpers.h" @@ -380,9 +389,8 @@ public: Formatter emitErrorPrefix() const { return [this](raw_ostream &os) -> raw_ostream & { if (emitForOp) - return os << "emitOpError("; - return os << formatv("emitError(loc, \"'{0}' op \"", - op.getOperationName()); + return os << "emitOpError(\""; + return os << formatv("emitError(loc, \"'{0}' op ", op.getOperationName()); }; } @@ -940,7 +948,7 @@ genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, // {4}: Attribute/constraint description. const char *const verifyAttrInline = R"( if ({0} && !({1})) - return {2}"attribute '{3}' failed to satisfy constraint: {4}"); + return {2}attribute '{3}' failed to satisfy constraint: {4}"); )"; // Verify the attribute using a uniqued constraint. Can only be used within // the context of an op. @@ -993,10 +1001,11 @@ while (true) {{ (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); } else { - body << formatv(verifyAttrInline, varName, - tgfmt(condition, &ctx.withSelf(varName)), - emitHelper.emitErrorPrefix(), attrName, - escapeString(attr.getSummary())); + body << formatv( + verifyAttrInline, varName, tgfmt(condition, &ctx.withSelf(varName)), + emitHelper.emitErrorPrefix(), attrName, + buildErrorStreamingString(attr.getSummary(), ctx.withSelf(varName), + ErrorStreamType::InsideOpError)); } }; @@ -1017,7 +1026,7 @@ while (true) {{ it.first); if (metadata.isRequired) body << formatv( - "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n", + "if (!tblgen_{0}) return {1}requires attribute '{0}'\");\n", it.first, emitHelper.emitErrorPrefix()); } } else { @@ -1099,7 +1108,7 @@ static void genPropertyVerifier( // {3}: Property description. const char *const verifyPropertyInline = R"( if (!({0})) - return {1}"property '{2}' failed to satisfy constraint: {3}"); + return {1}property '{2}' failed to satisfy constraint: {3}"); )"; // Verify the property using a uniqued constraint. Can only be used @@ -1143,9 +1152,12 @@ static void genPropertyVerifier( if (uniquedFn.has_value() && emitHelper.isEmittingForOp()) body << formatv(verifyPropertyUniqued, *uniquedFn, varName, prop.name); else - body << formatv( - verifyPropertyInline, tgfmt(rawCondition, &ctx.withSelf(varName)), - emitHelper.emitErrorPrefix(), prop.name, prop.prop.getSummary()); + body << formatv(verifyPropertyInline, + tgfmt(rawCondition, &ctx.withSelf(varName)), + emitHelper.emitErrorPrefix(), prop.name, + buildErrorStreamingString( + prop.prop.getSummary(), ctx.withSelf(varName), + ErrorStreamType::InsideOpError)); } } @@ -1629,7 +1641,7 @@ void OpEmitter::genPropertiesSupport() { // Hashing for the property const char *propHashFmt = R"decl( - auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code { + auto hash_{0}_ = [] (const auto &propStorage) -> llvm::hash_code { using ::llvm::hash_value; return {1}; }; @@ -1655,7 +1667,7 @@ void OpEmitter::genPropertiesSupport() { if (const auto *namedProperty = llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) { if (!namedProperty->prop.getHashPropertyCall().empty()) { - hashMethod << "\n hash_" << namedProperty->name << "(prop." + hashMethod << "\n hash_" << namedProperty->name << "_(prop." << namedProperty->name << ")"; } else { hashMethod << "\n hash_value(prop." << namedProperty->name @@ -2629,14 +2641,23 @@ void OpEmitter::genInlineCreateBody( std::string nonBuilderStateArgs = ""; if (!nonBuilderStateArgsList.empty()) { llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs); - interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS); + interleave( + nonBuilderStateArgsList, + [&](StringRef name) { + nonBuilderStateArgsOS << "std::forward<decltype(" << name << ")>(" + << name << ')'; + }, + [&] { nonBuilderStateArgsOS << ", "; }); + nonBuilderStateArgs = ", " + nonBuilderStateArgs; } - cWithLoc->body() << llvm::formatv(inlineCreateBody, locParamName, - nonBuilderStateArgs, - opClass.getClassName()); - cImplicitLoc->body() << llvm::formatv(inlineCreateBodyImplicitLoc, - nonBuilderStateArgs); + if (cWithLoc) + cWithLoc->body() << llvm::formatv(inlineCreateBody, locParamName, + nonBuilderStateArgs, + opClass.getClassName()); + if (cImplicitLoc) + cImplicitLoc->body() << llvm::formatv(inlineCreateBodyImplicitLoc, + nonBuilderStateArgs); } void OpEmitter::genSeparateArgParamBuilder() { diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 0172b3f..2c33f4e 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -341,6 +341,22 @@ static std::string attrSizedTraitForKind(const char *kind) { StringRef(kind).drop_front()); } +static StringRef getPythonType(StringRef cppType) { + return llvm::StringSwitch<StringRef>(cppType) + .Case("::mlir::MemRefType", "_ods_ir.MemRefType") + .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType") + .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType") + .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType") + .Case("::mlir::VectorType", "_ods_ir.VectorType") + .Case("::mlir::IntegerType", "_ods_ir.IntegerType") + .Case("::mlir::FloatType", "_ods_ir.FloatType") + .Case("::mlir::IndexType", "_ods_ir.IndexType") + .Case("::mlir::ComplexType", "_ods_ir.ComplexType") + .Case("::mlir::TupleType", "_ods_ir.TupleType") + .Case("::mlir::NoneType", "_ods_ir.NoneType") + .Default(StringRef()); +} + /// Emits accessors to "elements" of an Op definition. Currently, the supported /// elements are operands and results, indicated by `kind`, which must be either /// `operand` or `result` and is used verbatim in the emitted code. @@ -370,8 +386,11 @@ static void emitElementAccessors( seenVariableLength = true; if (element.name.empty()) continue; - const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" + std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; + if (StringRef pythonType = getPythonType(element.constraint.getCppType()); + !pythonType.empty()) + type = llvm::formatv("{0}[{1}]", type, pythonType); if (element.isVariableLength()) { if (element.isOptional()) { os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind, @@ -418,6 +437,12 @@ static void emitElementAccessors( type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; } + if (std::strcmp(type.c_str(), "_ods_ir.Value") == 0 || + std::strcmp(type.c_str(), "_ods_ir.OpResult") == 0) { + StringRef pythonType = getPythonType(element.constraint.getCppType()); + if (!pythonType.empty()) + type += "[" + pythonType.str() + "]"; + } os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name), kind, numSimpleLength, numVariadicGroups, numPrecedingSimple, numPrecedingVariadic, type); @@ -449,6 +474,12 @@ static void emitElementAccessors( if (!element.isVariableLength() || element.isOptional()) { type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; + if (std::strcmp(type.c_str(), "_ods_ir.Value") == 0 || + std::strcmp(type.c_str(), "_ods_ir.OpResult") == 0) { + StringRef pythonType = getPythonType(element.constraint.getCppType()); + if (!pythonType.empty()) + type += "[" + pythonType.str() + "]"; + } if (!element.isVariableLength()) { trailing = "[0]"; } else if (element.isOptional()) { diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index f7134ce..e4ae78f 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"( //===----------------------------------------------------------------------===// // {0} Registration //===----------------------------------------------------------------------===// +#ifdef {1} inline void register{0}() {{ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; + return {2}; }); } // Old registration code, kept for temporary backwards compatibility. inline void register{0}Pass() {{ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; + return {2}; }); } + +#undef {1} +#endif // {1} )"; /// The code snippet used to generate a function to register all passes in a @@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) { return "GEN_PASS_DECL_" + pass.getDef()->getName().upper(); } +static std::string getPassRegistrationVarName(const Pass &pass) { + return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper(); +} + /// Emit the code to be included in the public header of the pass. static void emitPassDecls(const Pass &pass, raw_ostream &os) { StringRef passName = pass.getDef()->getName(); @@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) { /// PassRegistry. static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) { os << "#ifdef GEN_PASS_REGISTRATION\n"; + os << "// Generate registrations for all passes.\n"; + for (const Pass &pass : passes) + os << "#define " << getPassRegistrationVarName(pass) << "\n"; + os << "#endif // GEN_PASS_REGISTRATION\n"; for (const Pass &pass : passes) { + std::string passName = pass.getDef()->getName().str(); + std::string passEnableVarName = getPassRegistrationVarName(pass); + std::string constructorCall; if (StringRef constructor = pass.getConstructor(); !constructor.empty()) constructorCall = constructor.str(); else - constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); - - os << formatv(passRegistrationCode, pass.getDef()->getName(), + constructorCall = formatv("create{0}()", passName).str(); + os << formatv(passRegistrationCode, passName, passEnableVarName, constructorCall); } + os << "#ifdef GEN_PASS_REGISTRATION\n"; os << formatv(passGroupRegistrationCode, groupName); for (const Pass &pass : passes) @@ -372,81 +387,6 @@ static void emitPass(const Pass &pass, raw_ostream &os) { emitPassDefs(pass, os); } -// TODO: Drop old pass declarations. -// The old pass base class is being kept until all the passes have switched to -// the new decls/defs design. -const char *const oldPassDeclBegin = R"( -template <typename DerivedT> -class {0}Base : public {1} { -public: - using Base = {0}Base; - - {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} - {0}Base(const {0}Base &other) : {1}(other) {{} - {0}Base& operator=(const {0}Base &) = delete; - {0}Base({0}Base &&) = delete; - {0}Base& operator=({0}Base &&) = delete; - ~{0}Base() = default; - - /// Returns the command-line argument attached to this pass. - static constexpr ::llvm::StringLiteral getArgumentName() { - return ::llvm::StringLiteral("{2}"); - } - ::llvm::StringRef getArgument() const override { return "{2}"; } - - ::llvm::StringRef getDescription() const override { return R"PD({3})PD"; } - - /// Returns the derived pass name. - static constexpr ::llvm::StringLiteral getPassName() { - return ::llvm::StringLiteral("{0}"); - } - ::llvm::StringRef getName() const override { return "{0}"; } - - /// Support isa/dyn_cast functionality for the derived pass class. - static bool classof(const ::mlir::Pass *pass) {{ - return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); - } - - /// A clone method to create a copy of this pass. - std::unique_ptr<::mlir::Pass> clonePass() const override {{ - return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); - } - - /// Register the dialects that must be loaded in the context before this pass. - void getDependentDialects(::mlir::DialectRegistry ®istry) const override { - {4} - } - - /// Explicitly declare the TypeID for this class. We declare an explicit private - /// instantiation because Pass classes should only be visible by the current - /// library. - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) - -protected: -)"; - -// TODO: Drop old pass declarations. -/// Emit a backward-compatible declaration of the pass base class. -static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { - StringRef defName = pass.getDef()->getName(); - std::string dependentDialectRegistrations; - { - llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); - llvm::interleave( - pass.getDependentDialects(), dialectsOs, - [&](StringRef dependentDialect) { - dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); - }, - "\n "); - } - os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(), - pass.getArgument(), pass.getSummary().trim(), - dependentDialectRegistrations); - emitPassOptionDecls(pass, os); - emitPassStatisticDecls(pass, os); - os << "};\n"; -} - static void emitPasses(const RecordKeeper &records, raw_ostream &os) { std::vector<Pass> passes = getPasses(records); os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; @@ -464,12 +404,10 @@ static void emitPasses(const RecordKeeper &records, raw_ostream &os) { emitRegistrations(passes, os); - // TODO: Drop old pass declarations. + // TODO: Remove warning, kept in to make error understandable. // Emit the old code until all the passes have switched to the new design. - os << "// Deprecated. Please use the new per-pass macros.\n"; os << "#ifdef GEN_PASS_CLASSES\n"; - for (const Pass &pass : passes) - emitOldPassDecl(pass, os); + os << "#error \"GEN_PASS_CLASSES is deprecated; use per-pass macros\"\n"; os << "#undef GEN_PASS_CLASSES\n"; os << "#endif // GEN_PASS_CLASSES\n"; } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index c3034bb8..08d6483 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1129,7 +1129,7 @@ void PatternEmitter::emit(StringRef rewriteName) { LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); // Emit RewritePattern for Pattern. - auto locs = pattern.getLocation(); + auto locs = pattern.getLocation(/*forSourceOutput=*/true); os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", llvm::reverse(locs)); os << formatv(R"(struct {0} : public ::mlir::RewritePattern { |
