diff options
| author | Callum Fare <callum@codeplay.com> | 2024-12-05 08:34:04 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-05 09:34:04 +0100 |
| commit | fd3907ccb583df99e9c19d2fe84e4e7c52d75de9 (patch) | |
| tree | deaffb6b369c1ec87261df173b32717b07f7525c /offload/tools | |
| parent | 636beb6a2833ee0290935f679252c1b662721b31 (diff) | |
| download | llvm-fd3907ccb583df99e9c19d2fe84e4e7c52d75de9.zip llvm-fd3907ccb583df99e9c19d2fe84e4e7c52d75de9.tar.gz llvm-fd3907ccb583df99e9c19d2fe84e4e7c52d75de9.tar.bz2 | |
Reland #118503: [Offload] Introduce offload-tblgen and initial new API implementation (#118614)
Reland #118503. Added a fix for builds with `-DBUILD_SHARED_LIBS=ON`
(see last commit). Otherwise the changes are identical.
---
### New API
Previous discussions at the LLVM/Offload meeting have brought up the
need for a new API for exposing the functionality of the plugins. This
change introduces a very small subset of a new API, which is primarily
for testing the offload tooling and demonstrating how a new API can fit
into the existing code base without being too disruptive. Exact designs
for these entry points and future additions can be worked out over time.
The new API does however introduce the bare minimum functionality to
implement device discovery for Unified Runtime and SYCL. This means that
the `urinfo` and `sycl-ls` tools can be used on top of Offload. A
(rough) implementation of a Unified Runtime adapter (aka plugin) for
Offload is available
[here](https://github.com/callumfare/unified-runtime/tree/offload_adapter).
Our intention is to maintain this and use it to implement and test
Offload API changes with SYCL.
### Demoing the new API
```sh
# From the runtime build directory
$ ninja LibomptUnitTests
$ OFFLOAD_TRACE=1 ./offload/unittests/OffloadAPI/offload.unittests
```
### Open questions and future work
* Only some of the available device info is exposed, and not all the
possible device queries needed for SYCL are implemented by the plugins.
A sensible next step would be to refactor and extend the existing device
info queries in the plugins. The existing info queries are all strings,
but the new API introduces the ability to return any arbitrary type.
* It may be sensible at some point for the plugins to implement the new
API directly, and the higher level code on top of it could be made
generic, but this is more of a long-term possibility.
Diffstat (limited to 'offload/tools')
| -rw-r--r-- | offload/tools/offload-tblgen/APIGen.cpp | 229 | ||||
| -rw-r--r-- | offload/tools/offload-tblgen/CMakeLists.txt | 26 | ||||
| -rw-r--r-- | offload/tools/offload-tblgen/EntryPointGen.cpp | 138 | ||||
| -rw-r--r-- | offload/tools/offload-tblgen/FuncsGen.cpp | 74 | ||||
| -rw-r--r-- | offload/tools/offload-tblgen/GenCommon.hpp | 67 | ||||
| -rw-r--r-- | offload/tools/offload-tblgen/Generators.hpp | 23 | ||||
| -rw-r--r-- | offload/tools/offload-tblgen/PrintGen.cpp | 226 | ||||
| -rw-r--r-- | offload/tools/offload-tblgen/RecordTypes.hpp | 227 | ||||
| -rw-r--r-- | offload/tools/offload-tblgen/offload-tblgen.cpp | 101 |
9 files changed, 1111 insertions, 0 deletions
diff --git a/offload/tools/offload-tblgen/APIGen.cpp b/offload/tools/offload-tblgen/APIGen.cpp new file mode 100644 index 0000000..97a2464 --- /dev/null +++ b/offload/tools/offload-tblgen/APIGen.cpp @@ -0,0 +1,229 @@ +//===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload header ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen backend that produces the contents of the Offload API +// header. The generated comments are Doxygen compatible. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#include "GenCommon.hpp" +#include "RecordTypes.hpp" + +using namespace llvm; +using namespace offload::tblgen; + +// Produce a possibly multi-line comment from the input string +static std::string MakeComment(StringRef in) { + std::string out = ""; + size_t LineStart = 0; + size_t LineBreak = 0; + while (LineBreak < in.size()) { + LineBreak = in.find_first_of("\n", LineStart); + if (LineBreak - LineStart <= 1) { + break; + } + out += std::string("/// ") + + in.substr(LineStart, LineBreak - LineStart).str() + "\n"; + LineStart = LineBreak + 1; + } + + return out; +} + +static void ProcessHandle(const HandleRec &H, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", H.getDesc()); + OS << formatv("typedef struct {0}_ *{0};\n", H.getName()); +} + +static void ProcessTypedef(const TypedefRec &T, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", T.getDesc()); + OS << formatv("typedef {0} {1};\n", T.getValue(), T.getName()); +} + +static void ProcessMacro(const MacroRec &M, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("#ifndef {0}\n", M.getName()); + if (auto Condition = M.getCondition()) { + OS << formatv("#if {0}\n", *Condition); + } + OS << "/// @brief " << M.getDesc() << "\n"; + OS << formatv("#define {0} {1}\n", M.getNameWithArgs(), M.getValue()); + if (auto AltValue = M.getAltValue()) { + OS << "#else\n"; + OS << formatv("#define {0} {1}\n", M.getNameWithArgs(), *AltValue); + } + if (auto Condition = M.getCondition()) { + OS << formatv("#endif // {0}\n", *Condition); + } + OS << formatv("#endif // {0}\n", M.getName()); +} + +static void ProcessFunction(const FunctionRec &F, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", F.getDesc()); + OS << CommentsBreak; + + OS << "/// @details\n"; + for (auto &Detail : F.getDetails()) { + OS << formatv("/// - {0}\n", Detail); + } + OS << CommentsBreak; + + // Emit analogue remarks + auto Analogues = F.getAnalogues(); + if (!Analogues.empty()) { + OS << "/// @remarks\n/// _Analogues_\n"; + for (auto &Analogue : Analogues) { + OS << formatv("/// - **{0}**\n", Analogue); + } + OS << CommentsBreak; + } + + OS << "/// @returns\n"; + auto Returns = F.getReturns(); + for (auto &Ret : Returns) { + OS << formatv("/// - ::{0}\n", Ret.getValue()); + auto RetConditions = Ret.getConditions(); + for (auto &RetCondition : RetConditions) { + OS << formatv("/// + {0}\n", RetCondition); + } + } + + OS << formatv("{0}_APIEXPORT {1}_result_t {0}_APICALL ", PrefixUpper, + PrefixLower); + OS << F.getName(); + OS << "(\n"; + auto Params = F.getParams(); + for (auto &Param : Params) { + OS << MakeParamComment(Param) << "\n"; + OS << " " << Param.getType() << " " << Param.getName(); + if (Param != Params.back()) { + OS << ",\n"; + } else { + OS << "\n"; + } + } + OS << ");\n\n"; +} + +static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", Enum.getDesc()); + OS << formatv("typedef enum {0} {{\n", Enum.getName()); + + uint32_t EtorVal = 0; + for (const auto &EnumVal : Enum.getValues()) { + if (Enum.isTyped()) { + OS << MakeComment( + formatv("[{0}] {1}", EnumVal.getTaggedType(), EnumVal.getDesc()) + .str()); + } else { + OS << MakeComment(EnumVal.getDesc()); + } + OS << formatv(TAB_1 "{0}_{1} = {2},\n", Enum.getEnumValNamePrefix(), + EnumVal.getName(), EtorVal++); + } + + // Add force uint32 val + OS << formatv(TAB_1 "/// @cond\n" TAB_1 + "{0}_FORCE_UINT32 = 0x7fffffff\n" TAB_1 + "/// @endcond\n\n", + Enum.getEnumValNamePrefix()); + + OS << formatv("} {0};\n", Enum.getName()); +} + +static void ProcessStruct(const StructRec &Struct, raw_ostream &OS) { + OS << CommentsHeader; + OS << formatv("/// @brief {0}\n", Struct.getDesc()); + OS << formatv("typedef struct {0} {{\n", Struct.getName()); + + for (const auto &Member : Struct.getMembers()) { + OS << formatv(TAB_1 "{0} {1}; {2}", Member.getType(), Member.getName(), + MakeComment(Member.getDesc())); + } + + OS << formatv("} {0};\n\n", Struct.getName()); +} + +static void ProcessFuncParamStruct(const FunctionRec &Func, raw_ostream &OS) { + if (Func.getParams().size() == 0) { + return; + } + + auto FuncParamStructBegin = R"( +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for {0} +/// @details Each entry is a pointer to the parameter passed to the function; +typedef struct {1} {{ +)"; + + OS << formatv(FuncParamStructBegin, Func.getName(), + Func.getParamStructName()); + for (const auto &Param : Func.getParams()) { + OS << TAB_1 << Param.getType() << "* p" << Param.getName() << ";\n"; + } + OS << formatv("} {0};\n", Func.getParamStructName()); +} + +static void ProcessFuncWithCodeLocVariant(const FunctionRec &Func, + raw_ostream &OS) { + + auto FuncWithCodeLocBegin = R"( +/////////////////////////////////////////////////////////////////////////////// +/// @brief Variant of {0} that also sets source code location information +/// @details See also ::{0} +OL_APIEXPORT ol_result_t OL_APICALL {0}WithCodeLoc( +)"; + OS << formatv(FuncWithCodeLocBegin, Func.getName()); + auto Params = Func.getParams(); + for (auto &Param : Params) { + OS << " " << Param.getType() << " " << Param.getName(); + OS << ",\n"; + } + OS << "ol_code_location_t *CodeLocation);\n\n"; +} + +void EmitOffloadAPI(const RecordKeeper &Records, raw_ostream &OS) { + OS << GenericHeader; + OS << FileHeader; + // Generate main API definitions + for (auto *R : Records.getAllDerivedDefinitions("APIObject")) { + if (R->isSubClassOf("Macro")) { + ProcessMacro(MacroRec{R}, OS); + } else if (R->isSubClassOf("Typedef")) { + ProcessTypedef(TypedefRec{R}, OS); + } else if (R->isSubClassOf("Handle")) { + ProcessHandle(HandleRec{R}, OS); + } else if (R->isSubClassOf("Function")) { + ProcessFunction(FunctionRec{R}, OS); + } else if (R->isSubClassOf("Enum")) { + ProcessEnum(EnumRec{R}, OS); + } else if (R->isSubClassOf("Struct")) { + ProcessStruct(StructRec{R}, OS); + } + } + + // Generate auxiliary definitions (func param structs etc) + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + ProcessFuncParamStruct(FunctionRec{R}, OS); + } + + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + ProcessFuncWithCodeLocVariant(FunctionRec{R}, OS); + } + + OS << FileFooter; +} diff --git a/offload/tools/offload-tblgen/CMakeLists.txt b/offload/tools/offload-tblgen/CMakeLists.txt new file mode 100644 index 0000000..e7e7c85 --- /dev/null +++ b/offload/tools/offload-tblgen/CMakeLists.txt @@ -0,0 +1,26 @@ +##===----------------------------------------------------------------------===## +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +##===----------------------------------------------------------------------===## +include(TableGen) + +set(LLVM_LINK_COMPONENTS Support) + +add_tablegen(offload-tblgen OFFLOAD + EXPORT OFFLOAD + APIGen.cpp + EntryPointGen.cpp + FuncsGen.cpp + GenCommon.hpp + Generators.hpp + offload-tblgen.cpp + PrintGen.cpp + RecordTypes.hpp + ) + +set(OFFLOAD_TABLEGEN_EXE "${OFFLOAD_TABLEGEN_EXE}" CACHE INTERNAL "") +set(OFFLOAD_TABLEGEN_TARGET "${OFFLOAD_TABLEGEN_TARGET}" CACHE INTERNAL "") + diff --git a/offload/tools/offload-tblgen/EntryPointGen.cpp b/offload/tools/offload-tblgen/EntryPointGen.cpp new file mode 100644 index 0000000..990ff96 --- /dev/null +++ b/offload/tools/offload-tblgen/EntryPointGen.cpp @@ -0,0 +1,138 @@ +//===- offload-tblgen/EntryPointGen.cpp - Tablegen backend for Offload ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen backend that produces the actual entry points for the +// Offload API. It serves as a place to integrate functionality like tracing +// and validation before dispatching to the actual implementations. +//===----------------------------------------------------------------------===// + +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Record.h" + +#include "GenCommon.hpp" +#include "RecordTypes.hpp" + +using namespace llvm; +using namespace offload::tblgen; + +static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { + OS << CommentsHeader; + // Emit preamble + OS << formatv("{0}_impl_result_t {1}_val(\n ", PrefixLower, F.getName()); + // Emit arguments + std::string ParamNameList = ""; + for (auto &Param : F.getParams()) { + OS << Param.getType() << " " << Param.getName(); + if (Param != F.getParams().back()) { + OS << ", "; + } + ParamNameList += Param.getName().str() + ", "; + } + OS << ") {\n"; + + OS << TAB_1 "if (true /*enableParameterValidation*/) {\n"; + // Emit validation checks + for (const auto &Return : F.getReturns()) { + for (auto &Condition : Return.getConditions()) { + if (Condition.starts_with("`") && Condition.ends_with("`")) { + auto ConditionString = Condition.substr(1, Condition.size() - 2); + OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString); + OS << formatv(TAB_3 "return {0};\n", Return.getValue()); + OS << TAB_2 "}\n\n"; + } + } + } + OS << TAB_1 "}\n\n"; + + // Perform actual function call to the implementation + ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); + OS << formatv(TAB_1 "return {0}_impl({1});\n\n", F.getName(), ParamNameList); + OS << "}\n"; +} + +static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { + // Emit preamble + OS << formatv("{1}_APIEXPORT {0}_result_t {1}_APICALL {2}(\n ", PrefixLower, + PrefixUpper, F.getName()); + // Emit arguments + std::string ParamNameList = ""; + for (auto &Param : F.getParams()) { + OS << Param.getType() << " " << Param.getName(); + if (Param != F.getParams().back()) { + OS << ", "; + } + ParamNameList += Param.getName().str() + ", "; + } + OS << ") {\n"; + + // Emit pre-call prints + OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; + OS << formatv(TAB_2 "std::cout << \"---> {0}\";\n", F.getName()); + OS << TAB_1 "}\n\n"; + + // Perform actual function call to the validation wrapper + ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); + OS << formatv(TAB_1 "{0}_result_t Result = {1}_val({2});\n\n", PrefixLower, + F.getName(), ParamNameList); + + // Emit post-call prints + OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; + if (F.getParams().size() > 0) { + OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName()); + for (const auto &Param : F.getParams()) { + OS << "&" << Param.getName(); + if (Param != F.getParams().back()) { + OS << ", "; + } + } + OS << formatv("};\n"); + OS << TAB_2 "std::cout << \"(\" << &Params << \")\";\n"; + } else { + OS << TAB_2 "std::cout << \"()\";\n"; + } + OS << TAB_2 "std::cout << \"-> \" << Result << \"\\n\";\n"; + OS << TAB_2 "if (Result && Result->Details) {\n"; + OS << TAB_3 "std::cout << \" *Error Details* \" << Result->Details " + "<< \" \\n\";\n"; + OS << TAB_2 "}\n"; + OS << TAB_1 "}\n"; + + OS << TAB_1 "return Result;\n"; + OS << "}\n"; +} + +static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) { + // Emit preamble + OS << formatv("{0}_result_t {1}WithCodeLoc(\n ", PrefixLower, F.getName()); + // Emit arguments + std::string ParamNameList = ""; + for (auto &Param : F.getParams()) { + OS << Param.getType() << " " << Param.getName() << ", "; + ParamNameList += Param.getName().str(); + if (Param != F.getParams().back()) { + ParamNameList += ", "; + } + } + OS << "ol_code_location_t *CodeLocation"; + OS << ") {\n"; + OS << TAB_1 "currentCodeLocation() = CodeLocation;\n"; + OS << formatv(TAB_1 "{0}_result_t Result = {1}({2});\n\n", PrefixLower, + F.getName(), ParamNameList); + OS << TAB_1 "currentCodeLocation() = nullptr;\n"; + OS << TAB_1 "return Result;\n"; + OS << "}\n"; +} + +void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) { + OS << GenericHeader; + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + EmitValidationFunc(FunctionRec{R}, OS); + EmitEntryPointFunc(FunctionRec{R}, OS); + EmitCodeLocWrapper(FunctionRec{R}, OS); + } +} diff --git a/offload/tools/offload-tblgen/FuncsGen.cpp b/offload/tools/offload-tblgen/FuncsGen.cpp new file mode 100644 index 0000000..3238652 --- /dev/null +++ b/offload/tools/offload-tblgen/FuncsGen.cpp @@ -0,0 +1,74 @@ +//===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload functions -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen backend that handles generation of various small files +// pertaining to the API functions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Record.h" + +#include "GenCommon.hpp" +#include "RecordTypes.hpp" + +using namespace llvm; +using namespace offload::tblgen; + +// Emit a list of just the API function names +void EmitOffloadFuncNames(const RecordKeeper &Records, raw_ostream &OS) { + OS << GenericHeader; + OS << R"( +#ifndef OFFLOAD_FUNC +#error Please define the macro OFFLOAD_FUNC(Function) +#endif + +)"; + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + FunctionRec FR{R}; + OS << formatv("OFFLOAD_FUNC({0})", FR.getName()) << "\n"; + } + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + FunctionRec FR{R}; + OS << formatv("OFFLOAD_FUNC({0}WithCodeLoc)", FR.getName()) << "\n"; + } + + OS << "\n#undef OFFLOAD_FUNC\n"; +} + +void EmitOffloadExports(const RecordKeeper &Records, raw_ostream &OS) { + OS << "VERS1.0 {\n"; + OS << TAB_1 "global:\n"; + + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + OS << formatv(TAB_2 "{0};\n", FunctionRec(R).getName()); + } + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + OS << formatv(TAB_2 "{0}WithCodeLoc;\n", FunctionRec(R).getName()); + } + OS << TAB_1 "local:\n"; + OS << TAB_2 "*;\n"; + OS << "};\n"; +} + +// Emit declarations for every implementation function +void EmitOffloadImplFuncDecls(const RecordKeeper &Records, raw_ostream &OS) { + OS << GenericHeader; + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + FunctionRec F{R}; + OS << formatv("{0}_impl_result_t {1}_impl(", PrefixLower, F.getName()); + auto Params = F.getParams(); + for (auto &Param : Params) { + OS << Param.getType() << " " << Param.getName(); + if (Param != Params.back()) { + OS << ", "; + } + } + OS << ");\n\n"; + } +} diff --git a/offload/tools/offload-tblgen/GenCommon.hpp b/offload/tools/offload-tblgen/GenCommon.hpp new file mode 100644 index 0000000..db432e9 --- /dev/null +++ b/offload/tools/offload-tblgen/GenCommon.hpp @@ -0,0 +1,67 @@ +//===- offload-tblgen/GenCommon.cpp - Common defs for Offload generators --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "RecordTypes.hpp" +#include "llvm/Support/FormatVariadic.h" + +// Having inline bits of tabbed code is hard to read, provide some definitions +// so we can keep things tidier +#define TAB_1 " " +#define TAB_2 " " +#define TAB_3 " " +#define TAB_4 " " +#define TAB_5 " " + +constexpr auto GenericHeader = + R"(//===- Auto-generated file, part of the LLVM/Offload project --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +)"; + +constexpr auto FileHeader = R"( +// Auto-generated file, do not manually edit. + +#pragma once + +#include <stddef.h> +#include <stdint.h> + +#if defined(__cplusplus) +extern "C" { +#endif + +)"; + +constexpr auto FileFooter = R"( +#if defined(__cplusplus) +} // extern "C" +#endif + +)"; + +constexpr auto CommentsHeader = R"( +/////////////////////////////////////////////////////////////////////////////// +)"; + +constexpr auto CommentsBreak = "///\n"; + +constexpr auto PrefixLower = "ol"; +constexpr auto PrefixUpper = "OL"; + +inline std::string +MakeParamComment(const llvm::offload::tblgen::ParamRec &Param) { + return llvm::formatv("// {0}{1}{2} {3}", (Param.isIn() ? "[in]" : ""), + (Param.isOut() ? "[out]" : ""), + (Param.isOpt() ? "[optional]" : ""), Param.getDesc()); +} diff --git a/offload/tools/offload-tblgen/Generators.hpp b/offload/tools/offload-tblgen/Generators.hpp new file mode 100644 index 0000000..8b6104c --- /dev/null +++ b/offload/tools/offload-tblgen/Generators.hpp @@ -0,0 +1,23 @@ +//===- offload-tblgen/Generators.hpp - Offload generator declarations -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "llvm/TableGen/Record.h" + +void EmitOffloadAPI(const llvm::RecordKeeper &Records, llvm::raw_ostream &OS); +void EmitOffloadFuncNames(const llvm::RecordKeeper &Records, + llvm::raw_ostream &OS); +void EmitOffloadImplFuncDecls(const llvm::RecordKeeper &Records, + llvm::raw_ostream &OS); +void EmitOffloadEntryPoints(const llvm::RecordKeeper &Records, + llvm::raw_ostream &OS); +void EmitOffloadPrintHeader(const llvm::RecordKeeper &Records, + llvm::raw_ostream &OS); +void EmitOffloadExports(const llvm::RecordKeeper &Records, + llvm::raw_ostream &OS); diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp new file mode 100644 index 0000000..2a7c63c --- /dev/null +++ b/offload/tools/offload-tblgen/PrintGen.cpp @@ -0,0 +1,226 @@ +//===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload printing --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen backend that produces print functions for the Offload API +// entry point functions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Record.h" + +#include "GenCommon.hpp" +#include "RecordTypes.hpp" + +using namespace llvm; +using namespace offload::tblgen; + +constexpr auto PrintEnumHeader = + R"(/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the {0} type +/// @returns std::ostream & +)"; + +constexpr auto PrintTaggedEnumHeader = + R"(/////////////////////////////////////////////////////////////////////////////// +/// @brief Print type-tagged {0} enum value +/// @returns std::ostream & +)"; + +static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) { + OS << formatv(PrintEnumHeader, Enum.getName()); + OS << formatv( + "inline std::ostream &operator<<(std::ostream &os, enum {0} value) " + "{{\n" TAB_1 "switch (value) {{\n", + Enum.getName()); + + for (const auto &Val : Enum.getValues()) { + auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); + OS << formatv(TAB_1 "case {0}:\n", Name); + OS << formatv(TAB_2 "os << \"{0}\";\n", Name); + OS << formatv(TAB_2 "break;\n"); + } + + OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2 + "break;\n" TAB_1 "}\n" TAB_1 "return os;\n}\n\n"; + + if (!Enum.isTyped()) { + return; + } + + OS << formatv(PrintTaggedEnumHeader, Enum.getName()); + + OS << formatv(R"""(template <> +inline void printTagged(std::ostream &os, const void *ptr, {0} value, size_t size) {{ + if (ptr == NULL) {{ + printPtr(os, ptr); + return; + } + + switch (value) {{ +)""", + Enum.getName()); + + for (const auto &Val : Enum.getValues()) { + auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); + auto Type = Val.getTaggedType(); + OS << formatv(TAB_1 "case {0}: {{\n", Name); + // Special case for strings + if (Type == "char[]") { + OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n"); + } else { + OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", + Type); + // TODO: Handle other cases here + OS << TAB_2 "os << (const void *)tptr << \" (\";\n"; + if (Type.ends_with("*")) { + OS << TAB_2 "os << printPtr(os, tptr);\n"; + } else { + OS << TAB_2 "os << *tptr;\n"; + } + OS << TAB_2 "os << \")\";\n"; + } + OS << formatv(TAB_2 "break;\n" TAB_1 "}\n"); + } + + OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2 + "break;\n" TAB_1 "}\n"; + + OS << "}\n"; +} + +static void EmitResultPrint(raw_ostream &OS) { + OS << R""( +inline std::ostream &operator<<(std::ostream &os, + const ol_error_struct_t *Err) { + if (Err == nullptr) { + os << "OL_SUCCESS"; + } else { + os << Err->Code; + } + return os; +} +)""; +} + +static void EmitFunctionParamStructPrint(const FunctionRec &Func, + raw_ostream &OS) { + if (Func.getParams().size() == 0) { + return; + } + + OS << formatv(R"( +inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{ +)", + Func.getParamStructName()); + + for (const auto &Param : Func.getParams()) { + OS << formatv(TAB_1 "os << \".{0} = \";\n", Param.getName()); + if (auto Range = Param.getRange()) { + OS << formatv(TAB_1 "os << \"{{\";\n"); + OS << formatv(TAB_1 "for (size_t i = {0}; i < *params->p{1}; i++) {{\n", + Range->first, Range->second); + OS << TAB_2 "if (i > 0) {\n"; + OS << TAB_3 " os << \", \";\n"; + OS << TAB_2 "}\n"; + OS << formatv(TAB_2 "printPtr(os, (*params->p{0})[i]);\n", + Param.getName()); + OS << formatv(TAB_1 "}\n"); + OS << formatv(TAB_1 "os << \"}\";\n"); + } else if (auto TypeInfo = Param.getTypeInfo()) { + OS << formatv( + TAB_1 + "printTagged(os, *params->p{0}, *params->p{1}, *params->p{2});\n", + Param.getName(), TypeInfo->first, TypeInfo->second); + } else if (Param.isPointerType() || Param.isHandleType()) { + OS << formatv(TAB_1 "printPtr(os, *params->p{0});\n", Param.getName()); + } else { + OS << formatv(TAB_1 "os << *params->p{0};\n", Param.getName()); + } + if (Param != Func.getParams().back()) { + OS << TAB_1 "os << \", \";\n"; + } + } + + OS << TAB_1 "return os;\n}\n"; +} + +void EmitOffloadPrintHeader(const RecordKeeper &Records, raw_ostream &OS) { + OS << GenericHeader; + OS << R"""( +// Auto-generated file, do not manually edit. + +#pragma once + +#include <OffloadAPI.h> +#include <ostream> + + +template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr); +template <typename T> inline void printTagged(std::ostream &os, const void *ptr, T value, size_t size); +)"""; + + // ========== + OS << "template <typename T> struct is_handle : std::false_type {};\n"; + for (auto *R : Records.getAllDerivedDefinitions("Handle")) { + HandleRec H{R}; + OS << formatv("template <> struct is_handle<{0}> : std::true_type {{};\n", + H.getName()); + } + OS << "template <typename T> inline constexpr bool is_handle_v = " + "is_handle<T>::value;\n"; + // ========= + + // Forward declare the operator<< overloads so their implementations can + // use each other. + OS << "\n"; + for (auto *R : Records.getAllDerivedDefinitions("Enum")) { + OS << formatv( + "inline std::ostream &operator<<(std::ostream &os, enum {0} value);\n", + EnumRec{R}.getName()); + } + OS << "\n"; + + // Create definitions + for (auto *R : Records.getAllDerivedDefinitions("Enum")) { + EnumRec E{R}; + ProcessEnum(E, OS); + } + EmitResultPrint(OS); + + // Emit print functions for the function param structs + for (auto *R : Records.getAllDerivedDefinitions("Function")) { + EmitFunctionParamStructPrint(FunctionRec{R}, OS); + } + + OS << R"""( +/////////////////////////////////////////////////////////////////////////////// +// @brief Print pointer value +template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr) { + if (ptr == nullptr) { + os << "nullptr"; + } else if constexpr (std::is_pointer_v<T>) { + os << (const void *)(ptr) << " ("; + printPtr(os, *ptr); + os << ")"; + } else if constexpr (std::is_void_v<T> || is_handle_v<T *>) { + os << (const void *)ptr; + } else if constexpr (std::is_same_v<std::remove_cv_t< T >, char>) { + os << (const void *)(ptr) << " ("; + os << ptr; + os << ")"; + } else { + os << (const void *)(ptr) << " ("; + os << *ptr; + os << ")"; + } + + return OL_SUCCESS; +} + )"""; +} diff --git a/offload/tools/offload-tblgen/RecordTypes.hpp b/offload/tools/offload-tblgen/RecordTypes.hpp new file mode 100644 index 0000000..0bf3256c --- /dev/null +++ b/offload/tools/offload-tblgen/RecordTypes.hpp @@ -0,0 +1,227 @@ +//===- offload-tblgen/RecordTypes.cpp - Offload record type wrappers -----===-// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include <string> + +#include "llvm/TableGen/Record.h" + +namespace llvm { +namespace offload { +namespace tblgen { + +class HandleRec { +public: + explicit HandleRec(const Record *rec) : rec(rec) {} + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + +private: + const Record *rec; +}; + +class MacroRec { +public: + explicit MacroRec(const Record *rec) : rec(rec) { + auto Name = rec->getValueAsString("name"); + auto OpenBrace = Name.find_first_of("("); + nameWithoutArgs = Name.substr(0, OpenBrace); + } + StringRef getName() const { return nameWithoutArgs; } + StringRef getNameWithArgs() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + + std::optional<StringRef> getCondition() const { + return rec->getValueAsOptionalString("condition"); + } + StringRef getValue() const { return rec->getValueAsString("value"); } + std::optional<StringRef> getAltValue() const { + return rec->getValueAsOptionalString("alt_value"); + } + +private: + const Record *rec; + std::string nameWithoutArgs; +}; + +class TypedefRec { +public: + explicit TypedefRec(const Record *rec) : rec(rec) {} + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + StringRef getValue() const { return rec->getValueAsString("value"); } + +private: + const Record *rec; +}; + +class EnumValueRec { +public: + explicit EnumValueRec(const Record *rec) : rec(rec) {} + std::string getName() const { return rec->getValueAsString("name").upper(); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + StringRef getTaggedType() const { + return rec->getValueAsString("tagged_type"); + } + +private: + const Record *rec; +}; + +class EnumRec { +public: + explicit EnumRec(const Record *rec) : rec(rec) { + for (const auto *Val : rec->getValueAsListOfDefs("etors")) { + vals.emplace_back(EnumValueRec{Val}); + } + } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + const std::vector<EnumValueRec> &getValues() const { return vals; } + + std::string getEnumValNamePrefix() const { + return StringRef(getName().str().substr(0, getName().str().length() - 2)) + .upper(); + } + + bool isTyped() const { return rec->getValueAsBit("is_typed"); } + +private: + const Record *rec; + std::vector<EnumValueRec> vals; +}; + +class StructMemberRec { +public: + explicit StructMemberRec(const Record *rec) : rec(rec) {} + StringRef getType() const { return rec->getValueAsString("type"); } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + +private: + const Record *rec; +}; + +class StructRec { +public: + explicit StructRec(const Record *rec) : rec(rec) { + for (auto *Member : rec->getValueAsListOfDefs("all_members")) { + members.emplace_back(StructMemberRec(Member)); + } + } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + std::optional<StringRef> getBaseClass() const { + return rec->getValueAsOptionalString("base_class"); + } + const std::vector<StructMemberRec> &getMembers() const { return members; } + +private: + const Record *rec; + std::vector<StructMemberRec> members; +}; + +class ParamRec { +public: + explicit ParamRec(const Record *rec) : rec(rec) { + flags = rec->getValueAsBitsInit("flags"); + auto *Range = rec->getValueAsDef("range"); + auto RangeBegin = Range->getValueAsString("begin"); + auto RangeEnd = Range->getValueAsString("end"); + if (RangeBegin != "" && RangeEnd != "") { + range = {RangeBegin, RangeEnd}; + } else { + range = std::nullopt; + } + + auto *TypeInfo = rec->getValueAsDef("type_info"); + auto TypeInfoEnum = TypeInfo->getValueAsString("enum"); + auto TypeInfoSize = TypeInfo->getValueAsString("size"); + if (TypeInfoEnum != "" && TypeInfoSize != "") { + typeinfo = {TypeInfoEnum, TypeInfoSize}; + } else { + typeinfo = std::nullopt; + } + } + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getType() const { return rec->getValueAsString("type"); } + bool isPointerType() const { return getType().ends_with('*'); } + bool isHandleType() const { return getType().ends_with("_handle_t"); } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + bool isIn() const { return dyn_cast<BitInit>(flags->getBit(0))->getValue(); } + bool isOut() const { return dyn_cast<BitInit>(flags->getBit(1))->getValue(); } + bool isOpt() const { return dyn_cast<BitInit>(flags->getBit(2))->getValue(); } + + const Record *getRec() const { return rec; } + std::optional<std::pair<StringRef, StringRef>> getRange() const { + return range; + } + + std::optional<std::pair<StringRef, StringRef>> getTypeInfo() const { + return typeinfo; + } + + // Needed to check whether we're at the back of a vector of params + bool operator!=(const ParamRec &p) const { return rec != p.getRec(); } + +private: + const Record *rec; + const BitsInit *flags; + std::optional<std::pair<StringRef, StringRef>> range; + std::optional<std::pair<StringRef, StringRef>> typeinfo; +}; + +class ReturnRec { +public: + ReturnRec(const Record *rec) : rec(rec) {} + StringRef getValue() const { return rec->getValueAsString("value"); } + std::vector<StringRef> getConditions() const { + return rec->getValueAsListOfStrings("conditions"); + } + +private: + const Record *rec; +}; + +class FunctionRec { +public: + FunctionRec(const Record *rec) : rec(rec) { + for (auto &Ret : rec->getValueAsListOfDefs("all_returns")) + rets.emplace_back(Ret); + for (auto &Param : rec->getValueAsListOfDefs("params")) + params.emplace_back(Param); + } + + std::string getParamStructName() const { + return llvm::formatv("{0}_params_t", + llvm::convertToSnakeFromCamelCase(getName())); + } + + StringRef getName() const { return rec->getValueAsString("name"); } + StringRef getClass() const { return rec->getValueAsString("api_class"); } + const std::vector<ReturnRec> &getReturns() const { return rets; } + const std::vector<ParamRec> &getParams() const { return params; } + StringRef getDesc() const { return rec->getValueAsString("desc"); } + std::vector<StringRef> getDetails() const { + return rec->getValueAsListOfStrings("details"); + } + std::vector<StringRef> getAnalogues() const { + return rec->getValueAsListOfStrings("analogues"); + } + +private: + std::vector<ReturnRec> rets; + std::vector<ParamRec> params; + + const Record *rec; +}; + +} // namespace tblgen +} // namespace offload +} // namespace llvm diff --git a/offload/tools/offload-tblgen/offload-tblgen.cpp b/offload/tools/offload-tblgen/offload-tblgen.cpp new file mode 100644 index 0000000..1912abf --- /dev/null +++ b/offload/tools/offload-tblgen/offload-tblgen.cpp @@ -0,0 +1,101 @@ +//===- offload-tblgen/offload-tblgen.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a Tablegen tool that produces source files for the Offload project. +// See offload/API/README.md for more information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" + +#include "Generators.hpp" + +namespace llvm { +namespace offload { +namespace tblgen { + +enum ActionType { + PrintRecords, + DumpJSON, + GenAPI, + GenFuncNames, + GenImplFuncDecls, + GenEntryPoints, + GenPrintHeader, + GenExports +}; + +namespace { +cl::opt<ActionType> Action( + cl::desc("Action to perform:"), + cl::values( + clEnumValN(PrintRecords, "print-records", + "Print all records to stdout (default)"), + clEnumValN(DumpJSON, "dump-json", + "Dump all records as machine-readable JSON"), + clEnumValN(GenAPI, "gen-api", "Generate Offload API header contents"), + clEnumValN(GenFuncNames, "gen-func-names", + "Generate a list of all Offload API function names"), + clEnumValN( + GenImplFuncDecls, "gen-impl-func-decls", + "Generate declarations for Offload API implementation functions"), + clEnumValN(GenEntryPoints, "gen-entry-points", + "Generate Offload API wrapper function definitions"), + clEnumValN(GenPrintHeader, "gen-print-header", + "Generate Offload API print header"), + clEnumValN(GenExports, "gen-exports", + "Generate export file for the Offload library"))); +} + +static bool OffloadTableGenMain(raw_ostream &OS, const RecordKeeper &Records) { + switch (Action) { + case PrintRecords: + OS << Records; + break; + case DumpJSON: + EmitJSON(Records, OS); + break; + case GenAPI: + EmitOffloadAPI(Records, OS); + break; + case GenFuncNames: + EmitOffloadFuncNames(Records, OS); + break; + case GenImplFuncDecls: + EmitOffloadImplFuncDecls(Records, OS); + break; + case GenEntryPoints: + EmitOffloadEntryPoints(Records, OS); + break; + case GenPrintHeader: + EmitOffloadPrintHeader(Records, OS); + break; + case GenExports: + EmitOffloadExports(Records, OS); + break; + } + + return false; +} + +int OffloadTblgenMain(int argc, char **argv) { + InitLLVM y(argc, argv); + cl::ParseCommandLineOptions(argc, argv); + return TableGenMain(argv[0], &OffloadTableGenMain); +} +} // namespace tblgen +} // namespace offload +} // namespace llvm + +using namespace llvm; +using namespace offload::tblgen; + +int main(int argc, char **argv) { return OffloadTblgenMain(argc, argv); } |
