//===-- CIRLoweringEmitter.cpp - Generate CIR lowering patterns -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This TableGen backend emits CIR operation lowering patterns. // //===----------------------------------------------------------------------===// #include "TableGenBackends.h" #include "llvm/TableGen/TableGenBackend.h" #include #include #include using namespace llvm; using namespace clang; namespace { std::vector LLVMLoweringPatterns; std::vector LLVMLoweringPatternsList; // Adapted from mlir/lib/TableGen/Operator.cpp // Returns the C++ class name of the operation, which is the name of the // operation with the dialect prefix removed and the first underscore removed. // If the operation name starts with an underscore, the underscore is considered // part of the class name. std::string GetOpCppClassName(const Record *OpRecord) { StringRef Name = OpRecord->getName(); StringRef Prefix; StringRef CppClassName; std::tie(Prefix, CppClassName) = Name.split('_'); if (Prefix.empty()) { // Class name with a leading underscore and without dialect prefix return Name.str(); } if (CppClassName.empty()) { // Class name without dialect prefix return Prefix.str(); } return CppClassName.str(); } std::string GetOpLLVMLoweringPatternName(llvm::StringRef OpName) { std::string Name = "CIRToLLVM"; Name += OpName; Name += "Lowering"; return Name; } void GenerateLLVMLoweringPattern(llvm::StringRef OpName, llvm::StringRef PatternName, bool IsRecursive, llvm::StringRef ExtraDecl) { std::string CodeBuffer; llvm::raw_string_ostream Code(CodeBuffer); Code << "class " << PatternName << " : public mlir::OpConversionPattern {\n"; Code << " [[maybe_unused]] mlir::DataLayout const &dataLayout;\n"; Code << "\n"; Code << "public:\n"; Code << " using mlir::OpConversionPattern::OpConversionPattern;\n"; Code << " " << PatternName << "(mlir::TypeConverter const " "&typeConverter, mlir::MLIRContext *context, mlir::DataLayout const " "&dataLayout)\n"; Code << " : OpConversionPattern(typeConverter, context), dataLayout(dataLayout)"; if (IsRecursive) { Code << " {\n"; Code << " setHasBoundedRewriteRecursion();\n"; Code << " }\n"; } else { Code << " {}\n"; } Code << "\n"; Code << " mlir::LogicalResult matchAndRewrite(cir::" << OpName << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) " "const override;\n"; if (!ExtraDecl.empty()) { Code << "\nprivate:\n"; Code << ExtraDecl << "\n"; } Code << "};\n"; LLVMLoweringPatterns.push_back(std::move(CodeBuffer)); } void Generate(const Record *OpRecord) { std::string OpName = GetOpCppClassName(OpRecord); if (OpRecord->getValueAsBit("hasLLVMLowering")) { std::string PatternName = GetOpLLVMLoweringPatternName(OpName); bool IsRecursive = OpRecord->getValueAsBit("isLLVMLoweringRecursive"); llvm::StringRef ExtraDecl = OpRecord->getValueAsString("extraLLVMLoweringPatternDecl"); GenerateLLVMLoweringPattern(OpName, PatternName, IsRecursive, ExtraDecl); LLVMLoweringPatternsList.push_back(std::move(PatternName)); } } } // namespace void clang::EmitCIRLowering(const llvm::RecordKeeper &RK, llvm::raw_ostream &OS) { emitSourceFileHeader("Lowering patterns for CIR operations", OS); for (const auto *OpRecord : RK.getAllDerivedDefinitions("CIR_Op")) Generate(OpRecord); OS << "#ifdef GET_LLVM_LOWERING_PATTERNS\n" << llvm::join(LLVMLoweringPatterns, "\n") << "#endif\n\n"; OS << "#ifdef GET_LLVM_LOWERING_PATTERNS_LIST\n" << llvm::join(LLVMLoweringPatternsList, ",\n") << "\n#endif\n\n"; }