1 | //===- offload-tblgen/EntryPointGen.cpp - Tablegen backend for Offload ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This is a Tablegen backend that produces the actual entry points for the |
10 | // Offload API. It serves as a place to integrate functionality like tracing |
11 | // and validation before dispatching to the actual implementations. |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Support/FormatVariadic.h" |
15 | #include "llvm/TableGen/Record.h" |
16 | |
17 | #include "GenCommon.hpp" |
18 | #include "RecordTypes.hpp" |
19 | |
20 | using namespace llvm; |
21 | using namespace offload::tblgen; |
22 | |
23 | static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { |
24 | OS << CommentsHeader; |
25 | // Emit preamble |
26 | OS << formatv(Fmt: "llvm::Error {0}_val(\n " , Vals: F.getName()); |
27 | // Emit arguments |
28 | std::string ParamNameList = "" ; |
29 | for (auto &Param : F.getParams()) { |
30 | OS << Param.getType() << " " << Param.getName(); |
31 | if (Param != F.getParams().back()) { |
32 | OS << ", " ; |
33 | } |
34 | ParamNameList += Param.getName().str() + ", " ; |
35 | } |
36 | OS << ") {\n" ; |
37 | |
38 | OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n" ; |
39 | // Emit validation checks |
40 | for (const auto &Return : F.getReturns()) { |
41 | for (auto &Condition : Return.getConditions()) { |
42 | if (Condition.starts_with(Prefix: "`" ) && Condition.ends_with(Suffix: "`" )) { |
43 | auto ConditionString = Condition.substr(Start: 1, N: Condition.size() - 2); |
44 | OS << formatv(TAB_2 "if ({0}) {{\n" , Vals&: ConditionString); |
45 | OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, " |
46 | "\"validation failure: {1}\");\n" , |
47 | Vals: Return.getUnprefixedValue(), Vals&: ConditionString); |
48 | OS << TAB_2 "}\n\n" ; |
49 | } |
50 | } |
51 | } |
52 | OS << TAB_1 "}\n\n" ; |
53 | |
54 | // Perform actual function call to the implementation |
55 | ParamNameList = ParamNameList.substr(pos: 0, n: ParamNameList.size() - 2); |
56 | OS << formatv(TAB_1 "return llvm::offload::{0}_impl({1});\n\n" , Vals: F.getName(), |
57 | Vals&: ParamNameList); |
58 | OS << "}\n" ; |
59 | } |
60 | |
61 | static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { |
62 | // Emit preamble |
63 | OS << formatv(Fmt: "{1}_APIEXPORT {0}_result_t {1}_APICALL {2}(\n " , Vals: PrefixLower, |
64 | Vals: PrefixUpper, Vals: F.getName()); |
65 | // Emit arguments |
66 | std::string ParamNameList = "" ; |
67 | for (auto &Param : F.getParams()) { |
68 | OS << Param.getType() << " " << Param.getName(); |
69 | if (Param != F.getParams().back()) { |
70 | OS << ", " ; |
71 | } |
72 | ParamNameList += Param.getName().str() + ", " ; |
73 | } |
74 | OS << ") {\n" ; |
75 | |
76 | // Emit pre-call prints |
77 | OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n" ; |
78 | OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n" , Vals: F.getName()); |
79 | OS << TAB_1 "}\n\n" ; |
80 | |
81 | // Perform actual function call to the validation wrapper |
82 | ParamNameList = ParamNameList.substr(pos: 0, n: ParamNameList.size() - 2); |
83 | OS << formatv( |
84 | TAB_1 "{0}_result_t Result = llvmErrorToOffloadError({1}_val({2}));\n\n" , |
85 | Vals: PrefixLower, Vals: F.getName(), Vals&: ParamNameList); |
86 | |
87 | // Emit post-call prints |
88 | OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n" ; |
89 | if (F.getParams().size() > 0) { |
90 | OS << formatv(TAB_2 "{0} Params = {{" , Vals: F.getParamStructName()); |
91 | for (const auto &Param : F.getParams()) { |
92 | OS << "&" << Param.getName(); |
93 | if (Param != F.getParams().back()) { |
94 | OS << ", " ; |
95 | } |
96 | } |
97 | OS << formatv(Fmt: "};\n" ); |
98 | OS << TAB_2 "llvm::errs() << \"(\" << &Params << \")\";\n" ; |
99 | } else { |
100 | OS << TAB_2 "llvm::errs() << \"()\";\n" ; |
101 | } |
102 | OS << TAB_2 "llvm::errs() << \"-> \" << Result << \"\\n\";\n" ; |
103 | OS << TAB_2 "if (Result && Result->Details) {\n" ; |
104 | OS << TAB_3 "llvm::errs() << \" *Error Details* \" << Result->Details " |
105 | "<< \" \\n\";\n" ; |
106 | OS << TAB_2 "}\n" ; |
107 | OS << TAB_1 "}\n" ; |
108 | |
109 | OS << TAB_1 "return Result;\n" ; |
110 | OS << "}\n" ; |
111 | } |
112 | |
113 | static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) { |
114 | // Emit preamble |
115 | OS << formatv(Fmt: "{0}_result_t {1}WithCodeLoc(\n " , Vals: PrefixLower, Vals: F.getName()); |
116 | // Emit arguments |
117 | std::string ParamNameList = "" ; |
118 | for (auto &Param : F.getParams()) { |
119 | OS << Param.getType() << " " << Param.getName() << ", " ; |
120 | ParamNameList += Param.getName().str(); |
121 | if (Param != F.getParams().back()) { |
122 | ParamNameList += ", " ; |
123 | } |
124 | } |
125 | OS << "ol_code_location_t *CodeLocation" ; |
126 | OS << ") {\n" ; |
127 | OS << TAB_1 "currentCodeLocation() = CodeLocation;\n" ; |
128 | OS << formatv(TAB_1 "{0}_result_t Result = ::{1}({2});\n\n" , Vals: PrefixLower, |
129 | Vals: F.getName(), Vals&: ParamNameList); |
130 | OS << TAB_1 "currentCodeLocation() = nullptr;\n" ; |
131 | OS << TAB_1 "return Result;\n" ; |
132 | OS << "}\n" ; |
133 | } |
134 | |
135 | void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) { |
136 | OS << GenericHeader; |
137 | for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Function" )) { |
138 | EmitValidationFunc(F: FunctionRec{R}, OS); |
139 | EmitEntryPointFunc(F: FunctionRec{R}, OS); |
140 | EmitCodeLocWrapper(F: FunctionRec{R}, OS); |
141 | } |
142 | } |
143 | |