1//===- offload-tblgen/EntryPointGen.cpp - Tablegen backend for Offload ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is a Tablegen backend that produces the actual entry points for the
10// Offload API. It serves as a place to integrate functionality like tracing
11// and validation before dispatching to the actual implementations.
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Support/FormatVariadic.h"
15#include "llvm/TableGen/Record.h"
16
17#include "GenCommon.hpp"
18#include "RecordTypes.hpp"
19
20using namespace llvm;
21using namespace offload::tblgen;
22
23static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
24 OS << CommentsHeader;
25 // Emit preamble
26 OS << formatv(Fmt: "llvm::Error {0}_val(\n ", Vals: F.getName());
27 // Emit arguments
28 std::string ParamNameList = "";
29 for (auto &Param : F.getParams()) {
30 OS << Param.getType() << " " << Param.getName();
31 if (Param != F.getParams().back()) {
32 OS << ", ";
33 }
34 ParamNameList += Param.getName().str() + ", ";
35 }
36 OS << ") {\n";
37
38 bool HasValidation = llvm::any_of(Range: F.getReturns(), P: [](auto &R) {
39 return llvm::any_of(R.getConditions(), [](auto &C) {
40 return C.starts_with("`") && C.ends_with("`");
41 });
42 });
43
44 if (HasValidation) {
45 OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n";
46 // Emit validation checks
47 for (const auto &Return : F.getReturns()) {
48 for (auto &Condition : Return.getConditions()) {
49 if (Condition.starts_with(Prefix: "`") && Condition.ends_with(Suffix: "`")) {
50 auto ConditionString = Condition.substr(Start: 1, N: Condition.size() - 2);
51 OS << formatv(TAB_2 "if ({0}) {{\n", Vals&: ConditionString);
52 OS << formatv(TAB_3
53 "return createOffloadError(error::ErrorCode::{0}, "
54 "\"validation failure: {1}\");\n",
55 Vals: Return.getUnprefixedValue(), Vals&: ConditionString);
56 OS << TAB_2 "}\n\n";
57 }
58 }
59 }
60 OS << TAB_1 "}\n\n";
61 }
62
63 // Perform actual function call to the implementation
64 ParamNameList = ParamNameList.substr(pos: 0, n: ParamNameList.size() - 2);
65 OS << formatv(TAB_1 "return llvm::offload::{0}_impl({1});\n\n", Vals: F.getName(),
66 Vals&: ParamNameList);
67 OS << "}\n";
68}
69
70static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
71 // Emit preamble
72 OS << formatv(Fmt: "{1}_APIEXPORT {0}_result_t {1}_APICALL {2}(\n ", Vals: PrefixLower,
73 Vals: PrefixUpper, Vals: F.getName());
74 // Emit arguments
75 std::string ParamNameList = "";
76 for (auto &Param : F.getParams()) {
77 OS << Param.getType() << " " << Param.getName();
78 if (Param != F.getParams().back()) {
79 OS << ", ";
80 }
81 ParamNameList += Param.getName().str() + ", ";
82 }
83 OS << ") {\n";
84
85 // Check offload is initialized
86 if (F.getName() != "olInit")
87 OS << "if (!llvm::offload::isOffloadInitialized()) return &UninitError;";
88
89 // Emit pre-call prints
90 OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
91 OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", Vals: F.getName());
92 OS << TAB_1 "}\n\n";
93
94 // Perform actual function call to the validation wrapper
95 ParamNameList = ParamNameList.substr(pos: 0, n: ParamNameList.size() - 2);
96 OS << formatv(
97 TAB_1 "{0}_result_t Result = llvmErrorToOffloadError({1}_val({2}));\n\n",
98 Vals: PrefixLower, Vals: F.getName(), Vals&: ParamNameList);
99
100 // Emit post-call prints
101 OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
102 if (F.getParams().size() > 0) {
103 OS << formatv(TAB_2 "{0} Params = {{", Vals: F.getParamStructName());
104 for (const auto &Param : F.getParams()) {
105 OS << "&" << Param.getName();
106 if (Param != F.getParams().back()) {
107 OS << ", ";
108 }
109 }
110 OS << formatv(Fmt: "};\n");
111 OS << TAB_2 "llvm::errs() << \"(\" << &Params << \")\";\n";
112 } else {
113 OS << TAB_2 "llvm::errs() << \"()\";\n";
114 }
115 OS << TAB_2 "llvm::errs() << \"-> \" << Result << \"\\n\";\n";
116 OS << TAB_2 "if (Result && Result->Details) {\n";
117 OS << TAB_3 "llvm::errs() << \" *Error Details* \" << Result->Details "
118 "<< \" \\n\";\n";
119 OS << TAB_2 "}\n";
120 OS << TAB_1 "}\n";
121
122 OS << TAB_1 "return Result;\n";
123 OS << "}\n";
124}
125
126static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) {
127 // Emit preamble
128 OS << formatv(Fmt: "{0}_result_t {1}WithCodeLoc(\n ", Vals: PrefixLower, Vals: F.getName());
129 // Emit arguments
130 std::string ParamNameList = "";
131 for (auto &Param : F.getParams()) {
132 OS << Param.getType() << " " << Param.getName() << ", ";
133 ParamNameList += Param.getName().str();
134 if (Param != F.getParams().back()) {
135 ParamNameList += ", ";
136 }
137 }
138 OS << "ol_code_location_t *CodeLocation";
139 OS << ") {\n";
140 OS << TAB_1 "currentCodeLocation() = CodeLocation;\n";
141 OS << formatv(TAB_1 "{0}_result_t Result = ::{1}({2});\n\n", Vals: PrefixLower,
142 Vals: F.getName(), Vals&: ParamNameList);
143 OS << TAB_1 "currentCodeLocation() = nullptr;\n";
144 OS << TAB_1 "return Result;\n";
145 OS << "}\n";
146}
147
148void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) {
149 OS << GenericHeader;
150
151 constexpr const char *UninitMessage =
152 "liboffload has not been initialized - please call olInit before using "
153 "this API";
154 OS << formatv(Fmt: "static {0}_error_struct_t UninitError = "
155 "{{{1}_ERRC_UNINITIALIZED, \"{2}\"};",
156 Vals: PrefixLower, Vals: PrefixUpper, Vals: UninitMessage);
157
158 for (auto *R : Records.getAllDerivedDefinitions(ClassName: "Function")) {
159 EmitValidationFunc(F: FunctionRec{R}, OS);
160 EmitEntryPointFunc(F: FunctionRec{R}, OS);
161 EmitCodeLocWrapper(F: FunctionRec{R}, OS);
162 }
163}
164

source code of offload/tools/offload-tblgen/EntryPointGen.cpp