aboutsummaryrefslogtreecommitdiff
path: root/mlir/tools/mlir-tblgen/PassGen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/tools/mlir-tblgen/PassGen.cpp')
-rw-r--r--mlir/tools/mlir-tblgen/PassGen.cpp106
1 files changed, 22 insertions, 84 deletions
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index f7134ce..e4ae78f 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
+#ifdef {1}
inline void register{0}() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
- return {1};
+ return {2};
});
}
// Old registration code, kept for temporary backwards compatibility.
inline void register{0}Pass() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
- return {1};
+ return {2};
});
}
+
+#undef {1}
+#endif // {1}
)";
/// The code snippet used to generate a function to register all passes in a
@@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) {
return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
}
+static std::string getPassRegistrationVarName(const Pass &pass) {
+ return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper();
+}
+
/// Emit the code to be included in the public header of the pass.
static void emitPassDecls(const Pass &pass, raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
@@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
/// PassRegistry.
static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
+ os << "// Generate registrations for all passes.\n";
+ for (const Pass &pass : passes)
+ os << "#define " << getPassRegistrationVarName(pass) << "\n";
+ os << "#endif // GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
+ std::string passName = pass.getDef()->getName().str();
+ std::string passEnableVarName = getPassRegistrationVarName(pass);
+
std::string constructorCall;
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str();
else
- constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
-
- os << formatv(passRegistrationCode, pass.getDef()->getName(),
+ constructorCall = formatv("create{0}()", passName).str();
+ os << formatv(passRegistrationCode, passName, passEnableVarName,
constructorCall);
}
+ os << "#ifdef GEN_PASS_REGISTRATION\n";
os << formatv(passGroupRegistrationCode, groupName);
for (const Pass &pass : passes)
@@ -372,81 +387,6 @@ static void emitPass(const Pass &pass, raw_ostream &os) {
emitPassDefs(pass, os);
}
-// TODO: Drop old pass declarations.
-// The old pass base class is being kept until all the passes have switched to
-// the new decls/defs design.
-const char *const oldPassDeclBegin = R"(
-template <typename DerivedT>
-class {0}Base : public {1} {
-public:
- using Base = {0}Base;
-
- {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
- {0}Base(const {0}Base &other) : {1}(other) {{}
- {0}Base& operator=(const {0}Base &) = delete;
- {0}Base({0}Base &&) = delete;
- {0}Base& operator=({0}Base &&) = delete;
- ~{0}Base() = default;
-
- /// Returns the command-line argument attached to this pass.
- static constexpr ::llvm::StringLiteral getArgumentName() {
- return ::llvm::StringLiteral("{2}");
- }
- ::llvm::StringRef getArgument() const override { return "{2}"; }
-
- ::llvm::StringRef getDescription() const override { return R"PD({3})PD"; }
-
- /// Returns the derived pass name.
- static constexpr ::llvm::StringLiteral getPassName() {
- return ::llvm::StringLiteral("{0}");
- }
- ::llvm::StringRef getName() const override { return "{0}"; }
-
- /// Support isa/dyn_cast functionality for the derived pass class.
- static bool classof(const ::mlir::Pass *pass) {{
- return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
- }
-
- /// A clone method to create a copy of this pass.
- std::unique_ptr<::mlir::Pass> clonePass() const override {{
- return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
- }
-
- /// Register the dialects that must be loaded in the context before this pass.
- void getDependentDialects(::mlir::DialectRegistry &registry) const override {
- {4}
- }
-
- /// Explicitly declare the TypeID for this class. We declare an explicit private
- /// instantiation because Pass classes should only be visible by the current
- /// library.
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
-
-protected:
-)";
-
-// TODO: Drop old pass declarations.
-/// Emit a backward-compatible declaration of the pass base class.
-static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
- StringRef defName = pass.getDef()->getName();
- std::string dependentDialectRegistrations;
- {
- llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- llvm::interleave(
- pass.getDependentDialects(), dialectsOs,
- [&](StringRef dependentDialect) {
- dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
- },
- "\n ");
- }
- os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
- pass.getArgument(), pass.getSummary().trim(),
- dependentDialectRegistrations);
- emitPassOptionDecls(pass, os);
- emitPassStatisticDecls(pass, os);
- os << "};\n";
-}
-
static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
std::vector<Pass> passes = getPasses(records);
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
@@ -464,12 +404,10 @@ static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
emitRegistrations(passes, os);
- // TODO: Drop old pass declarations.
+ // TODO: Remove warning, kept in to make error understandable.
// Emit the old code until all the passes have switched to the new design.
- os << "// Deprecated. Please use the new per-pass macros.\n";
os << "#ifdef GEN_PASS_CLASSES\n";
- for (const Pass &pass : passes)
- emitOldPassDecl(pass, os);
+ os << "#error \"GEN_PASS_CLASSES is deprecated; use per-pass macros\"\n";
os << "#undef GEN_PASS_CLASSES\n";
os << "#endif // GEN_PASS_CLASSES\n";
}