diff options
Diffstat (limited to 'mlir/tools/mlir-tblgen/PassGen.cpp')
| -rw-r--r-- | mlir/tools/mlir-tblgen/PassGen.cpp | 106 |
1 files changed, 22 insertions, 84 deletions
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index f7134ce..e4ae78f 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"( //===----------------------------------------------------------------------===// // {0} Registration //===----------------------------------------------------------------------===// +#ifdef {1} inline void register{0}() {{ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; + return {2}; }); } // Old registration code, kept for temporary backwards compatibility. inline void register{0}Pass() {{ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; + return {2}; }); } + +#undef {1} +#endif // {1} )"; /// The code snippet used to generate a function to register all passes in a @@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) { return "GEN_PASS_DECL_" + pass.getDef()->getName().upper(); } +static std::string getPassRegistrationVarName(const Pass &pass) { + return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper(); +} + /// Emit the code to be included in the public header of the pass. static void emitPassDecls(const Pass &pass, raw_ostream &os) { StringRef passName = pass.getDef()->getName(); @@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) { /// PassRegistry. static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) { os << "#ifdef GEN_PASS_REGISTRATION\n"; + os << "// Generate registrations for all passes.\n"; + for (const Pass &pass : passes) + os << "#define " << getPassRegistrationVarName(pass) << "\n"; + os << "#endif // GEN_PASS_REGISTRATION\n"; for (const Pass &pass : passes) { + std::string passName = pass.getDef()->getName().str(); + std::string passEnableVarName = getPassRegistrationVarName(pass); + std::string constructorCall; if (StringRef constructor = pass.getConstructor(); !constructor.empty()) constructorCall = constructor.str(); else - constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); - - os << formatv(passRegistrationCode, pass.getDef()->getName(), + constructorCall = formatv("create{0}()", passName).str(); + os << formatv(passRegistrationCode, passName, passEnableVarName, constructorCall); } + os << "#ifdef GEN_PASS_REGISTRATION\n"; os << formatv(passGroupRegistrationCode, groupName); for (const Pass &pass : passes) @@ -372,81 +387,6 @@ static void emitPass(const Pass &pass, raw_ostream &os) { emitPassDefs(pass, os); } -// TODO: Drop old pass declarations. -// The old pass base class is being kept until all the passes have switched to -// the new decls/defs design. -const char *const oldPassDeclBegin = R"( -template <typename DerivedT> -class {0}Base : public {1} { -public: - using Base = {0}Base; - - {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} - {0}Base(const {0}Base &other) : {1}(other) {{} - {0}Base& operator=(const {0}Base &) = delete; - {0}Base({0}Base &&) = delete; - {0}Base& operator=({0}Base &&) = delete; - ~{0}Base() = default; - - /// Returns the command-line argument attached to this pass. - static constexpr ::llvm::StringLiteral getArgumentName() { - return ::llvm::StringLiteral("{2}"); - } - ::llvm::StringRef getArgument() const override { return "{2}"; } - - ::llvm::StringRef getDescription() const override { return R"PD({3})PD"; } - - /// Returns the derived pass name. - static constexpr ::llvm::StringLiteral getPassName() { - return ::llvm::StringLiteral("{0}"); - } - ::llvm::StringRef getName() const override { return "{0}"; } - - /// Support isa/dyn_cast functionality for the derived pass class. - static bool classof(const ::mlir::Pass *pass) {{ - return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); - } - - /// A clone method to create a copy of this pass. - std::unique_ptr<::mlir::Pass> clonePass() const override {{ - return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); - } - - /// Register the dialects that must be loaded in the context before this pass. - void getDependentDialects(::mlir::DialectRegistry ®istry) const override { - {4} - } - - /// Explicitly declare the TypeID for this class. We declare an explicit private - /// instantiation because Pass classes should only be visible by the current - /// library. - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) - -protected: -)"; - -// TODO: Drop old pass declarations. -/// Emit a backward-compatible declaration of the pass base class. -static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { - StringRef defName = pass.getDef()->getName(); - std::string dependentDialectRegistrations; - { - llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); - llvm::interleave( - pass.getDependentDialects(), dialectsOs, - [&](StringRef dependentDialect) { - dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); - }, - "\n "); - } - os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(), - pass.getArgument(), pass.getSummary().trim(), - dependentDialectRegistrations); - emitPassOptionDecls(pass, os); - emitPassStatisticDecls(pass, os); - os << "};\n"; -} - static void emitPasses(const RecordKeeper &records, raw_ostream &os) { std::vector<Pass> passes = getPasses(records); os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; @@ -464,12 +404,10 @@ static void emitPasses(const RecordKeeper &records, raw_ostream &os) { emitRegistrations(passes, os); - // TODO: Drop old pass declarations. + // TODO: Remove warning, kept in to make error understandable. // Emit the old code until all the passes have switched to the new design. - os << "// Deprecated. Please use the new per-pass macros.\n"; os << "#ifdef GEN_PASS_CLASSES\n"; - for (const Pass &pass : passes) - emitOldPassDecl(pass, os); + os << "#error \"GEN_PASS_CLASSES is deprecated; use per-pass macros\"\n"; os << "#undef GEN_PASS_CLASSES\n"; os << "#endif // GEN_PASS_CLASSES\n"; } |
