diff options
-rw-r--r-- | mlir/include/mlir/Tools/mlir-translate/Translation.h | 51 | ||||
-rw-r--r-- | mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp | 8 | ||||
-rw-r--r-- | mlir/lib/Tools/mlir-translate/Translation.cpp | 51 |
3 files changed, 74 insertions, 36 deletions
diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h index d3cd817..80c4e37 100644 --- a/mlir/include/mlir/Tools/mlir-translate/Translation.h +++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h @@ -47,9 +47,44 @@ using TranslateFromMLIRFunction = using TranslateFunction = std::function<LogicalResult( llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>; +/// This class contains all of the components necessary for performing a +/// translation. +class Translation { +public: + Translation() = default; + Translation(TranslateFunction function, StringRef description, + Optional<llvm::Align> inputAlignment) + : function(std::move(function)), description(description), + inputAlignment(inputAlignment) {} + + /// Return the description of this translation. + StringRef getDescription() const { return description; } + + /// Return the optional alignment desired for the input of the translation. + Optional<llvm::Align> getInputAlignment() const { return inputAlignment; } + + /// Invoke the translation function with the given input and output streams. + LogicalResult operator()(llvm::SourceMgr &sourceMgr, + llvm::raw_ostream &output, + MLIRContext *context) const { + return function(sourceMgr, output, context); + } + +private: + /// The underlying translation function. + TranslateFunction function; + + /// The description of the translation. + StringRef description; + + /// An optional alignment desired for the input of the translation. + Optional<llvm::Align> inputAlignment; +}; + /// Use Translate[ToMLIR|FromMLIR]Registration as an initializer that /// registers a function and associates it with name. This requires that a -/// translation has not been registered to a given name. +/// translation has not been registered to a given name. `inputAlign` is an +/// optional expected alignment for the input data. /// /// Usage: /// @@ -62,10 +97,14 @@ using TranslateFunction = std::function<LogicalResult( /// /// \{ struct TranslateToMLIRRegistration { - TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description, - const TranslateSourceMgrToMLIRFunction &function); - TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description, - const TranslateStringRefToMLIRFunction &function); + TranslateToMLIRRegistration( + llvm::StringRef name, llvm::StringRef description, + const TranslateSourceMgrToMLIRFunction &function, + Optional<llvm::Align> inputAlignment = llvm::None); + TranslateToMLIRRegistration( + llvm::StringRef name, llvm::StringRef description, + const TranslateStringRefToMLIRFunction &function, + Optional<llvm::Align> inputAlignment = llvm::None); }; struct TranslateFromMLIRRegistration { @@ -99,7 +138,7 @@ struct TranslateRegistration { /// \} /// A command line parser for translation functions. -struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> { +struct TranslationParser : public llvm::cl::parser<const Translation *> { TranslationParser(llvm::cl::Option &opt); void printOptionInfo(const llvm::cl::Option &o, diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp index ef2545b..51b21f2 100644 --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -56,7 +56,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, llvm::InitLLVM y(argc, argv); // Add flags for all the registered translations. - llvm::cl::opt<const TranslateFunction *, false, TranslationParser> + llvm::cl::opt<const Translation *, false, TranslationParser> translationRequested("", llvm::cl::desc("Translation to perform"), llvm::cl::Required); registerAsmPrinterCLOptions(); @@ -65,7 +65,11 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, llvm::cl::ParseCommandLineOptions(argc, argv, toolName); std::string errorMessage; - auto input = openInputFile(inputFilename, &errorMessage); + std::unique_ptr<llvm::MemoryBuffer> input; + if (auto inputAlignment = translationRequested->getInputAlignment()) + input = openInputFile(inputFilename, *inputAlignment, &errorMessage); + else + input = openInputFile(inputFilename, &errorMessage); if (!input) { llvm::errs() << errorMessage << "\n"; return failure(); diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp index ab86cd0..548e3f9 100644 --- a/mlir/lib/Tools/mlir-translate/Translation.cpp +++ b/mlir/lib/Tools/mlir-translate/Translation.cpp @@ -40,34 +40,30 @@ void mlir::registerTranslationCLOptions() { *clOptions; } // Translation Registry //===----------------------------------------------------------------------===// -struct TranslationBundle { - TranslateFunction translateFunction; - StringRef translateDescription; -}; - -/// Get the mutable static map between registered file-to-file MLIR translations -/// and TranslateFunctions with its description that perform those translations. -static llvm::StringMap<TranslationBundle> &getTranslationRegistry() { - static llvm::StringMap<TranslationBundle> translationBundle; +/// Get the mutable static map between registered file-to-file MLIR +/// translations. +static llvm::StringMap<Translation> &getTranslationRegistry() { + static llvm::StringMap<Translation> translationBundle; return translationBundle; } /// Register the given translation. static void registerTranslation(StringRef name, StringRef description, + Optional<llvm::Align> inputAlignment, const TranslateFunction &function) { - auto &translationRegistry = getTranslationRegistry(); - if (translationRegistry.find(name) != translationRegistry.end()) + auto ®istry = getTranslationRegistry(); + if (registry.count(name)) llvm::report_fatal_error( "Attempting to overwrite an existing <file-to-file> function"); assert(function && "Attempting to register an empty translate <file-to-file> function"); - translationRegistry[name].translateFunction = function; - translationRegistry[name].translateDescription = description; + registry[name] = Translation(function, description, inputAlignment); } TranslateRegistration::TranslateRegistration( StringRef name, StringRef description, const TranslateFunction &function) { - registerTranslation(name, description, function); + registerTranslation(name, description, /*inputAlignment=*/llvm::None, + function); } //===----------------------------------------------------------------------===// @@ -77,7 +73,7 @@ TranslateRegistration::TranslateRegistration( // Puts `function` into the to-MLIR translation registry unless there is already // a function registered for the same name. static void registerTranslateToMLIRFunction( - StringRef name, StringRef description, + StringRef name, StringRef description, Optional<llvm::Align> inputAlignment, const TranslateSourceMgrToMLIRFunction &function) { auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { @@ -87,21 +83,23 @@ static void registerTranslateToMLIRFunction( op.get()->print(output); return success(); }; - registerTranslation(name, description, wrappedFn); + registerTranslation(name, description, inputAlignment, wrappedFn); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, - const TranslateSourceMgrToMLIRFunction &function) { - registerTranslateToMLIRFunction(name, description, function); + const TranslateSourceMgrToMLIRFunction &function, + Optional<llvm::Align> inputAlignment) { + registerTranslateToMLIRFunction(name, description, inputAlignment, function); } /// Wraps `function` with a lambda that extracts a StringRef from a source /// manager and registers the wrapper lambda as a to-MLIR conversion. TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, - const TranslateStringRefToMLIRFunction &function) { + const TranslateStringRefToMLIRFunction &function, + Optional<llvm::Align> inputAlignment) { registerTranslateToMLIRFunction( - name, description, + name, description, inputAlignment, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { const llvm::MemoryBuffer *buffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); @@ -117,9 +115,8 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( StringRef name, StringRef description, const TranslateFromMLIRFunction &function, const std::function<void(DialectRegistry &)> &dialectRegistration) { - registerTranslation( - name, description, + name, description, /*inputAlignment=*/llvm::None, [function, dialectRegistration](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { @@ -141,11 +138,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( //===----------------------------------------------------------------------===// TranslationParser::TranslationParser(llvm::cl::Option &opt) - : llvm::cl::parser<const TranslateFunction *>(opt) { - for (const auto &kv : getTranslationRegistry()) { - addLiteralOption(kv.first(), &kv.second.translateFunction, - kv.second.translateDescription); - } + : llvm::cl::parser<const Translation *>(opt) { + for (const auto &kv : getTranslationRegistry()) + addLiteralOption(kv.first(), &kv.second, kv.second.getDescription()); } void TranslationParser::printOptionInfo(const llvm::cl::Option &o, @@ -156,5 +151,5 @@ void TranslationParser::printOptionInfo(const llvm::cl::Option &o, const TranslationParser::OptionInfo *rhs) { return lhs->Name.compare(rhs->Name); }); - llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth); + llvm::cl::parser<const Translation *>::printOptionInfo(o, globalWidth); } |