| 1 | //===- OmpOpGen.cpp - OpenMP dialect op specific generators ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // OmpOpGen defines OpenMP dialect operation specific generators. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/TableGen/GenInfo.h" |
| 14 | |
| 15 | #include "mlir/TableGen/CodeGenHelpers.h" |
| 16 | #include "llvm/ADT/StringExtras.h" |
| 17 | #include "llvm/ADT/StringSet.h" |
| 18 | #include "llvm/ADT/TypeSwitch.h" |
| 19 | #include "llvm/Support/FormatAdapters.h" |
| 20 | #include "llvm/TableGen/Error.h" |
| 21 | #include "llvm/TableGen/Record.h" |
| 22 | |
| 23 | using namespace llvm; |
| 24 | |
| 25 | /// The code block defining the base mixin class for combining clause operand |
| 26 | /// structures. |
| 27 | static const char *const baseMixinClass = R"( |
| 28 | namespace detail { |
| 29 | template <typename... Mixins> |
| 30 | struct Clauses : public Mixins... {}; |
| 31 | } // namespace detail |
| 32 | )" ; |
| 33 | |
| 34 | /// The code block defining operation argument structures. |
| 35 | static const char *const operationArgStruct = R"( |
| 36 | using {0}Operands = detail::Clauses<{1}>; |
| 37 | )" ; |
| 38 | |
| 39 | /// Remove multiple optional prefixes and suffixes from \c str. |
| 40 | /// |
| 41 | /// Prefixes and suffixes are attempted to be removed once in the order they |
| 42 | /// appear in the \c prefixes and \c suffixes arguments. All prefixes are |
| 43 | /// processed before suffixes are. This means it will behave as shown in the |
| 44 | /// following example: |
| 45 | /// - str: "PrePreNameSuf1Suf2" |
| 46 | /// - prefixes: ["Pre"] |
| 47 | /// - suffixes: ["Suf1", "Suf2"] |
| 48 | /// - return: "PreNameSuf1" |
| 49 | static StringRef stripPrefixAndSuffix(StringRef str, |
| 50 | llvm::ArrayRef<StringRef> prefixes, |
| 51 | llvm::ArrayRef<StringRef> suffixes) { |
| 52 | for (StringRef prefix : prefixes) |
| 53 | if (str.starts_with(Prefix: prefix)) |
| 54 | str = str.drop_front(N: prefix.size()); |
| 55 | |
| 56 | for (StringRef suffix : suffixes) |
| 57 | if (str.ends_with(Suffix: suffix)) |
| 58 | str = str.drop_back(N: suffix.size()); |
| 59 | |
| 60 | return str; |
| 61 | } |
| 62 | |
| 63 | /// Obtain the name of the OpenMP clause a given record inheriting |
| 64 | /// `OpenMP_Clause` refers to. |
| 65 | /// |
| 66 | /// It supports direct and indirect `OpenMP_Clause` superclasses. Once the |
| 67 | /// `OpenMP_Clause` class the record is based on is found, the optional |
| 68 | /// "OpenMP_" prefix and "Skip" and "Clause" suffixes are removed to return only |
| 69 | /// the clause name, i.e. "OpenMP_CollapseClauseSkip" is returned as "Collapse". |
| 70 | static StringRef (const Record *clause) { |
| 71 | const Record *ompClause = clause->getRecords().getClass(Name: "OpenMP_Clause" ); |
| 72 | assert(ompClause && "base OpenMP records expected to be defined" ); |
| 73 | |
| 74 | StringRef clauseClassName; |
| 75 | |
| 76 | // Check if OpenMP_Clause is a direct superclass. |
| 77 | for (const Record *superClass : |
| 78 | llvm::make_first_range(c: clause->getDirectSuperClasses())) { |
| 79 | if (superClass == ompClause) { |
| 80 | clauseClassName = clause->getName(); |
| 81 | break; |
| 82 | } |
| 83 | } |
| 84 | |
| 85 | // Support indirectly-inherited OpenMP_Clauses. |
| 86 | if (clauseClassName.empty()) { |
| 87 | for (const Record *superClass : clause->getSuperClasses()) { |
| 88 | if (superClass->isSubClassOf(R: ompClause)) { |
| 89 | clauseClassName = superClass->getName(); |
| 90 | break; |
| 91 | } |
| 92 | } |
| 93 | } |
| 94 | |
| 95 | assert(!clauseClassName.empty() && "clause name must be found" ); |
| 96 | |
| 97 | // Keep only the OpenMP clause name itself for reporting purposes. |
| 98 | return stripPrefixAndSuffix(str: clauseClassName, /*prefixes=*/{"OpenMP_" }, |
| 99 | /*suffixes=*/{"Skip" , "Clause" }); |
| 100 | } |
| 101 | |
| 102 | /// Check that the given argument, identified by its name and initialization |
| 103 | /// value, is present in the \c arguments `dag`. |
| 104 | static bool verifyArgument(const DagInit *arguments, StringRef argName, |
| 105 | const Init *argInit) { |
| 106 | auto range = zip_equal(t: arguments->getArgNames(), u: arguments->getArgs()); |
| 107 | return llvm::any_of( |
| 108 | Range&: range, P: [&](std::tuple<const llvm::StringInit *, const llvm::Init *> v) { |
| 109 | return std::get<0>(t&: v)->getAsUnquotedString() == argName && |
| 110 | std::get<1>(t&: v) == argInit; |
| 111 | }); |
| 112 | } |
| 113 | |
| 114 | /// Check that the given string record value, identified by its \c opValueName, |
| 115 | /// is either undefined or empty in both the given operation and clause record |
| 116 | /// or its contents for the clause record are contained in the operation record. |
| 117 | /// Passing a non-empty \c clauseValueName enables checking values named |
| 118 | /// differently in the operation and clause records. |
| 119 | static bool verifyStringValue(const Record *op, const Record *clause, |
| 120 | StringRef opValueName, |
| 121 | StringRef clauseValueName = {}) { |
| 122 | auto opValue = op->getValueAsOptionalString(FieldName: opValueName); |
| 123 | auto clauseValue = clause->getValueAsOptionalString( |
| 124 | FieldName: clauseValueName.empty() ? opValueName : clauseValueName); |
| 125 | |
| 126 | bool opHasValue = opValue && !opValue->trim().empty(); |
| 127 | bool clauseHasValue = clauseValue && !clauseValue->trim().empty(); |
| 128 | |
| 129 | if (!opHasValue) |
| 130 | return !clauseHasValue; |
| 131 | |
| 132 | return !clauseHasValue || opValue->contains(Other: clauseValue->trim()); |
| 133 | } |
| 134 | |
| 135 | /// Verify that all fields of the given clause not explicitly ignored are |
| 136 | /// present in the corresponding operation field. |
| 137 | /// |
| 138 | /// Print warnings or errors where this is not the case. |
| 139 | static void verifyClause(const Record *op, const Record *clause) { |
| 140 | StringRef clauseClassName = extractOmpClauseName(clause); |
| 141 | |
| 142 | if (!clause->getValueAsBit(FieldName: "ignoreArgs" )) { |
| 143 | const DagInit *opArguments = op->getValueAsDag(FieldName: "arguments" ); |
| 144 | const DagInit *arguments = clause->getValueAsDag(FieldName: "arguments" ); |
| 145 | |
| 146 | for (auto [name, arg] : |
| 147 | zip(t: arguments->getArgNames(), u: arguments->getArgs())) { |
| 148 | if (!verifyArgument(arguments: opArguments, argName: name->getAsUnquotedString(), argInit: arg)) |
| 149 | PrintWarning( |
| 150 | WarningLoc: op->getLoc(), |
| 151 | Msg: "'" + clauseClassName + "' clause-defined argument '" + |
| 152 | arg->getAsUnquotedString() + ":$" + |
| 153 | name->getAsUnquotedString() + |
| 154 | "' not present in operation. Consider `dag arguments = " |
| 155 | "!con(clausesArgs, ...)` or explicitly skipping this field." ); |
| 156 | } |
| 157 | } |
| 158 | |
| 159 | if (!clause->getValueAsBit(FieldName: "ignoreAsmFormat" ) && |
| 160 | !verifyStringValue(op, clause, opValueName: "assemblyFormat" , clauseValueName: "reqAssemblyFormat" )) |
| 161 | PrintWarning( |
| 162 | WarningLoc: op->getLoc(), |
| 163 | Msg: "'" + clauseClassName + |
| 164 | "' clause-defined `reqAssemblyFormat` not present in operation. " |
| 165 | "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or " |
| 166 | "explicitly skipping this field." ); |
| 167 | |
| 168 | if (!clause->getValueAsBit(FieldName: "ignoreAsmFormat" ) && |
| 169 | !verifyStringValue(op, clause, opValueName: "assemblyFormat" , clauseValueName: "optAssemblyFormat" )) |
| 170 | PrintWarning( |
| 171 | WarningLoc: op->getLoc(), |
| 172 | Msg: "'" + clauseClassName + |
| 173 | "' clause-defined `optAssemblyFormat` not present in operation. " |
| 174 | "Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or " |
| 175 | "explicitly skipping this field." ); |
| 176 | |
| 177 | if (!clause->getValueAsBit(FieldName: "ignoreDesc" ) && |
| 178 | !verifyStringValue(op, clause, opValueName: "description" )) |
| 179 | PrintError(ErrorLoc: op->getLoc(), |
| 180 | Msg: "'" + clauseClassName + |
| 181 | "' clause-defined `description` not present in operation. " |
| 182 | "Consider concatenating `clausesDescription` or explicitly " |
| 183 | "skipping this field." ); |
| 184 | |
| 185 | if (!clause->getValueAsBit(FieldName: "ignoreExtraDecl" ) && |
| 186 | !verifyStringValue(op, clause, opValueName: "extraClassDeclaration" )) |
| 187 | PrintWarning( |
| 188 | WarningLoc: op->getLoc(), |
| 189 | Msg: "'" + clauseClassName + |
| 190 | "' clause-defined `extraClassDeclaration` not present in " |
| 191 | "operation. Consider concatenating `clausesExtraClassDeclaration` " |
| 192 | "or explicitly skipping this field." ); |
| 193 | } |
| 194 | |
| 195 | /// Translate the type of an OpenMP clause's argument to its corresponding |
| 196 | /// representation for clause operand structures. |
| 197 | /// |
| 198 | /// All kinds of values are represented as `mlir::Value` fields, whereas |
| 199 | /// attributes are represented based on their `storageType`. |
| 200 | /// |
| 201 | /// \param[in] name The name of the argument. |
| 202 | /// \param[in] init The `DefInit` object representing the argument. |
| 203 | /// \param[out] nest Number of levels of array nesting associated with the |
| 204 | /// type. Must be initially set to 0. |
| 205 | /// \param[out] rank Rank (number of dimensions, if an array type) of the base |
| 206 | /// type. Must be initially set to 1. |
| 207 | /// |
| 208 | /// \return the name of the base type to represent elements of the argument |
| 209 | /// type. |
| 210 | static StringRef translateArgumentType(ArrayRef<SMLoc> loc, |
| 211 | const StringInit *name, const Init *init, |
| 212 | int &nest, int &rank) { |
| 213 | const Record *def = cast<DefInit>(Val: init)->getDef(); |
| 214 | |
| 215 | llvm::StringSet<> superClasses; |
| 216 | for (const Record *sc : def->getSuperClasses()) |
| 217 | superClasses.insert(key: sc->getNameInitAsString()); |
| 218 | |
| 219 | // Handle wrapper-style superclasses. |
| 220 | if (superClasses.contains(key: "OptionalAttr" )) |
| 221 | return translateArgumentType( |
| 222 | loc, name, init: def->getValue(Name: "baseAttr" )->getValue(), nest, rank); |
| 223 | |
| 224 | if (superClasses.contains(key: "TypedArrayAttrBase" )) |
| 225 | return translateArgumentType( |
| 226 | loc, name, init: def->getValue(Name: "elementAttr" )->getValue(), nest&: ++nest, rank); |
| 227 | |
| 228 | // Handle ElementsAttrBase superclasses. |
| 229 | if (superClasses.contains(key: "ElementsAttrBase" )) { |
| 230 | // TODO: Obtain the rank from ranked types. |
| 231 | ++nest; |
| 232 | |
| 233 | if (superClasses.contains(key: "IntElementsAttrBase" )) |
| 234 | return "::llvm::APInt" ; |
| 235 | if (superClasses.contains(key: "FloatElementsAttr" ) || |
| 236 | superClasses.contains(key: "RankedFloatElementsAttr" )) |
| 237 | return "::llvm::APFloat" ; |
| 238 | if (superClasses.contains(key: "DenseArrayAttrBase" )) |
| 239 | return stripPrefixAndSuffix(str: def->getValueAsString(FieldName: "returnType" ), |
| 240 | prefixes: {"::llvm::ArrayRef<" }, suffixes: {">" }); |
| 241 | |
| 242 | // Decrease the nesting depth in the case where the base type cannot be |
| 243 | // inferred, so that the bare storageType is used instead of a vector. |
| 244 | --nest; |
| 245 | PrintWarning( |
| 246 | WarningLoc: loc, |
| 247 | Msg: "could not infer array-like attribute element type for argument '" + |
| 248 | name->getAsUnquotedString() + "', will use bare `storageType`" ); |
| 249 | } |
| 250 | |
| 251 | // Handle simple attribute and value types. |
| 252 | [[maybe_unused]] bool isAttr = superClasses.contains(key: "Attr" ); |
| 253 | bool isValue = superClasses.contains(key: "TypeConstraint" ); |
| 254 | if (superClasses.contains(key: "Variadic" )) |
| 255 | ++nest; |
| 256 | |
| 257 | if (isValue) { |
| 258 | assert(!isAttr && |
| 259 | "argument can't be simultaneously a value and an attribute" ); |
| 260 | return "::mlir::Value" ; |
| 261 | } |
| 262 | |
| 263 | assert(isAttr && "argument must be an attribute if it's not a value" ); |
| 264 | return nest > 0 ? "::mlir::Attribute" |
| 265 | : def->getValueAsString(FieldName: "storageType" ).trim(); |
| 266 | } |
| 267 | |
| 268 | /// Generate the structure that represents the arguments of the given \c clause |
| 269 | /// record of type \c OpenMP_Clause. |
| 270 | /// |
| 271 | /// It will contain a field for each argument, using the same name translated to |
| 272 | /// camel case and the corresponding base type as returned by |
| 273 | /// translateArgumentType() optionally wrapped in one or more llvm::SmallVector. |
| 274 | /// |
| 275 | /// An additional field containing a tuple of integers to hold the size of each |
| 276 | /// dimension will also be created for multi-rank types. This is not yet |
| 277 | /// supported. |
| 278 | static void genClauseOpsStruct(const Record *clause, raw_ostream &os) { |
| 279 | if (clause->isAnonymous()) |
| 280 | return; |
| 281 | |
| 282 | StringRef clauseName = extractOmpClauseName(clause); |
| 283 | os << "struct " << clauseName << "ClauseOps {\n" ; |
| 284 | |
| 285 | const DagInit *arguments = clause->getValueAsDag(FieldName: "arguments" ); |
| 286 | for (auto [name, arg] : |
| 287 | zip_equal(t: arguments->getArgNames(), u: arguments->getArgs())) { |
| 288 | int nest = 0, rank = 1; |
| 289 | StringRef baseType = |
| 290 | translateArgumentType(loc: clause->getLoc(), name, init: arg, nest, rank); |
| 291 | std::string fieldName = |
| 292 | convertToCamelFromSnakeCase(input: name->getAsUnquotedString(), |
| 293 | /*capitalizeFirst=*/false); |
| 294 | |
| 295 | os << formatv(Fmt: " {0}{1}{2} {3};\n" , |
| 296 | Vals: fmt_repeat(Item: "::llvm::SmallVector<" , Count: nest), Vals&: baseType, |
| 297 | Vals: fmt_repeat(Item: ">" , Count: nest), Vals&: fieldName); |
| 298 | |
| 299 | if (rank > 1) { |
| 300 | assert(nest >= 1 && "must be nested if it's a ranked type" ); |
| 301 | os << formatv(Fmt: " {0}::std::tuple<{1}int>{2} {3}Dims;\n" , |
| 302 | Vals: fmt_repeat(Item: "::llvm::SmallVector<" , Count: nest - 1), |
| 303 | Vals: fmt_repeat(Item: "int, " , Count: rank - 1), Vals: fmt_repeat(Item: ">" , Count: nest - 1), |
| 304 | Vals&: fieldName); |
| 305 | } |
| 306 | } |
| 307 | |
| 308 | os << "};\n" ; |
| 309 | } |
| 310 | |
| 311 | /// Generate the structure that represents the clause-related arguments of the |
| 312 | /// given \c op record of type \c OpenMP_Op. |
| 313 | /// |
| 314 | /// This structure will be defined in terms of the clause operand structures |
| 315 | /// associated to the clauses of the operation. |
| 316 | static void genOperandsDef(const Record *op, raw_ostream &os) { |
| 317 | if (op->isAnonymous()) |
| 318 | return; |
| 319 | |
| 320 | SmallVector<std::string> clauseNames; |
| 321 | for (const Record *clause : op->getValueAsListOfDefs(FieldName: "clauseList" )) |
| 322 | clauseNames.push_back(Elt: (extractOmpClauseName(clause) + "ClauseOps" ).str()); |
| 323 | |
| 324 | StringRef opName = stripPrefixAndSuffix( |
| 325 | str: op->getName(), /*prefixes=*/{"OpenMP_" }, /*suffixes=*/{"Op" }); |
| 326 | os << formatv(Fmt: operationArgStruct, Vals&: opName, Vals: join(R&: clauseNames, Separator: ", " )); |
| 327 | } |
| 328 | |
| 329 | /// Verify that all properties of `OpenMP_Clause`s of records deriving from |
| 330 | /// `OpenMP_Op`s have been inherited by the latter. |
| 331 | static bool verifyDecls(const RecordKeeper &records, raw_ostream &) { |
| 332 | for (const Record *op : records.getAllDerivedDefinitions(ClassName: "OpenMP_Op" )) { |
| 333 | for (const Record *clause : op->getValueAsListOfDefs(FieldName: "clauseList" )) |
| 334 | verifyClause(op, clause); |
| 335 | } |
| 336 | |
| 337 | return false; |
| 338 | } |
| 339 | |
| 340 | /// Generate structures to represent clause-related operands, based on existing |
| 341 | /// `OpenMP_Clause` definitions and aggregate them into operation-specific |
| 342 | /// structures according to the `clauses` argument of each definition deriving |
| 343 | /// from `OpenMP_Op`. |
| 344 | static bool genClauseOps(const RecordKeeper &records, raw_ostream &os) { |
| 345 | mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp" ); |
| 346 | for (const Record *clause : records.getAllDerivedDefinitions(ClassName: "OpenMP_Clause" )) |
| 347 | genClauseOpsStruct(clause, os); |
| 348 | |
| 349 | // Produce base mixin class. |
| 350 | os << baseMixinClass; |
| 351 | |
| 352 | for (const Record *op : records.getAllDerivedDefinitions(ClassName: "OpenMP_Op" )) |
| 353 | genOperandsDef(op, os); |
| 354 | |
| 355 | return false; |
| 356 | } |
| 357 | |
| 358 | // Registers the generator to mlir-tblgen. |
| 359 | static mlir::GenRegistration |
| 360 | verifyOpenmpOps("verify-openmp-ops" , |
| 361 | "Verify OpenMP operations (produce no output file)" , |
| 362 | verifyDecls); |
| 363 | |
| 364 | static mlir::GenRegistration |
| 365 | genOpenmpClauseOps("gen-openmp-clause-ops" , |
| 366 | "Generate OpenMP clause operand structures" , |
| 367 | genClauseOps); |
| 368 | |