1 | //===- OmpOpGen.cpp - OpenMP dialect op specific generators ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // OmpOpGen defines OpenMP dialect operation specific generators. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/TableGen/GenInfo.h" |
14 | |
15 | #include "mlir/TableGen/CodeGenHelpers.h" |
16 | #include "llvm/ADT/StringExtras.h" |
17 | #include "llvm/ADT/StringSet.h" |
18 | #include "llvm/ADT/TypeSwitch.h" |
19 | #include "llvm/Support/FormatAdapters.h" |
20 | #include "llvm/TableGen/Error.h" |
21 | #include "llvm/TableGen/Record.h" |
22 | |
23 | using namespace llvm; |
24 | |
25 | /// The code block defining the base mixin class for combining clause operand |
26 | /// structures. |
27 | static const char *const baseMixinClass = R"( |
28 | namespace detail { |
29 | template <typename... Mixins> |
30 | struct Clauses : public Mixins... {}; |
31 | } // namespace detail |
32 | )" ; |
33 | |
34 | /// The code block defining operation argument structures. |
35 | static const char *const operationArgStruct = R"( |
36 | using {0}Operands = detail::Clauses<{1}>; |
37 | )" ; |
38 | |
39 | /// Remove multiple optional prefixes and suffixes from \c str. |
40 | /// |
41 | /// Prefixes and suffixes are attempted to be removed once in the order they |
42 | /// appear in the \c prefixes and \c suffixes arguments. All prefixes are |
43 | /// processed before suffixes are. This means it will behave as shown in the |
44 | /// following example: |
45 | /// - str: "PrePreNameSuf1Suf2" |
46 | /// - prefixes: ["Pre"] |
47 | /// - suffixes: ["Suf1", "Suf2"] |
48 | /// - return: "PreNameSuf1" |
49 | static StringRef stripPrefixAndSuffix(StringRef str, |
50 | llvm::ArrayRef<StringRef> prefixes, |
51 | llvm::ArrayRef<StringRef> suffixes) { |
52 | for (StringRef prefix : prefixes) |
53 | if (str.starts_with(Prefix: prefix)) |
54 | str = str.drop_front(N: prefix.size()); |
55 | |
56 | for (StringRef suffix : suffixes) |
57 | if (str.ends_with(Suffix: suffix)) |
58 | str = str.drop_back(N: suffix.size()); |
59 | |
60 | return str; |
61 | } |
62 | |
63 | /// Obtain the name of the OpenMP clause a given record inheriting |
64 | /// `OpenMP_Clause` refers to. |
65 | /// |
66 | /// It supports direct and indirect `OpenMP_Clause` superclasses. Once the |
67 | /// `OpenMP_Clause` class the record is based on is found, the optional |
68 | /// "OpenMP_" prefix and "Skip" and "Clause" suffixes are removed to return only |
69 | /// the clause name, i.e. "OpenMP_CollapseClauseSkip" is returned as "Collapse". |
70 | static StringRef (const Record *clause) { |
71 | const Record *ompClause = clause->getRecords().getClass(Name: "OpenMP_Clause" ); |
72 | assert(ompClause && "base OpenMP records expected to be defined" ); |
73 | |
74 | StringRef clauseClassName; |
75 | |
76 | // Check if OpenMP_Clause is a direct superclass. |
77 | for (const Record *superClass : |
78 | llvm::make_first_range(c: clause->getDirectSuperClasses())) { |
79 | if (superClass == ompClause) { |
80 | clauseClassName = clause->getName(); |
81 | break; |
82 | } |
83 | } |
84 | |
85 | // Support indirectly-inherited OpenMP_Clauses. |
86 | if (clauseClassName.empty()) { |
87 | for (const Record *superClass : clause->getSuperClasses()) { |
88 | if (superClass->isSubClassOf(R: ompClause)) { |
89 | clauseClassName = superClass->getName(); |
90 | break; |
91 | } |
92 | } |
93 | } |
94 | |
95 | assert(!clauseClassName.empty() && "clause name must be found" ); |
96 | |
97 | // Keep only the OpenMP clause name itself for reporting purposes. |
98 | return stripPrefixAndSuffix(str: clauseClassName, /*prefixes=*/{"OpenMP_" }, |
99 | /*suffixes=*/{"Skip" , "Clause" }); |
100 | } |
101 | |
102 | /// Check that the given argument, identified by its name and initialization |
103 | /// value, is present in the \c arguments `dag`. |
104 | static bool verifyArgument(const DagInit *arguments, StringRef argName, |
105 | const Init *argInit) { |
106 | auto range = zip_equal(t: arguments->getArgNames(), u: arguments->getArgs()); |
107 | return llvm::any_of( |
108 | Range&: range, P: [&](std::tuple<const llvm::StringInit *, const llvm::Init *> v) { |
109 | return std::get<0>(t&: v)->getAsUnquotedString() == argName && |
110 | std::get<1>(t&: v) == argInit; |
111 | }); |
112 | } |
113 | |
114 | /// Check that the given string record value, identified by its \c opValueName, |
115 | /// is either undefined or empty in both the given operation and clause record |
116 | /// or its contents for the clause record are contained in the operation record. |
117 | /// Passing a non-empty \c clauseValueName enables checking values named |
118 | /// differently in the operation and clause records. |
119 | static bool verifyStringValue(const Record *op, const Record *clause, |
120 | StringRef opValueName, |
121 | StringRef clauseValueName = {}) { |
122 | auto opValue = op->getValueAsOptionalString(FieldName: opValueName); |
123 | auto clauseValue = clause->getValueAsOptionalString( |
124 | FieldName: clauseValueName.empty() ? opValueName : clauseValueName); |
125 | |
126 | bool opHasValue = opValue && !opValue->trim().empty(); |
127 | bool clauseHasValue = clauseValue && !clauseValue->trim().empty(); |
128 | |
129 | if (!opHasValue) |
130 | return !clauseHasValue; |
131 | |
132 | return !clauseHasValue || opValue->contains(Other: clauseValue->trim()); |
133 | } |
134 | |
135 | /// Verify that all fields of the given clause not explicitly ignored are |
136 | /// present in the corresponding operation field. |
137 | /// |
138 | /// Print warnings or errors where this is not the case. |
139 | static void verifyClause(const Record *op, const Record *clause) { |
140 | StringRef clauseClassName = extractOmpClauseName(clause); |
141 | |
142 | if (!clause->getValueAsBit(FieldName: "ignoreArgs" )) { |
143 | const DagInit *opArguments = op->getValueAsDag(FieldName: "arguments" ); |
144 | const DagInit *arguments = clause->getValueAsDag(FieldName: "arguments" ); |
145 | |
146 | for (auto [name, arg] : |
147 | zip(t: arguments->getArgNames(), u: arguments->getArgs())) { |
148 | if (!verifyArgument(arguments: opArguments, argName: name->getAsUnquotedString(), argInit: arg)) |
149 | PrintWarning( |
150 | WarningLoc: op->getLoc(), |
151 | Msg: "'" + clauseClassName + "' clause-defined argument '" + |
152 | arg->getAsUnquotedString() + ":$" + |
153 | name->getAsUnquotedString() + |
154 | "' not present in operation. Consider `dag arguments = " |
155 | "!con(clausesArgs, ...)` or explicitly skipping this field." ); |
156 | } |
157 | } |
158 | |
159 | if (!clause->getValueAsBit(FieldName: "ignoreAsmFormat" ) && |
160 | !verifyStringValue(op, clause, opValueName: "assemblyFormat" , clauseValueName: "reqAssemblyFormat" )) |
161 | PrintWarning( |
162 | WarningLoc: op->getLoc(), |
163 | Msg: "'" + clauseClassName + |
164 | "' clause-defined `reqAssemblyFormat` not present in operation. " |
165 | "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or " |
166 | "explicitly skipping this field." ); |
167 | |
168 | if (!clause->getValueAsBit(FieldName: "ignoreAsmFormat" ) && |
169 | !verifyStringValue(op, clause, opValueName: "assemblyFormat" , clauseValueName: "optAssemblyFormat" )) |
170 | PrintWarning( |
171 | WarningLoc: op->getLoc(), |
172 | Msg: "'" + clauseClassName + |
173 | "' clause-defined `optAssemblyFormat` not present in operation. " |
174 | "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or " |
175 | "explicitly skipping this field." ); |
176 | |
177 | if (!clause->getValueAsBit(FieldName: "ignoreDesc" ) && |
178 | !verifyStringValue(op, clause, opValueName: "description" )) |
179 | PrintError(ErrorLoc: op->getLoc(), |
180 | Msg: "'" + clauseClassName + |
181 | "' clause-defined `description` not present in operation. " |
182 | "Consider concatenating `clausesDescription` or explicitly " |
183 | "skipping this field." ); |
184 | |
185 | if (!clause->getValueAsBit(FieldName: "ignoreExtraDecl" ) && |
186 | !verifyStringValue(op, clause, opValueName: "extraClassDeclaration" )) |
187 | PrintWarning( |
188 | WarningLoc: op->getLoc(), |
189 | Msg: "'" + clauseClassName + |
190 | "' clause-defined `extraClassDeclaration` not present in " |
191 | "operation. Consider concatenating `clausesExtraClassDeclaration` " |
192 | "or explicitly skipping this field." ); |
193 | } |
194 | |
195 | /// Translate the type of an OpenMP clause's argument to its corresponding |
196 | /// representation for clause operand structures. |
197 | /// |
198 | /// All kinds of values are represented as `mlir::Value` fields, whereas |
199 | /// attributes are represented based on their `storageType`. |
200 | /// |
201 | /// \param[in] name The name of the argument. |
202 | /// \param[in] init The `DefInit` object representing the argument. |
203 | /// \param[out] nest Number of levels of array nesting associated with the |
204 | /// type. Must be initially set to 0. |
205 | /// \param[out] rank Rank (number of dimensions, if an array type) of the base |
206 | /// type. Must be initially set to 1. |
207 | /// |
208 | /// \return the name of the base type to represent elements of the argument |
209 | /// type. |
210 | static StringRef translateArgumentType(ArrayRef<SMLoc> loc, |
211 | const StringInit *name, const Init *init, |
212 | int &nest, int &rank) { |
213 | const Record *def = cast<DefInit>(Val: init)->getDef(); |
214 | |
215 | llvm::StringSet<> superClasses; |
216 | for (const Record *sc : def->getSuperClasses()) |
217 | superClasses.insert(key: sc->getNameInitAsString()); |
218 | |
219 | // Handle wrapper-style superclasses. |
220 | if (superClasses.contains(key: "OptionalAttr" )) |
221 | return translateArgumentType( |
222 | loc, name, init: def->getValue(Name: "baseAttr" )->getValue(), nest, rank); |
223 | |
224 | if (superClasses.contains(key: "TypedArrayAttrBase" )) |
225 | return translateArgumentType( |
226 | loc, name, init: def->getValue(Name: "elementAttr" )->getValue(), nest&: ++nest, rank); |
227 | |
228 | // Handle ElementsAttrBase superclasses. |
229 | if (superClasses.contains(key: "ElementsAttrBase" )) { |
230 | // TODO: Obtain the rank from ranked types. |
231 | ++nest; |
232 | |
233 | if (superClasses.contains(key: "IntElementsAttrBase" )) |
234 | return "::llvm::APInt" ; |
235 | if (superClasses.contains(key: "FloatElementsAttr" ) || |
236 | superClasses.contains(key: "RankedFloatElementsAttr" )) |
237 | return "::llvm::APFloat" ; |
238 | if (superClasses.contains(key: "DenseArrayAttrBase" )) |
239 | return stripPrefixAndSuffix(str: def->getValueAsString(FieldName: "returnType" ), |
240 | prefixes: {"::llvm::ArrayRef<" }, suffixes: {">" }); |
241 | |
242 | // Decrease the nesting depth in the case where the base type cannot be |
243 | // inferred, so that the bare storageType is used instead of a vector. |
244 | --nest; |
245 | PrintWarning( |
246 | WarningLoc: loc, |
247 | Msg: "could not infer array-like attribute element type for argument '" + |
248 | name->getAsUnquotedString() + "', will use bare `storageType`" ); |
249 | } |
250 | |
251 | // Handle simple attribute and value types. |
252 | [[maybe_unused]] bool isAttr = superClasses.contains(key: "Attr" ); |
253 | bool isValue = superClasses.contains(key: "TypeConstraint" ); |
254 | if (superClasses.contains(key: "Variadic" )) |
255 | ++nest; |
256 | |
257 | if (isValue) { |
258 | assert(!isAttr && |
259 | "argument can't be simultaneously a value and an attribute" ); |
260 | return "::mlir::Value" ; |
261 | } |
262 | |
263 | assert(isAttr && "argument must be an attribute if it's not a value" ); |
264 | return nest > 0 ? "::mlir::Attribute" |
265 | : def->getValueAsString(FieldName: "storageType" ).trim(); |
266 | } |
267 | |
268 | /// Generate the structure that represents the arguments of the given \c clause |
269 | /// record of type \c OpenMP_Clause. |
270 | /// |
271 | /// It will contain a field for each argument, using the same name translated to |
272 | /// camel case and the corresponding base type as returned by |
273 | /// translateArgumentType() optionally wrapped in one or more llvm::SmallVector. |
274 | /// |
275 | /// An additional field containing a tuple of integers to hold the size of each |
276 | /// dimension will also be created for multi-rank types. This is not yet |
277 | /// supported. |
278 | static void genClauseOpsStruct(const Record *clause, raw_ostream &os) { |
279 | if (clause->isAnonymous()) |
280 | return; |
281 | |
282 | StringRef clauseName = extractOmpClauseName(clause); |
283 | os << "struct " << clauseName << "ClauseOps {\n" ; |
284 | |
285 | const DagInit *arguments = clause->getValueAsDag(FieldName: "arguments" ); |
286 | for (auto [name, arg] : |
287 | zip_equal(t: arguments->getArgNames(), u: arguments->getArgs())) { |
288 | int nest = 0, rank = 1; |
289 | StringRef baseType = |
290 | translateArgumentType(loc: clause->getLoc(), name, init: arg, nest, rank); |
291 | std::string fieldName = |
292 | convertToCamelFromSnakeCase(input: name->getAsUnquotedString(), |
293 | /*capitalizeFirst=*/false); |
294 | |
295 | os << formatv(Fmt: " {0}{1}{2} {3};\n" , |
296 | Vals: fmt_repeat(Item: "::llvm::SmallVector<" , Count: nest), Vals&: baseType, |
297 | Vals: fmt_repeat(Item: ">" , Count: nest), Vals&: fieldName); |
298 | |
299 | if (rank > 1) { |
300 | assert(nest >= 1 && "must be nested if it's a ranked type" ); |
301 | os << formatv(Fmt: " {0}::std::tuple<{1}int>{2} {3}Dims;\n" , |
302 | Vals: fmt_repeat(Item: "::llvm::SmallVector<" , Count: nest - 1), |
303 | Vals: fmt_repeat(Item: "int, " , Count: rank - 1), Vals: fmt_repeat(Item: ">" , Count: nest - 1), |
304 | Vals&: fieldName); |
305 | } |
306 | } |
307 | |
308 | os << "};\n" ; |
309 | } |
310 | |
311 | /// Generate the structure that represents the clause-related arguments of the |
312 | /// given \c op record of type \c OpenMP_Op. |
313 | /// |
314 | /// This structure will be defined in terms of the clause operand structures |
315 | /// associated to the clauses of the operation. |
316 | static void genOperandsDef(const Record *op, raw_ostream &os) { |
317 | if (op->isAnonymous()) |
318 | return; |
319 | |
320 | SmallVector<std::string> clauseNames; |
321 | for (const Record *clause : op->getValueAsListOfDefs(FieldName: "clauseList" )) |
322 | clauseNames.push_back(Elt: (extractOmpClauseName(clause) + "ClauseOps" ).str()); |
323 | |
324 | StringRef opName = stripPrefixAndSuffix( |
325 | str: op->getName(), /*prefixes=*/{"OpenMP_" }, /*suffixes=*/{"Op" }); |
326 | os << formatv(Fmt: operationArgStruct, Vals&: opName, Vals: join(R&: clauseNames, Separator: ", " )); |
327 | } |
328 | |
329 | /// Verify that all properties of `OpenMP_Clause`s of records deriving from |
330 | /// `OpenMP_Op`s have been inherited by the latter. |
331 | static bool verifyDecls(const RecordKeeper &records, raw_ostream &) { |
332 | for (const Record *op : records.getAllDerivedDefinitions(ClassName: "OpenMP_Op" )) { |
333 | for (const Record *clause : op->getValueAsListOfDefs(FieldName: "clauseList" )) |
334 | verifyClause(op, clause); |
335 | } |
336 | |
337 | return false; |
338 | } |
339 | |
340 | /// Generate structures to represent clause-related operands, based on existing |
341 | /// `OpenMP_Clause` definitions and aggregate them into operation-specific |
342 | /// structures according to the `clauses` argument of each definition deriving |
343 | /// from `OpenMP_Op`. |
344 | static bool genClauseOps(const RecordKeeper &records, raw_ostream &os) { |
345 | mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp" ); |
346 | for (const Record *clause : records.getAllDerivedDefinitions(ClassName: "OpenMP_Clause" )) |
347 | genClauseOpsStruct(clause, os); |
348 | |
349 | // Produce base mixin class. |
350 | os << baseMixinClass; |
351 | |
352 | for (const Record *op : records.getAllDerivedDefinitions(ClassName: "OpenMP_Op" )) |
353 | genOperandsDef(op, os); |
354 | |
355 | return false; |
356 | } |
357 | |
358 | // Registers the generator to mlir-tblgen. |
359 | static mlir::GenRegistration |
360 | verifyOpenmpOps("verify-openmp-ops" , |
361 | "Verify OpenMP operations (produce no output file)" , |
362 | verifyDecls); |
363 | |
364 | static mlir::GenRegistration |
365 | genOpenmpClauseOps("gen-openmp-clause-ops" , |
366 | "Generate OpenMP clause operand structures" , |
367 | genClauseOps); |
368 | |