1 | //===- offload-tblgen/APIGen.cpp - Tablegen backend for Offload printing --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This is a Tablegen backend that produces print functions for the Offload API |
10 | // entry point functions. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Support/FormatVariadic.h" |
15 | #include "llvm/TableGen/Record.h" |
16 | |
17 | #include "GenCommon.hpp" |
18 | #include "RecordTypes.hpp" |
19 | |
20 | using namespace llvm; |
21 | using namespace offload::tblgen; |
22 | |
23 | constexpr auto = |
24 | R"(/////////////////////////////////////////////////////////////////////////////// |
25 | /// @brief Print operator for the {0} type |
26 | /// @returns llvm::raw_ostream & |
27 | )" ; |
28 | |
29 | constexpr auto = |
30 | R"(/////////////////////////////////////////////////////////////////////////////// |
31 | /// @brief Print type-tagged {0} enum value |
32 | /// @returns llvm::raw_ostream & |
33 | )" ; |
34 | |
35 | static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) { |
36 | OS << formatv(Fmt: PrintTypeHeader, Vals: Enum.getName()); |
37 | OS << formatv(Fmt: "inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, " |
38 | "enum {0} value) " |
39 | "{{\n" TAB_1 "switch (value) {{\n" , |
40 | Vals: Enum.getName()); |
41 | |
42 | for (const auto &Val : Enum.getValues()) { |
43 | auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); |
44 | OS << formatv(TAB_1 "case {0}:\n" , Vals&: Name); |
45 | OS << formatv(TAB_2 "os << \"{0}\";\n" , Vals&: Name); |
46 | OS << formatv(TAB_2 "break;\n" ); |
47 | } |
48 | |
49 | OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2 |
50 | "break;\n" TAB_1 "}\n" TAB_1 "return os;\n}\n\n" ; |
51 | |
52 | if (!Enum.isTyped()) { |
53 | return; |
54 | } |
55 | |
56 | OS << formatv(Fmt: PrintTaggedEnumHeader, Vals: Enum.getName()); |
57 | |
58 | OS << formatv(Fmt: R"""(template <> |
59 | inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_t size) {{ |
60 | if (ptr == NULL) {{ |
61 | printPtr(os, ptr); |
62 | return; |
63 | } |
64 | |
65 | switch (value) {{ |
66 | )""" , |
67 | Vals: Enum.getName()); |
68 | |
69 | for (const auto &Val : Enum.getValues()) { |
70 | auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName(); |
71 | auto Type = Val.getTaggedType(); |
72 | OS << formatv(TAB_1 "case {0}: {{\n" , Vals&: Name); |
73 | // Special case for strings |
74 | if (Type == "char[]" ) { |
75 | OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n" ); |
76 | } else { |
77 | OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n" , |
78 | Vals&: Type); |
79 | // TODO: Handle other cases here |
80 | OS << TAB_2 "os << (const void *)tptr << \" (\";\n" ; |
81 | if (Type.ends_with(Suffix: "*" )) { |
82 | OS << TAB_2 "os << printPtr(os, tptr);\n" ; |
83 | } else { |
84 | OS << TAB_2 "os << *tptr;\n" ; |
85 | } |
86 | OS << TAB_2 "os << \")\";\n" ; |
87 | } |
88 | OS << formatv(TAB_2 "break;\n" TAB_1 "}\n" ); |
89 | } |
90 | |
91 | OS << TAB_1 "default:\n" TAB_2 "os << \"unknown enumerator\";\n" TAB_2 |
92 | "break;\n" TAB_1 "}\n" ; |
93 | |
94 | OS << "}\n" ; |
95 | } |
96 | |
97 | static void EmitResultPrint(raw_ostream &OS) { |
98 | OS << R""( |
99 | inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
100 | const ol_error_struct_t *Err) { |
101 | if (Err == nullptr) { |
102 | os << "OL_SUCCESS"; |
103 | } else { |
104 | os << Err->Code; |
105 | } |
106 | return os; |
107 | } |
108 | )"" ; |
109 | } |
110 | |
111 | static void EmitFunctionParamStructPrint(const FunctionRec &Func, |
112 | raw_ostream &OS) { |
113 | if (Func.getParams().size() == 0) { |
114 | return; |
115 | } |
116 | |
117 | OS << formatv(Fmt: R"( |
118 | inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct {0} *params) {{ |
119 | )" , |
120 | Vals: Func.getParamStructName()); |
121 | |
122 | for (const auto &Param : Func.getParams()) { |
123 | OS << formatv(TAB_1 "os << \".{0} = \";\n" , Vals: Param.getName()); |
124 | if (auto Range = Param.getRange()) { |
125 | OS << formatv(TAB_1 "os << \"{{\";\n" ); |
126 | OS << formatv(TAB_1 "for (size_t i = {0}; i < *params->p{1}; i++) {{\n" , |
127 | Vals&: Range->first, Vals&: Range->second); |
128 | OS << TAB_2 "if (i > 0) {\n" ; |
129 | OS << TAB_3 " os << \", \";\n" ; |
130 | OS << TAB_2 "}\n" ; |
131 | OS << formatv(TAB_2 "printPtr(os, (*params->p{0})[i]);\n" , |
132 | Vals: Param.getName()); |
133 | OS << formatv(TAB_1 "}\n" ); |
134 | OS << formatv(TAB_1 "os << \"}\";\n" ); |
135 | } else if (auto TypeInfo = Param.getTypeInfo()) { |
136 | OS << formatv( |
137 | TAB_1 |
138 | "printTagged(os, *params->p{0}, *params->p{1}, *params->p{2});\n" , |
139 | Vals: Param.getName(), Vals&: TypeInfo->first, Vals&: TypeInfo->second); |
140 | } else if (Param.isPointerType() || Param.isHandleType()) { |
141 | OS << formatv(TAB_1 "printPtr(os, *params->p{0});\n" , Vals: Param.getName()); |
142 | } else if (Param.isFptrType()) { |
143 | OS << formatv(TAB_1 "os << reinterpret_cast<void*>(*params->p{0});\n" , |
144 | Vals: Param.getName()); |
145 | } else { |
146 | OS << formatv(TAB_1 "os << *params->p{0};\n" , Vals: Param.getName()); |
147 | } |
148 | if (Param != Func.getParams().back()) { |
149 | OS << TAB_1 "os << \", \";\n" ; |
150 | } |
151 | } |
152 | |
153 | OS << TAB_1 "return os;\n}\n" ; |
154 | } |
155 | |
156 | void ProcessStruct(const StructRec &Struct, raw_ostream &OS) { |
157 | if (Struct.getName() == "ol_error_struct_t" ) { |
158 | return; |
159 | } |
160 | OS << formatv(Fmt: PrintTypeHeader, Vals: Struct.getName()); |
161 | OS << formatv(Fmt: R"( |
162 | inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct {0} params) {{ |
163 | )" , |
164 | Vals: Struct.getName()); |
165 | OS << formatv(TAB_1 "os << \"(struct {0}){{\";\n" , Vals: Struct.getName()); |
166 | for (const auto &Member : Struct.getMembers()) { |
167 | OS << formatv(TAB_1 "os << \".{0} = \";\n" , Vals: Member.getName()); |
168 | if (Member.isPointerType() || Member.isHandleType()) { |
169 | OS << formatv(TAB_1 "printPtr(os, params.{0});\n" , Vals: Member.getName()); |
170 | } else { |
171 | OS << formatv(TAB_1 "os << params.{0};\n" , Vals: Member.getName()); |
172 | } |
173 | if (Member.getName() != Struct.getMembers().back().getName()) { |
174 | OS << TAB_1 "os << \", \";\n" ; |
175 | } |
176 | } |
177 | OS << TAB_1 "os << \"}\";\n" ; |
178 | OS << TAB_1 "return os;\n" ; |
179 | OS << "}\n" ; |
180 | } |
181 | |
182 | void (const RecordKeeper &Records, raw_ostream &OS) { |
183 | OS << GenericHeader; |
184 | OS << R"""( |
185 | // Auto-generated file, do not manually edit. |
186 | |
187 | #pragma once |
188 | |
189 | #include <OffloadAPI.h> |
190 | #include <llvm/Support/raw_ostream.h> |
191 | |
192 | |
193 | template <typename T> inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr); |
194 | template <typename T> inline void printTagged(llvm::raw_ostream &os, const void *ptr, T value, size_t size); |
195 | )""" ; |
196 | |
197 | // ========== |
198 | OS << "template <typename T> struct is_handle : std::false_type {};\n" ; |
199 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Handle" )) { |
200 | HandleRec H{R}; |
201 | OS << formatv(Fmt: "template <> struct is_handle<{0}> : std::true_type {{};\n" , |
202 | Vals: H.getName()); |
203 | } |
204 | OS << "template <typename T> inline constexpr bool is_handle_v = " |
205 | "is_handle<T>::value;\n" ; |
206 | // ========= |
207 | |
208 | // Forward declare the operator<< overloads so their implementations can |
209 | // use each other. |
210 | OS << "\n" ; |
211 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Enum" )) { |
212 | OS << formatv(Fmt: "inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, " |
213 | "enum {0} value);\n" , |
214 | Vals: EnumRec{R}.getName()); |
215 | } |
216 | OS << "\n" ; |
217 | |
218 | // Create definitions |
219 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Enum" )) { |
220 | EnumRec E{R}; |
221 | ProcessEnum(Enum: E, OS); |
222 | } |
223 | EmitResultPrint(OS); |
224 | |
225 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Struct" )) { |
226 | StructRec S{R}; |
227 | ProcessStruct(Struct: S, OS); |
228 | } |
229 | |
230 | // Emit print functions for the function param structs |
231 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Function" )) { |
232 | EmitFunctionParamStructPrint(Func: FunctionRec{R}, OS); |
233 | } |
234 | |
235 | OS << R"""( |
236 | /////////////////////////////////////////////////////////////////////////////// |
237 | // @brief Print pointer value |
238 | template <typename T> inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr) { |
239 | if (ptr == nullptr) { |
240 | os << "nullptr"; |
241 | } else if constexpr (std::is_pointer_v<T>) { |
242 | os << (const void *)(ptr) << " ("; |
243 | printPtr(os, *ptr); |
244 | os << ")"; |
245 | } else if constexpr (std::is_void_v<T> || is_handle_v<T *>) { |
246 | os << (const void *)ptr; |
247 | } else if constexpr (std::is_same_v<std::remove_cv_t< T >, char>) { |
248 | os << (const void *)(ptr) << " ("; |
249 | os << ptr; |
250 | os << ")"; |
251 | } else { |
252 | os << (const void *)(ptr) << " ("; |
253 | os << *ptr; |
254 | os << ")"; |
255 | } |
256 | |
257 | return OL_SUCCESS; |
258 | } |
259 | )""" ; |
260 | } |
261 | |