1//===- OmpOpGen.cpp - OpenMP dialect op specific generators ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OmpOpGen defines OpenMP dialect operation specific generators.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/TableGen/GenInfo.h"
14
15#include "mlir/TableGen/CodeGenHelpers.h"
16#include "llvm/ADT/StringExtras.h"
17#include "llvm/ADT/StringSet.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/FormatAdapters.h"
20#include "llvm/TableGen/Error.h"
21#include "llvm/TableGen/Record.h"
22
23using namespace llvm;
24
25/// The code block defining the base mixin class for combining clause operand
26/// structures.
27static const char *const baseMixinClass = R"(
28namespace detail {
29template <typename... Mixins>
30struct Clauses : public Mixins... {};
31} // namespace detail
32)";
33
34/// The code block defining operation argument structures.
35static const char *const operationArgStruct = R"(
36using {0}Operands = detail::Clauses<{1}>;
37)";
38
39/// Remove multiple optional prefixes and suffixes from \c str.
40///
41/// Prefixes and suffixes are attempted to be removed once in the order they
42/// appear in the \c prefixes and \c suffixes arguments. All prefixes are
43/// processed before suffixes are. This means it will behave as shown in the
44/// following example:
45/// - str: "PrePreNameSuf1Suf2"
46/// - prefixes: ["Pre"]
47/// - suffixes: ["Suf1", "Suf2"]
48/// - return: "PreNameSuf1"
49static StringRef stripPrefixAndSuffix(StringRef str,
50 llvm::ArrayRef<StringRef> prefixes,
51 llvm::ArrayRef<StringRef> suffixes) {
52 for (StringRef prefix : prefixes)
53 if (str.starts_with(Prefix: prefix))
54 str = str.drop_front(N: prefix.size());
55
56 for (StringRef suffix : suffixes)
57 if (str.ends_with(Suffix: suffix))
58 str = str.drop_back(N: suffix.size());
59
60 return str;
61}
62
63/// Obtain the name of the OpenMP clause a given record inheriting
64/// `OpenMP_Clause` refers to.
65///
66/// It supports direct and indirect `OpenMP_Clause` superclasses. Once the
67/// `OpenMP_Clause` class the record is based on is found, the optional
68/// "OpenMP_" prefix and "Skip" and "Clause" suffixes are removed to return only
69/// the clause name, i.e. "OpenMP_CollapseClauseSkip" is returned as "Collapse".
70static StringRef extractOmpClauseName(const Record *clause) {
71 const Record *ompClause = clause->getRecords().getClass(Name: "OpenMP_Clause");
72 assert(ompClause && "base OpenMP records expected to be defined");
73
74 StringRef clauseClassName;
75
76 // Check if OpenMP_Clause is a direct superclass.
77 for (const Record *superClass :
78 llvm::make_first_range(c: clause->getDirectSuperClasses())) {
79 if (superClass == ompClause) {
80 clauseClassName = clause->getName();
81 break;
82 }
83 }
84
85 // Support indirectly-inherited OpenMP_Clauses.
86 if (clauseClassName.empty()) {
87 for (const Record *superClass : clause->getSuperClasses()) {
88 if (superClass->isSubClassOf(R: ompClause)) {
89 clauseClassName = superClass->getName();
90 break;
91 }
92 }
93 }
94
95 assert(!clauseClassName.empty() && "clause name must be found");
96
97 // Keep only the OpenMP clause name itself for reporting purposes.
98 return stripPrefixAndSuffix(str: clauseClassName, /*prefixes=*/{"OpenMP_"},
99 /*suffixes=*/{"Skip", "Clause"});
100}
101
102/// Check that the given argument, identified by its name and initialization
103/// value, is present in the \c arguments `dag`.
104static bool verifyArgument(const DagInit *arguments, StringRef argName,
105 const Init *argInit) {
106 auto range = zip_equal(t: arguments->getArgNames(), u: arguments->getArgs());
107 return llvm::any_of(
108 Range&: range, P: [&](std::tuple<const llvm::StringInit *, const llvm::Init *> v) {
109 return std::get<0>(t&: v)->getAsUnquotedString() == argName &&
110 std::get<1>(t&: v) == argInit;
111 });
112}
113
114/// Check that the given string record value, identified by its \c opValueName,
115/// is either undefined or empty in both the given operation and clause record
116/// or its contents for the clause record are contained in the operation record.
117/// Passing a non-empty \c clauseValueName enables checking values named
118/// differently in the operation and clause records.
119static bool verifyStringValue(const Record *op, const Record *clause,
120 StringRef opValueName,
121 StringRef clauseValueName = {}) {
122 auto opValue = op->getValueAsOptionalString(FieldName: opValueName);
123 auto clauseValue = clause->getValueAsOptionalString(
124 FieldName: clauseValueName.empty() ? opValueName : clauseValueName);
125
126 bool opHasValue = opValue && !opValue->trim().empty();
127 bool clauseHasValue = clauseValue && !clauseValue->trim().empty();
128
129 if (!opHasValue)
130 return !clauseHasValue;
131
132 return !clauseHasValue || opValue->contains(Other: clauseValue->trim());
133}
134
135/// Verify that all fields of the given clause not explicitly ignored are
136/// present in the corresponding operation field.
137///
138/// Print warnings or errors where this is not the case.
139static void verifyClause(const Record *op, const Record *clause) {
140 StringRef clauseClassName = extractOmpClauseName(clause);
141
142 if (!clause->getValueAsBit(FieldName: "ignoreArgs")) {
143 const DagInit *opArguments = op->getValueAsDag(FieldName: "arguments");
144 const DagInit *arguments = clause->getValueAsDag(FieldName: "arguments");
145
146 for (auto [name, arg] :
147 zip(t: arguments->getArgNames(), u: arguments->getArgs())) {
148 if (!verifyArgument(arguments: opArguments, argName: name->getAsUnquotedString(), argInit: arg))
149 PrintWarning(
150 WarningLoc: op->getLoc(),
151 Msg: "'" + clauseClassName + "' clause-defined argument '" +
152 arg->getAsUnquotedString() + ":$" +
153 name->getAsUnquotedString() +
154 "' not present in operation. Consider `dag arguments = "
155 "!con(clausesArgs, ...)` or explicitly skipping this field.");
156 }
157 }
158
159 if (!clause->getValueAsBit(FieldName: "ignoreAsmFormat") &&
160 !verifyStringValue(op, clause, opValueName: "assemblyFormat", clauseValueName: "reqAssemblyFormat"))
161 PrintWarning(
162 WarningLoc: op->getLoc(),
163 Msg: "'" + clauseClassName +
164 "' clause-defined `reqAssemblyFormat` not present in operation. "
165 "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or "
166 "explicitly skipping this field.");
167
168 if (!clause->getValueAsBit(FieldName: "ignoreAsmFormat") &&
169 !verifyStringValue(op, clause, opValueName: "assemblyFormat", clauseValueName: "optAssemblyFormat"))
170 PrintWarning(
171 WarningLoc: op->getLoc(),
172 Msg: "'" + clauseClassName +
173 "' clause-defined `optAssemblyFormat` not present in operation. "
174 "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or "
175 "explicitly skipping this field.");
176
177 if (!clause->getValueAsBit(FieldName: "ignoreDesc") &&
178 !verifyStringValue(op, clause, opValueName: "description"))
179 PrintError(ErrorLoc: op->getLoc(),
180 Msg: "'" + clauseClassName +
181 "' clause-defined `description` not present in operation. "
182 "Consider concatenating `clausesDescription` or explicitly "
183 "skipping this field.");
184
185 if (!clause->getValueAsBit(FieldName: "ignoreExtraDecl") &&
186 !verifyStringValue(op, clause, opValueName: "extraClassDeclaration"))
187 PrintWarning(
188 WarningLoc: op->getLoc(),
189 Msg: "'" + clauseClassName +
190 "' clause-defined `extraClassDeclaration` not present in "
191 "operation. Consider concatenating `clausesExtraClassDeclaration` "
192 "or explicitly skipping this field.");
193}
194
195/// Translate the type of an OpenMP clause's argument to its corresponding
196/// representation for clause operand structures.
197///
198/// All kinds of values are represented as `mlir::Value` fields, whereas
199/// attributes are represented based on their `storageType`.
200///
201/// \param[in] name The name of the argument.
202/// \param[in] init The `DefInit` object representing the argument.
203/// \param[out] nest Number of levels of array nesting associated with the
204/// type. Must be initially set to 0.
205/// \param[out] rank Rank (number of dimensions, if an array type) of the base
206/// type. Must be initially set to 1.
207///
208/// \return the name of the base type to represent elements of the argument
209/// type.
210static StringRef translateArgumentType(ArrayRef<SMLoc> loc,
211 const StringInit *name, const Init *init,
212 int &nest, int &rank) {
213 const Record *def = cast<DefInit>(Val: init)->getDef();
214
215 llvm::StringSet<> superClasses;
216 for (const Record *sc : def->getSuperClasses())
217 superClasses.insert(key: sc->getNameInitAsString());
218
219 // Handle wrapper-style superclasses.
220 if (superClasses.contains(key: "OptionalAttr"))
221 return translateArgumentType(
222 loc, name, init: def->getValue(Name: "baseAttr")->getValue(), nest, rank);
223
224 if (superClasses.contains(key: "TypedArrayAttrBase"))
225 return translateArgumentType(
226 loc, name, init: def->getValue(Name: "elementAttr")->getValue(), nest&: ++nest, rank);
227
228 // Handle ElementsAttrBase superclasses.
229 if (superClasses.contains(key: "ElementsAttrBase")) {
230 // TODO: Obtain the rank from ranked types.
231 ++nest;
232
233 if (superClasses.contains(key: "IntElementsAttrBase"))
234 return "::llvm::APInt";
235 if (superClasses.contains(key: "FloatElementsAttr") ||
236 superClasses.contains(key: "RankedFloatElementsAttr"))
237 return "::llvm::APFloat";
238 if (superClasses.contains(key: "DenseArrayAttrBase"))
239 return stripPrefixAndSuffix(str: def->getValueAsString(FieldName: "returnType"),
240 prefixes: {"::llvm::ArrayRef<"}, suffixes: {">"});
241
242 // Decrease the nesting depth in the case where the base type cannot be
243 // inferred, so that the bare storageType is used instead of a vector.
244 --nest;
245 PrintWarning(
246 WarningLoc: loc,
247 Msg: "could not infer array-like attribute element type for argument '" +
248 name->getAsUnquotedString() + "', will use bare `storageType`");
249 }
250
251 // Handle simple attribute and value types.
252 [[maybe_unused]] bool isAttr = superClasses.contains(key: "Attr");
253 bool isValue = superClasses.contains(key: "TypeConstraint");
254 if (superClasses.contains(key: "Variadic"))
255 ++nest;
256
257 if (isValue) {
258 assert(!isAttr &&
259 "argument can't be simultaneously a value and an attribute");
260 return "::mlir::Value";
261 }
262
263 assert(isAttr && "argument must be an attribute if it's not a value");
264 return nest > 0 ? "::mlir::Attribute"
265 : def->getValueAsString(FieldName: "storageType").trim();
266}
267
268/// Generate the structure that represents the arguments of the given \c clause
269/// record of type \c OpenMP_Clause.
270///
271/// It will contain a field for each argument, using the same name translated to
272/// camel case and the corresponding base type as returned by
273/// translateArgumentType() optionally wrapped in one or more llvm::SmallVector.
274///
275/// An additional field containing a tuple of integers to hold the size of each
276/// dimension will also be created for multi-rank types. This is not yet
277/// supported.
278static void genClauseOpsStruct(const Record *clause, raw_ostream &os) {
279 if (clause->isAnonymous())
280 return;
281
282 StringRef clauseName = extractOmpClauseName(clause);
283 os << "struct " << clauseName << "ClauseOps {\n";
284
285 const DagInit *arguments = clause->getValueAsDag(FieldName: "arguments");
286 for (auto [name, arg] :
287 zip_equal(t: arguments->getArgNames(), u: arguments->getArgs())) {
288 int nest = 0, rank = 1;
289 StringRef baseType =
290 translateArgumentType(loc: clause->getLoc(), name, init: arg, nest, rank);
291 std::string fieldName =
292 convertToCamelFromSnakeCase(input: name->getAsUnquotedString(),
293 /*capitalizeFirst=*/false);
294
295 os << formatv(Fmt: " {0}{1}{2} {3};\n",
296 Vals: fmt_repeat(Item: "::llvm::SmallVector<", Count: nest), Vals&: baseType,
297 Vals: fmt_repeat(Item: ">", Count: nest), Vals&: fieldName);
298
299 if (rank > 1) {
300 assert(nest >= 1 && "must be nested if it's a ranked type");
301 os << formatv(Fmt: " {0}::std::tuple<{1}int>{2} {3}Dims;\n",
302 Vals: fmt_repeat(Item: "::llvm::SmallVector<", Count: nest - 1),
303 Vals: fmt_repeat(Item: "int, ", Count: rank - 1), Vals: fmt_repeat(Item: ">", Count: nest - 1),
304 Vals&: fieldName);
305 }
306 }
307
308 os << "};\n";
309}
310
311/// Generate the structure that represents the clause-related arguments of the
312/// given \c op record of type \c OpenMP_Op.
313///
314/// This structure will be defined in terms of the clause operand structures
315/// associated to the clauses of the operation.
316static void genOperandsDef(const Record *op, raw_ostream &os) {
317 if (op->isAnonymous())
318 return;
319
320 SmallVector<std::string> clauseNames;
321 for (const Record *clause : op->getValueAsListOfDefs(FieldName: "clauseList"))
322 clauseNames.push_back(Elt: (extractOmpClauseName(clause) + "ClauseOps").str());
323
324 StringRef opName = stripPrefixAndSuffix(
325 str: op->getName(), /*prefixes=*/{"OpenMP_"}, /*suffixes=*/{"Op"});
326 os << formatv(Fmt: operationArgStruct, Vals&: opName, Vals: join(R&: clauseNames, Separator: ", "));
327}
328
329/// Verify that all properties of `OpenMP_Clause`s of records deriving from
330/// `OpenMP_Op`s have been inherited by the latter.
331static bool verifyDecls(const RecordKeeper &records, raw_ostream &) {
332 for (const Record *op : records.getAllDerivedDefinitions(ClassName: "OpenMP_Op")) {
333 for (const Record *clause : op->getValueAsListOfDefs(FieldName: "clauseList"))
334 verifyClause(op, clause);
335 }
336
337 return false;
338}
339
340/// Generate structures to represent clause-related operands, based on existing
341/// `OpenMP_Clause` definitions and aggregate them into operation-specific
342/// structures according to the `clauses` argument of each definition deriving
343/// from `OpenMP_Op`.
344static bool genClauseOps(const RecordKeeper &records, raw_ostream &os) {
345 mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp");
346 for (const Record *clause : records.getAllDerivedDefinitions(ClassName: "OpenMP_Clause"))
347 genClauseOpsStruct(clause, os);
348
349 // Produce base mixin class.
350 os << baseMixinClass;
351
352 for (const Record *op : records.getAllDerivedDefinitions(ClassName: "OpenMP_Op"))
353 genOperandsDef(op, os);
354
355 return false;
356}
357
358// Registers the generator to mlir-tblgen.
359static mlir::GenRegistration
360 verifyOpenmpOps("verify-openmp-ops",
361 "Verify OpenMP operations (produce no output file)",
362 verifyDecls);
363
364static mlir::GenRegistration
365 genOpenmpClauseOps("gen-openmp-clause-ops",
366 "Generate OpenMP clause operand structures",
367 genClauseOps);
368

source code of mlir/tools/mlir-tblgen/OmpOpGen.cpp