| 1 | //===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload printing --===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This is a Tablegen backend that produces print functions for the Offload API |
| 10 | // entry point functions. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "llvm/Support/FormatVariadic.h" |
| 15 | #include "llvm/TableGen/Record.h" |
| 16 | |
| 17 | #include "GenCommon.hpp" |
| 18 | #include "RecordTypes.hpp" |
| 19 | |
| 20 | using namespace llvm; |
| 21 | using namespace offload::tblgen; |
| 22 | |
| 23 | constexpr auto = |
| 24 | R"(/////////////////////////////////////////////////////////////////////////////// |
| 25 | /// @brief Print operator for the {0} type |
| 26 | /// @returns llvm::raw_ostream & |
| 27 | )" ; |
| 28 | |
| 29 | constexpr auto = |
| 30 | R"(/////////////////////////////////////////////////////////////////////////////// |
| 31 | /// @brief Print type-tagged {0} enum value |
| 32 | /// @returns llvm::raw_ostream & |
| 33 | )" ; |
| 34 | |
| 35 | static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) { |
| 36 | OS << formatv(Fmt: PrintTypeHeader, Vals: Enum.getName()); |
| 37 | OS << formatv(Fmt: "inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, " |
| 38 | "enum {0} value) " |
| 39 | "{{\n" TAB_1 "switch (value) {{\n" , |
| 40 | Vals: Enum.getName()); |
| 41 | |
| 42 | for (const auto &Val : Enum.getValues()) { |
| 43 | auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); |
| 44 | OS << formatv(TAB_1 "case {0}:\n" , Vals&: Name); |
| 45 | OS << formatv(TAB_2 "os << \"{0}\";\n" , Vals&: Name); |
| 46 | OS << formatv(TAB_2 "break;\n" ); |
| 47 | } |
| 48 | |
| 49 | OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2 |
| 50 | "break;\n" TAB_1 "}\n" TAB_1 "return os;\n}\n\n" ; |
| 51 | |
| 52 | if (!Enum.isTyped()) { |
| 53 | return; |
| 54 | } |
| 55 | |
| 56 | OS << formatv(Fmt: PrintTaggedEnumHeader, Vals: Enum.getName()); |
| 57 | |
| 58 | OS << formatv(Fmt: R"""(template <> |
| 59 | inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_t size) {{ |
| 60 | if (ptr == NULL) {{ |
| 61 | printPtr(os, ptr); |
| 62 | return; |
| 63 | } |
| 64 | |
| 65 | switch (value) {{ |
| 66 | )""" , |
| 67 | Vals: Enum.getName()); |
| 68 | |
| 69 | for (const auto &Val : Enum.getValues()) { |
| 70 | auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); |
| 71 | auto Type = Val.getTaggedType(); |
| 72 | OS << formatv(TAB_1 "case {0}: {{\n" , Vals&: Name); |
| 73 | // Special case for strings |
| 74 | if (Type == "char[]" ) { |
| 75 | OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n" ); |
| 76 | } else { |
| 77 | OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n" , |
| 78 | Vals&: Type); |
| 79 | // TODO: Handle other cases here |
| 80 | OS << TAB_2 "os << (const void *)tptr << \" (\";\n" ; |
| 81 | if (Type.ends_with(Suffix: "*" )) { |
| 82 | OS << TAB_2 "os << printPtr(os, tptr);\n" ; |
| 83 | } else { |
| 84 | OS << TAB_2 "os << *tptr;\n" ; |
| 85 | } |
| 86 | OS << TAB_2 "os << \")\";\n" ; |
| 87 | } |
| 88 | OS << formatv(TAB_2 "break;\n" TAB_1 "}\n" ); |
| 89 | } |
| 90 | |
| 91 | OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2 |
| 92 | "break;\n" TAB_1 "}\n" ; |
| 93 | |
| 94 | OS << "}\n" ; |
| 95 | } |
| 96 | |
| 97 | static void EmitResultPrint(raw_ostream &OS) { |
| 98 | OS << R""( |
| 99 | inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
| 100 | const ol_error_struct_t *Err) { |
| 101 | if (Err == nullptr) { |
| 102 | os << "OL_SUCCESS"; |
| 103 | } else { |
| 104 | os << Err->Code; |
| 105 | } |
| 106 | return os; |
| 107 | } |
| 108 | )"" ; |
| 109 | } |
| 110 | |
| 111 | static void EmitFunctionParamStructPrint(const FunctionRec &Func, |
| 112 | raw_ostream &OS) { |
| 113 | if (Func.getParams().size() == 0) { |
| 114 | return; |
| 115 | } |
| 116 | |
| 117 | OS << formatv(Fmt: R"( |
| 118 | inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct {0} *params) {{ |
| 119 | )" , |
| 120 | Vals: Func.getParamStructName()); |
| 121 | |
| 122 | for (const auto &Param : Func.getParams()) { |
| 123 | OS << formatv(TAB_1 "os << \".{0} = \";\n" , Vals: Param.getName()); |
| 124 | if (auto Range = Param.getRange()) { |
| 125 | OS << formatv(TAB_1 "os << \"{{\";\n" ); |
| 126 | OS << formatv(TAB_1 "for (size_t i = {0}; i < *params->p{1}; i++) {{\n" , |
| 127 | Vals&: Range->first, Vals&: Range->second); |
| 128 | OS << TAB_2 "if (i > 0) {\n" ; |
| 129 | OS << TAB_3 " os << \", \";\n" ; |
| 130 | OS << TAB_2 "}\n" ; |
| 131 | OS << formatv(TAB_2 "printPtr(os, (*params->p{0})[i]);\n" , |
| 132 | Vals: Param.getName()); |
| 133 | OS << formatv(TAB_1 "}\n" ); |
| 134 | OS << formatv(TAB_1 "os << \"}\";\n" ); |
| 135 | } else if (auto TypeInfo = Param.getTypeInfo()) { |
| 136 | OS << formatv( |
| 137 | TAB_1 |
| 138 | "printTagged(os, *params->p{0}, *params->p{1}, *params->p{2});\n" , |
| 139 | Vals: Param.getName(), Vals&: TypeInfo->first, Vals&: TypeInfo->second); |
| 140 | } else if (Param.isPointerType() || Param.isHandleType()) { |
| 141 | OS << formatv(TAB_1 "printPtr(os, *params->p{0});\n" , Vals: Param.getName()); |
| 142 | } else if (Param.isFptrType()) { |
| 143 | OS << formatv(TAB_1 "os << reinterpret_cast<void*>(*params->p{0});\n" , |
| 144 | Vals: Param.getName()); |
| 145 | } else { |
| 146 | OS << formatv(TAB_1 "os << *params->p{0};\n" , Vals: Param.getName()); |
| 147 | } |
| 148 | if (Param != Func.getParams().back()) { |
| 149 | OS << TAB_1 "os << \", \";\n" ; |
| 150 | } |
| 151 | } |
| 152 | |
| 153 | OS << TAB_1 "return os;\n}\n" ; |
| 154 | } |
| 155 | |
| 156 | void ProcessStruct(const StructRec &Struct, raw_ostream &OS) { |
| 157 | if (Struct.getName() == "ol_error_struct_t" ) { |
| 158 | return; |
| 159 | } |
| 160 | OS << formatv(Fmt: PrintTypeHeader, Vals: Struct.getName()); |
| 161 | OS << formatv(Fmt: R"( |
| 162 | inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct {0} params) {{ |
| 163 | )" , |
| 164 | Vals: Struct.getName()); |
| 165 | OS << formatv(TAB_1 "os << \"(struct {0}){{\";\n" , Vals: Struct.getName()); |
| 166 | for (const auto &Member : Struct.getMembers()) { |
| 167 | OS << formatv(TAB_1 "os << \".{0} = \";\n" , Vals: Member.getName()); |
| 168 | if (Member.isPointerType() || Member.isHandleType()) { |
| 169 | OS << formatv(TAB_1 "printPtr(os, params.{0});\n" , Vals: Member.getName()); |
| 170 | } else { |
| 171 | OS << formatv(TAB_1 "os << params.{0};\n" , Vals: Member.getName()); |
| 172 | } |
| 173 | if (Member.getName() != Struct.getMembers().back().getName()) { |
| 174 | OS << TAB_1 "os << \", \";\n" ; |
| 175 | } |
| 176 | } |
| 177 | OS << TAB_1 "os << \"}\";\n" ; |
| 178 | OS << TAB_1 "return os;\n" ; |
| 179 | OS << "}\n" ; |
| 180 | } |
| 181 | |
| 182 | void (const RecordKeeper &Records, raw_ostream &OS) { |
| 183 | OS << GenericHeader; |
| 184 | OS << R"""( |
| 185 | // Auto-generated file, do not manually edit. |
| 186 | |
| 187 | #pragma once |
| 188 | |
| 189 | #include <OffloadAPI.h> |
| 190 | #include <llvm/Support/raw_ostream.h> |
| 191 | |
| 192 | |
| 193 | template <typename T> inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr); |
| 194 | template <typename T> inline void printTagged(llvm::raw_ostream &os, const void *ptr, T value, size_t size); |
| 195 | )""" ; |
| 196 | |
| 197 | // ========== |
| 198 | OS << "template <typename T> struct is_handle : std::false_type {};\n" ; |
| 199 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Handle" )) { |
| 200 | HandleRec H{R}; |
| 201 | OS << formatv(Fmt: "template <> struct is_handle<{0}> : std::true_type {{};\n" , |
| 202 | Vals: H.getName()); |
| 203 | } |
| 204 | OS << "template <typename T> inline constexpr bool is_handle_v = " |
| 205 | "is_handle<T>::value;\n" ; |
| 206 | // ========= |
| 207 | |
| 208 | // Forward declare the operator<< overloads so their implementations can |
| 209 | // use each other. |
| 210 | OS << "\n" ; |
| 211 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Enum" )) { |
| 212 | OS << formatv(Fmt: "inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, " |
| 213 | "enum {0} value);\n" , |
| 214 | Vals: EnumRec{R}.getName()); |
| 215 | } |
| 216 | OS << "\n" ; |
| 217 | |
| 218 | // Create definitions |
| 219 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Enum" )) { |
| 220 | EnumRec E{R}; |
| 221 | ProcessEnum(Enum: E, OS); |
| 222 | } |
| 223 | EmitResultPrint(OS); |
| 224 | |
| 225 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Struct" )) { |
| 226 | StructRec S{R}; |
| 227 | ProcessStruct(Struct: S, OS); |
| 228 | } |
| 229 | |
| 230 | // Emit print functions for the function param structs |
| 231 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Function" )) { |
| 232 | EmitFunctionParamStructPrint(Func: FunctionRec{R}, OS); |
| 233 | } |
| 234 | |
| 235 | OS << R"""( |
| 236 | /////////////////////////////////////////////////////////////////////////////// |
| 237 | // @brief Print pointer value |
| 238 | template <typename T> inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr) { |
| 239 | if (ptr == nullptr) { |
| 240 | os << "nullptr"; |
| 241 | } else if constexpr (std::is_pointer_v<T>) { |
| 242 | os << (const void *)(ptr) << " ("; |
| 243 | printPtr(os, *ptr); |
| 244 | os << ")"; |
| 245 | } else if constexpr (std::is_void_v<T> || is_handle_v<T *>) { |
| 246 | os << (const void *)ptr; |
| 247 | } else if constexpr (std::is_same_v<std::remove_cv_t< T >, char>) { |
| 248 | os << (const void *)(ptr) << " ("; |
| 249 | os << ptr; |
| 250 | os << ")"; |
| 251 | } else { |
| 252 | os << (const void *)(ptr) << " ("; |
| 253 | os << *ptr; |
| 254 | os << ")"; |
| 255 | } |
| 256 | |
| 257 | return OL_SUCCESS; |
| 258 | } |
| 259 | )""" ; |
| 260 | } |
| 261 | |