| 1 | //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Pass/PassRegistry.h" |
| 10 | |
| 11 | #include "mlir/Pass/Pass.h" |
| 12 | #include "mlir/Pass/PassManager.h" |
| 13 | #include "llvm/ADT/DenseMap.h" |
| 14 | #include "llvm/ADT/ScopeExit.h" |
| 15 | #include "llvm/ADT/StringRef.h" |
| 16 | #include "llvm/Support/Format.h" |
| 17 | #include "llvm/Support/ManagedStatic.h" |
| 18 | #include "llvm/Support/MemoryBuffer.h" |
| 19 | #include "llvm/Support/SourceMgr.h" |
| 20 | |
| 21 | #include <optional> |
| 22 | #include <utility> |
| 23 | |
| 24 | using namespace mlir; |
| 25 | using namespace detail; |
| 26 | |
| 27 | /// Static mapping of all of the registered passes. |
| 28 | static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry; |
| 29 | |
| 30 | /// A mapping of the above pass registry entries to the corresponding TypeID |
| 31 | /// of the pass that they generate. |
| 32 | static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs; |
| 33 | |
| 34 | /// Static mapping of all of the registered pass pipelines. |
| 35 | static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> |
| 36 | passPipelineRegistry; |
| 37 | |
| 38 | /// Utility to create a default registry function from a pass instance. |
| 39 | static PassRegistryFunction |
| 40 | buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { |
| 41 | return [=](OpPassManager &pm, StringRef options, |
| 42 | function_ref<LogicalResult(const Twine &)> errorHandler) { |
| 43 | std::unique_ptr<Pass> pass = allocator(); |
| 44 | LogicalResult result = pass->initializeOptions(options, errorHandler); |
| 45 | |
| 46 | std::optional<StringRef> pmOpName = pm.getOpName(); |
| 47 | std::optional<StringRef> passOpName = pass->getOpName(); |
| 48 | if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName && |
| 49 | passOpName && *pmOpName != *passOpName) { |
| 50 | return errorHandler(llvm::Twine("Can't add pass '" ) + pass->getName() + |
| 51 | "' restricted to '" + *pass->getOpName() + |
| 52 | "' on a PassManager intended to run on '" + |
| 53 | pm.getOpAnchorName() + "', did you intend to nest?" ); |
| 54 | } |
| 55 | pm.addPass(pass: std::move(pass)); |
| 56 | return result; |
| 57 | }; |
| 58 | } |
| 59 | |
| 60 | /// Utility to print the help string for a specific option. |
| 61 | static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, |
| 62 | size_t descIndent, bool isTopLevel) { |
| 63 | size_t numSpaces = descIndent - indent - 4; |
| 64 | llvm::outs().indent(NumSpaces: indent) |
| 65 | << "--" << llvm::left_justify(Str: arg, Width: numSpaces) << "- " << desc << '\n'; |
| 66 | } |
| 67 | |
| 68 | //===----------------------------------------------------------------------===// |
| 69 | // PassRegistry |
| 70 | //===----------------------------------------------------------------------===// |
| 71 | |
| 72 | /// Prints the passes that were previously registered and stored in passRegistry |
| 73 | void mlir::printRegisteredPasses() { |
| 74 | size_t maxWidth = 0; |
| 75 | for (auto &entry : *passRegistry) |
| 76 | maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4); |
| 77 | |
| 78 | // Functor used to print the ordered entries of a registration map. |
| 79 | auto printOrderedEntries = [&](StringRef , auto &map) { |
| 80 | llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries; |
| 81 | for (auto &kv : map) |
| 82 | orderedEntries.push_back(Elt: &kv.second); |
| 83 | llvm::array_pod_sort( |
| 84 | orderedEntries.begin(), orderedEntries.end(), |
| 85 | [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { |
| 86 | return (*lhs)->getPassArgument().compare(RHS: (*rhs)->getPassArgument()); |
| 87 | }); |
| 88 | |
| 89 | llvm::outs().indent(NumSpaces: 0) << header << ":\n" ; |
| 90 | for (PassRegistryEntry *entry : orderedEntries) |
| 91 | entry->printHelpStr(/*indent=*/2, descIndent: maxWidth); |
| 92 | }; |
| 93 | |
| 94 | // Print the available passes. |
| 95 | printOrderedEntries("Passes" , *passRegistry); |
| 96 | } |
| 97 | |
| 98 | /// Print the help information for this pass. This includes the argument, |
| 99 | /// description, and any pass options. `descIndent` is the indent that the |
| 100 | /// descriptions should be aligned. |
| 101 | void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const { |
| 102 | printOptionHelp(arg: getPassArgument(), desc: getPassDescription(), indent, descIndent, |
| 103 | /*isTopLevel=*/true); |
| 104 | // If this entry has options, print the help for those as well. |
| 105 | optHandler([=](const PassOptions &options) { |
| 106 | options.printHelp(indent, descIndent); |
| 107 | }); |
| 108 | } |
| 109 | |
| 110 | /// Return the maximum width required when printing the options of this |
| 111 | /// entry. |
| 112 | size_t PassRegistryEntry::getOptionWidth() const { |
| 113 | size_t maxLen = 0; |
| 114 | optHandler([&](const PassOptions &options) mutable { |
| 115 | maxLen = options.getOptionWidth() + 2; |
| 116 | }); |
| 117 | return maxLen; |
| 118 | } |
| 119 | |
| 120 | //===----------------------------------------------------------------------===// |
| 121 | // PassPipelineInfo |
| 122 | //===----------------------------------------------------------------------===// |
| 123 | |
| 124 | void mlir::registerPassPipeline( |
| 125 | StringRef arg, StringRef description, const PassRegistryFunction &function, |
| 126 | std::function<void(function_ref<void(const PassOptions &)>)> optHandler) { |
| 127 | PassPipelineInfo pipelineInfo(arg, description, function, |
| 128 | std::move(optHandler)); |
| 129 | bool inserted = passPipelineRegistry->try_emplace(Key: arg, Args&: pipelineInfo).second; |
| 130 | #ifndef NDEBUG |
| 131 | if (!inserted) |
| 132 | report_fatal_error(reason: "Pass pipeline " + arg + " registered multiple times" ); |
| 133 | #endif |
| 134 | (void)inserted; |
| 135 | } |
| 136 | |
| 137 | //===----------------------------------------------------------------------===// |
| 138 | // PassInfo |
| 139 | //===----------------------------------------------------------------------===// |
| 140 | |
| 141 | PassInfo::PassInfo(StringRef arg, StringRef description, |
| 142 | const PassAllocatorFunction &allocator) |
| 143 | : PassRegistryEntry( |
| 144 | arg, description, buildDefaultRegistryFn(allocator), |
| 145 | // Use a temporary pass to provide an options instance. |
| 146 | [=](function_ref<void(const PassOptions &)> optHandler) { |
| 147 | optHandler(allocator()->passOptions); |
| 148 | }) {} |
| 149 | |
| 150 | void mlir::registerPass(const PassAllocatorFunction &function) { |
| 151 | std::unique_ptr<Pass> pass = function(); |
| 152 | StringRef arg = pass->getArgument(); |
| 153 | if (arg.empty()) |
| 154 | llvm::report_fatal_error(reason: llvm::Twine("Trying to register '" ) + |
| 155 | pass->getName() + |
| 156 | "' pass that does not override `getArgument()`" ); |
| 157 | StringRef description = pass->getDescription(); |
| 158 | PassInfo passInfo(arg, description, function); |
| 159 | passRegistry->try_emplace(Key: arg, Args&: passInfo); |
| 160 | |
| 161 | // Verify that the registered pass has the same ID as any registered to this |
| 162 | // arg before it. |
| 163 | TypeID entryTypeID = pass->getTypeID(); |
| 164 | auto it = passRegistryTypeIDs->try_emplace(Key: arg, Args&: entryTypeID).first; |
| 165 | if (it->second != entryTypeID) |
| 166 | llvm::report_fatal_error( |
| 167 | reason: "pass allocator creates a different pass than previously " |
| 168 | "registered for pass " + |
| 169 | arg); |
| 170 | } |
| 171 | |
| 172 | /// Returns the pass info for the specified pass argument or null if unknown. |
| 173 | const PassInfo *mlir::PassInfo::lookup(StringRef passArg) { |
| 174 | auto it = passRegistry->find(Key: passArg); |
| 175 | return it == passRegistry->end() ? nullptr : &it->second; |
| 176 | } |
| 177 | |
| 178 | /// Returns the pass pipeline info for the specified pass pipeline argument or |
| 179 | /// null if unknown. |
| 180 | const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) { |
| 181 | auto it = passPipelineRegistry->find(Key: pipelineArg); |
| 182 | return it == passPipelineRegistry->end() ? nullptr : &it->second; |
| 183 | } |
| 184 | |
| 185 | //===----------------------------------------------------------------------===// |
| 186 | // PassOptions |
| 187 | //===----------------------------------------------------------------------===// |
| 188 | |
| 189 | /// Attempt to find the next occurance of character 'c' in the string starting |
| 190 | /// from the `index`-th position , omitting any occurances that appear within |
| 191 | /// intervening ranges or literals. |
| 192 | static size_t findChar(StringRef str, size_t index, char c) { |
| 193 | for (size_t i = index, e = str.size(); i < e; ++i) { |
| 194 | if (str[i] == c) |
| 195 | return i; |
| 196 | // Check for various range characters. |
| 197 | if (str[i] == '{') |
| 198 | i = findChar(str, index: i + 1, c: '}'); |
| 199 | else if (str[i] == '(') |
| 200 | i = findChar(str, index: i + 1, c: ')'); |
| 201 | else if (str[i] == '[') |
| 202 | i = findChar(str, index: i + 1, c: ']'); |
| 203 | else if (str[i] == '\"') |
| 204 | i = str.find_first_of(C: '\"', From: i + 1); |
| 205 | else if (str[i] == '\'') |
| 206 | i = str.find_first_of(C: '\'', From: i + 1); |
| 207 | if (i == StringRef::npos) |
| 208 | return StringRef::npos; |
| 209 | } |
| 210 | return StringRef::npos; |
| 211 | } |
| 212 | |
| 213 | /// Extract an argument from 'options' and update it to point after the arg. |
| 214 | /// Returns the cleaned argument string. |
| 215 | static StringRef extractArgAndUpdateOptions(StringRef &options, |
| 216 | size_t argSize) { |
| 217 | StringRef str = options.take_front(N: argSize).trim(); |
| 218 | options = options.drop_front(N: argSize).ltrim(); |
| 219 | |
| 220 | // Early exit if there's no escape sequence. |
| 221 | if (str.size() <= 1) |
| 222 | return str; |
| 223 | |
| 224 | const auto escapePairs = {std::make_pair(x: '\'', y: '\''), |
| 225 | std::make_pair(x: '"', y: '"')}; |
| 226 | for (const auto &escape : escapePairs) { |
| 227 | if (str.front() == escape.first && str.back() == escape.second) { |
| 228 | // Drop the escape characters and trim. |
| 229 | // Don't process additional escape sequences. |
| 230 | return str.drop_front().drop_back().trim(); |
| 231 | } |
| 232 | } |
| 233 | |
| 234 | // Arguments may be wrapped in `{...}`. Unlike the quotation markers that |
| 235 | // denote literals, we respect scoping here. The outer `{...}` should not |
| 236 | // be stripped in cases such as "arg={...},{...}", which can be used to denote |
| 237 | // lists of nested option structs. |
| 238 | if (str.front() == '{') { |
| 239 | unsigned match = findChar(str, index: 1, c: '}'); |
| 240 | if (match == str.size() - 1) |
| 241 | str = str.drop_front().drop_back().trim(); |
| 242 | } |
| 243 | |
| 244 | return str; |
| 245 | } |
| 246 | |
| 247 | LogicalResult detail::pass_options::parseCommaSeparatedList( |
| 248 | llvm::cl::Option &opt, StringRef argName, StringRef optionStr, |
| 249 | function_ref<LogicalResult(StringRef)> elementParseFn) { |
| 250 | if (optionStr.empty()) |
| 251 | return success(); |
| 252 | |
| 253 | size_t nextElePos = findChar(str: optionStr, index: 0, c: ','); |
| 254 | while (nextElePos != StringRef::npos) { |
| 255 | // Process the portion before the comma. |
| 256 | if (failed( |
| 257 | Result: elementParseFn(extractArgAndUpdateOptions(options&: optionStr, argSize: nextElePos)))) |
| 258 | return failure(); |
| 259 | |
| 260 | // Drop the leading ',' |
| 261 | optionStr = optionStr.drop_front(); |
| 262 | nextElePos = findChar(str: optionStr, index: 0, c: ','); |
| 263 | } |
| 264 | return elementParseFn( |
| 265 | extractArgAndUpdateOptions(options&: optionStr, argSize: optionStr.size())); |
| 266 | } |
| 267 | |
| 268 | /// Out of line virtual function to provide home for the class. |
| 269 | void detail::PassOptions::OptionBase::anchor() {} |
| 270 | |
| 271 | /// Copy the option values from 'other'. |
| 272 | void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { |
| 273 | assert(options.size() == other.options.size()); |
| 274 | if (options.empty()) |
| 275 | return; |
| 276 | for (auto optionsIt : llvm::zip(t&: options, u: other.options)) |
| 277 | std::get<0>(t&: optionsIt)->copyValueFrom(other: *std::get<1>(t&: optionsIt)); |
| 278 | } |
| 279 | |
| 280 | /// Parse in the next argument from the given options string. Returns a tuple |
| 281 | /// containing [the key of the option, the value of the option, updated |
| 282 | /// `options` string pointing after the parsed option]. |
| 283 | static std::tuple<StringRef, StringRef, StringRef> |
| 284 | parseNextArg(StringRef options) { |
| 285 | // Try to process the given punctuation, properly escaping any contained |
| 286 | // characters. |
| 287 | auto tryProcessPunct = [&](size_t ¤tPos, char punct) { |
| 288 | if (options[currentPos] != punct) |
| 289 | return false; |
| 290 | size_t nextIt = options.find_first_of(C: punct, From: currentPos + 1); |
| 291 | if (nextIt != StringRef::npos) |
| 292 | currentPos = nextIt; |
| 293 | return true; |
| 294 | }; |
| 295 | |
| 296 | // Parse the argument name of the option. |
| 297 | StringRef argName; |
| 298 | for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { |
| 299 | // Check for the end of the full option. |
| 300 | if (argEndIt == optionsE || options[argEndIt] == ' ') { |
| 301 | argName = extractArgAndUpdateOptions(options, argSize: argEndIt); |
| 302 | return std::make_tuple(args&: argName, args: StringRef(), args&: options); |
| 303 | } |
| 304 | |
| 305 | // Check for the end of the name and the start of the value. |
| 306 | if (options[argEndIt] == '=') { |
| 307 | argName = extractArgAndUpdateOptions(options, argSize: argEndIt); |
| 308 | options = options.drop_front(); |
| 309 | break; |
| 310 | } |
| 311 | } |
| 312 | |
| 313 | // Parse the value of the option. |
| 314 | for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { |
| 315 | // Handle the end of the options string. |
| 316 | if (argEndIt == optionsE || options[argEndIt] == ' ') { |
| 317 | StringRef value = extractArgAndUpdateOptions(options, argSize: argEndIt); |
| 318 | return std::make_tuple(args&: argName, args&: value, args&: options); |
| 319 | } |
| 320 | |
| 321 | // Skip over escaped sequences. |
| 322 | char c = options[argEndIt]; |
| 323 | if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"')) |
| 324 | continue; |
| 325 | // '{...}' is used to specify options to passes, properly escape it so |
| 326 | // that we don't accidentally split any nested options. |
| 327 | if (c == '{') { |
| 328 | size_t braceCount = 1; |
| 329 | for (++argEndIt; argEndIt != optionsE; ++argEndIt) { |
| 330 | // Allow nested punctuation. |
| 331 | if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"')) |
| 332 | continue; |
| 333 | if (options[argEndIt] == '{') |
| 334 | ++braceCount; |
| 335 | else if (options[argEndIt] == '}' && --braceCount == 0) |
| 336 | break; |
| 337 | } |
| 338 | // Account for the increment at the top of the loop. |
| 339 | --argEndIt; |
| 340 | } |
| 341 | } |
| 342 | llvm_unreachable("unexpected control flow in pass option parsing" ); |
| 343 | } |
| 344 | |
| 345 | LogicalResult detail::PassOptions::parseFromString(StringRef options, |
| 346 | raw_ostream &errorStream) { |
| 347 | // NOTE: `options` is modified in place to always refer to the unprocessed |
| 348 | // part of the string. |
| 349 | while (!options.empty()) { |
| 350 | StringRef key, value; |
| 351 | std::tie(args&: key, args&: value, args&: options) = parseNextArg(options); |
| 352 | if (key.empty()) |
| 353 | continue; |
| 354 | |
| 355 | auto it = OptionsMap.find(Key: key); |
| 356 | if (it == OptionsMap.end()) { |
| 357 | errorStream << "<Pass-Options-Parser>: no such option " << key << "\n" ; |
| 358 | return failure(); |
| 359 | } |
| 360 | if (llvm::cl::ProvidePositionalOption(Handler: it->second, Arg: value, i: 0)) |
| 361 | return failure(); |
| 362 | } |
| 363 | |
| 364 | return success(); |
| 365 | } |
| 366 | |
| 367 | /// Print the options held by this struct in a form that can be parsed via |
| 368 | /// 'parseFromString'. |
| 369 | void detail::PassOptions::print(raw_ostream &os) const { |
| 370 | // If there are no options, there is nothing left to do. |
| 371 | if (OptionsMap.empty()) |
| 372 | return; |
| 373 | |
| 374 | // Sort the options to make the ordering deterministic. |
| 375 | SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); |
| 376 | auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { |
| 377 | return (*lhs)->getArgStr().compare(RHS: (*rhs)->getArgStr()); |
| 378 | }; |
| 379 | llvm::array_pod_sort(Start: orderedOps.begin(), End: orderedOps.end(), Compare: compareOptionArgs); |
| 380 | |
| 381 | // Interleave the options with ' '. |
| 382 | os << '{'; |
| 383 | llvm::interleave( |
| 384 | c: orderedOps, os, each_fn: [&](OptionBase *option) { option->print(os); }, separator: " " ); |
| 385 | os << '}'; |
| 386 | } |
| 387 | |
| 388 | /// Print the help string for the options held by this struct. `descIndent` is |
| 389 | /// the indent within the stream that the descriptions should be aligned. |
| 390 | void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const { |
| 391 | // Sort the options to make the ordering deterministic. |
| 392 | SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); |
| 393 | auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { |
| 394 | return (*lhs)->getArgStr().compare(RHS: (*rhs)->getArgStr()); |
| 395 | }; |
| 396 | llvm::array_pod_sort(Start: orderedOps.begin(), End: orderedOps.end(), Compare: compareOptionArgs); |
| 397 | for (OptionBase *option : orderedOps) { |
| 398 | // TODO: printOptionInfo assumes a specific indent and will |
| 399 | // print options with values with incorrect indentation. We should add |
| 400 | // support to llvm::cl::Option for passing in a base indent to use when |
| 401 | // printing. |
| 402 | llvm::outs().indent(NumSpaces: indent); |
| 403 | option->getOption()->printOptionInfo(GlobalWidth: descIndent - indent); |
| 404 | } |
| 405 | } |
| 406 | |
| 407 | /// Return the maximum width required when printing the help string. |
| 408 | size_t detail::PassOptions::getOptionWidth() const { |
| 409 | size_t max = 0; |
| 410 | for (auto *option : options) |
| 411 | max = std::max(a: max, b: option->getOption()->getOptionWidth()); |
| 412 | return max; |
| 413 | } |
| 414 | |
| 415 | //===----------------------------------------------------------------------===// |
| 416 | // MLIR Options |
| 417 | //===----------------------------------------------------------------------===// |
| 418 | |
| 419 | //===----------------------------------------------------------------------===// |
| 420 | // OpPassManager: OptionValue |
| 421 | //===----------------------------------------------------------------------===// |
| 422 | |
| 423 | llvm::cl::OptionValue<OpPassManager>::OptionValue() = default; |
| 424 | llvm::cl::OptionValue<OpPassManager>::OptionValue( |
| 425 | const mlir::OpPassManager &value) { |
| 426 | setValue(value); |
| 427 | } |
| 428 | llvm::cl::OptionValue<OpPassManager>::OptionValue( |
| 429 | const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) { |
| 430 | if (rhs.hasValue()) |
| 431 | setValue(rhs.getValue()); |
| 432 | } |
| 433 | llvm::cl::OptionValue<OpPassManager> & |
| 434 | llvm::cl::OptionValue<OpPassManager>::operator=( |
| 435 | const mlir::OpPassManager &rhs) { |
| 436 | setValue(rhs); |
| 437 | return *this; |
| 438 | } |
| 439 | |
| 440 | llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default; |
| 441 | |
| 442 | void llvm::cl::OptionValue<OpPassManager>::setValue( |
| 443 | const OpPassManager &newValue) { |
| 444 | if (hasValue()) |
| 445 | *value = newValue; |
| 446 | else |
| 447 | value = std::make_unique<mlir::OpPassManager>(args: newValue); |
| 448 | } |
| 449 | void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) { |
| 450 | FailureOr<OpPassManager> pipeline = parsePassPipeline(pipeline: pipelineStr); |
| 451 | assert(succeeded(pipeline) && "invalid pass pipeline" ); |
| 452 | setValue(*pipeline); |
| 453 | } |
| 454 | |
| 455 | bool llvm::cl::OptionValue<OpPassManager>::compare( |
| 456 | const mlir::OpPassManager &rhs) const { |
| 457 | std::string lhsStr, rhsStr; |
| 458 | { |
| 459 | raw_string_ostream lhsStream(lhsStr); |
| 460 | value->printAsTextualPipeline(os&: lhsStream); |
| 461 | |
| 462 | raw_string_ostream rhsStream(rhsStr); |
| 463 | rhs.printAsTextualPipeline(os&: rhsStream); |
| 464 | } |
| 465 | |
| 466 | // Use the textual format for pipeline comparisons. |
| 467 | return lhsStr == rhsStr; |
| 468 | } |
| 469 | |
| 470 | void llvm::cl::OptionValue<OpPassManager>::anchor() {} |
| 471 | |
| 472 | //===----------------------------------------------------------------------===// |
| 473 | // OpPassManager: Parser |
| 474 | //===----------------------------------------------------------------------===// |
| 475 | |
| 476 | namespace llvm { |
| 477 | namespace cl { |
| 478 | template class basic_parser<OpPassManager>; |
| 479 | } // namespace cl |
| 480 | } // namespace llvm |
| 481 | |
| 482 | bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg, |
| 483 | ParsedPassManager &value) { |
| 484 | FailureOr<OpPassManager> pipeline = parsePassPipeline(pipeline: arg); |
| 485 | if (failed(Result: pipeline)) |
| 486 | return true; |
| 487 | value.value = std::make_unique<OpPassManager>(args: std::move(*pipeline)); |
| 488 | return false; |
| 489 | } |
| 490 | |
| 491 | void llvm::cl::parser<OpPassManager>::print(raw_ostream &os, |
| 492 | const OpPassManager &value) { |
| 493 | value.printAsTextualPipeline(os); |
| 494 | } |
| 495 | |
| 496 | void llvm::cl::parser<OpPassManager>::printOptionDiff( |
| 497 | const Option &opt, OpPassManager &pm, const OptVal &defaultValue, |
| 498 | size_t globalWidth) const { |
| 499 | printOptionName(O: opt, GlobalWidth: globalWidth); |
| 500 | outs() << "= " ; |
| 501 | pm.printAsTextualPipeline(os&: outs()); |
| 502 | |
| 503 | if (defaultValue.hasValue()) { |
| 504 | outs().indent(NumSpaces: 2) << " (default: " ; |
| 505 | defaultValue.getValue().printAsTextualPipeline(os&: outs()); |
| 506 | outs() << ")" ; |
| 507 | } |
| 508 | outs() << "\n" ; |
| 509 | } |
| 510 | |
| 511 | void llvm::cl::parser<OpPassManager>::anchor() {} |
| 512 | |
| 513 | llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() = |
| 514 | default; |
| 515 | llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager( |
| 516 | ParsedPassManager &&) = default; |
| 517 | llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() = |
| 518 | default; |
| 519 | |
| 520 | //===----------------------------------------------------------------------===// |
| 521 | // TextualPassPipeline Parser |
| 522 | //===----------------------------------------------------------------------===// |
| 523 | |
| 524 | namespace { |
| 525 | /// This class represents a textual description of a pass pipeline. |
| 526 | class TextualPipeline { |
| 527 | public: |
| 528 | /// Try to initialize this pipeline with the given pipeline text. |
| 529 | /// `errorStream` is the output stream to emit errors to. |
| 530 | LogicalResult initialize(StringRef text, raw_ostream &errorStream); |
| 531 | |
| 532 | /// Add the internal pipeline elements to the provided pass manager. |
| 533 | LogicalResult |
| 534 | addToPipeline(OpPassManager &pm, |
| 535 | function_ref<LogicalResult(const Twine &)> errorHandler) const; |
| 536 | |
| 537 | private: |
| 538 | /// A functor used to emit errors found during pipeline handling. The first |
| 539 | /// parameter corresponds to the raw location within the pipeline string. This |
| 540 | /// should always return failure. |
| 541 | using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>; |
| 542 | |
| 543 | /// A struct to capture parsed pass pipeline names. |
| 544 | /// |
| 545 | /// A pipeline is defined as a series of names, each of which may in itself |
| 546 | /// recursively contain a nested pipeline. A name is either the name of a pass |
| 547 | /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If |
| 548 | /// the name is the name of a pass, the InnerPipeline is empty, since passes |
| 549 | /// cannot contain inner pipelines. |
| 550 | struct PipelineElement { |
| 551 | PipelineElement(StringRef name) : name(name) {} |
| 552 | |
| 553 | StringRef name; |
| 554 | StringRef options; |
| 555 | const PassRegistryEntry *registryEntry = nullptr; |
| 556 | std::vector<PipelineElement> innerPipeline; |
| 557 | }; |
| 558 | |
| 559 | /// Parse the given pipeline text into the internal pipeline vector. This |
| 560 | /// function only parses the structure of the pipeline, and does not resolve |
| 561 | /// its elements. |
| 562 | LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler); |
| 563 | |
| 564 | /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to |
| 565 | /// the corresponding registry entry. |
| 566 | LogicalResult |
| 567 | resolvePipelineElements(MutableArrayRef<PipelineElement> elements, |
| 568 | ErrorHandlerT errorHandler); |
| 569 | |
| 570 | /// Resolve a single element of the pipeline. |
| 571 | LogicalResult resolvePipelineElement(PipelineElement &element, |
| 572 | ErrorHandlerT errorHandler); |
| 573 | |
| 574 | /// Add the given pipeline elements to the provided pass manager. |
| 575 | LogicalResult |
| 576 | addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm, |
| 577 | function_ref<LogicalResult(const Twine &)> errorHandler) const; |
| 578 | |
| 579 | std::vector<PipelineElement> pipeline; |
| 580 | }; |
| 581 | |
| 582 | } // namespace |
| 583 | |
| 584 | /// Try to initialize this pipeline with the given pipeline text. An option is |
| 585 | /// given to enable accurate error reporting. |
| 586 | LogicalResult TextualPipeline::initialize(StringRef text, |
| 587 | raw_ostream &errorStream) { |
| 588 | if (text.empty()) |
| 589 | return success(); |
| 590 | |
| 591 | // Build a source manager to use for error reporting. |
| 592 | llvm::SourceMgr pipelineMgr; |
| 593 | pipelineMgr.AddNewSourceBuffer( |
| 594 | F: llvm::MemoryBuffer::getMemBuffer(InputData: text, BufferName: "MLIR Textual PassPipeline Parser" , |
| 595 | /*RequiresNullTerminator=*/false), |
| 596 | IncludeLoc: SMLoc()); |
| 597 | auto errorHandler = [&](const char *rawLoc, Twine msg) { |
| 598 | pipelineMgr.PrintMessage(OS&: errorStream, Loc: SMLoc::getFromPointer(Ptr: rawLoc), |
| 599 | Kind: llvm::SourceMgr::DK_Error, Msg: msg); |
| 600 | return failure(); |
| 601 | }; |
| 602 | |
| 603 | // Parse the provided pipeline string. |
| 604 | if (failed(Result: parsePipelineText(text, errorHandler))) |
| 605 | return failure(); |
| 606 | return resolvePipelineElements(elements: pipeline, errorHandler); |
| 607 | } |
| 608 | |
| 609 | /// Add the internal pipeline elements to the provided pass manager. |
| 610 | LogicalResult TextualPipeline::addToPipeline( |
| 611 | OpPassManager &pm, |
| 612 | function_ref<LogicalResult(const Twine &)> errorHandler) const { |
| 613 | // Temporarily disable implicit nesting while we append to the pipeline. We |
| 614 | // want the created pipeline to exactly match the parsed text pipeline, so |
| 615 | // it's preferrable to just error out if implicit nesting would be required. |
| 616 | OpPassManager::Nesting nesting = pm.getNesting(); |
| 617 | pm.setNesting(OpPassManager::Nesting::Explicit); |
| 618 | auto restore = llvm::make_scope_exit(F: [&]() { pm.setNesting(nesting); }); |
| 619 | |
| 620 | return addToPipeline(elements: pipeline, pm, errorHandler); |
| 621 | } |
| 622 | |
| 623 | /// Parse the given pipeline text into the internal pipeline vector. This |
| 624 | /// function only parses the structure of the pipeline, and does not resolve |
| 625 | /// its elements. |
| 626 | LogicalResult TextualPipeline::parsePipelineText(StringRef text, |
| 627 | ErrorHandlerT errorHandler) { |
| 628 | SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline}; |
| 629 | for (;;) { |
| 630 | std::vector<PipelineElement> &pipeline = *pipelineStack.back(); |
| 631 | size_t pos = text.find_first_of(Chars: ",(){" ); |
| 632 | pipeline.emplace_back(/*name=*/args: text.substr(Start: 0, N: pos).trim()); |
| 633 | |
| 634 | // If we have a single terminating name, we're done. |
| 635 | if (pos == StringRef::npos) |
| 636 | break; |
| 637 | |
| 638 | text = text.substr(Start: pos); |
| 639 | char sep = text[0]; |
| 640 | |
| 641 | // Handle pulling ... from 'pass{...}' out as PipelineElement.options. |
| 642 | if (sep == '{') { |
| 643 | text = text.substr(Start: 1); |
| 644 | |
| 645 | // Skip over everything until the closing '}' and store as options. |
| 646 | size_t close = StringRef::npos; |
| 647 | for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) { |
| 648 | if (text[i] == '{') { |
| 649 | ++braceCount; |
| 650 | continue; |
| 651 | } |
| 652 | if (text[i] == '}' && --braceCount == 0) { |
| 653 | close = i; |
| 654 | break; |
| 655 | } |
| 656 | } |
| 657 | |
| 658 | // Check to see if a closing options brace was found. |
| 659 | if (close == StringRef::npos) { |
| 660 | return errorHandler( |
| 661 | /*rawLoc=*/text.data() - 1, |
| 662 | "missing closing '}' while processing pass options" ); |
| 663 | } |
| 664 | pipeline.back().options = text.substr(Start: 0, N: close); |
| 665 | text = text.substr(Start: close + 1); |
| 666 | |
| 667 | // Consume space characters that an user might add for readability. |
| 668 | text = text.ltrim(); |
| 669 | |
| 670 | // Skip checking for '(' because nested pipelines cannot have options. |
| 671 | } else if (sep == '(') { |
| 672 | text = text.substr(Start: 1); |
| 673 | |
| 674 | // Push the inner pipeline onto the stack to continue processing. |
| 675 | pipelineStack.push_back(Elt: &pipeline.back().innerPipeline); |
| 676 | continue; |
| 677 | } |
| 678 | |
| 679 | // When handling the close parenthesis, we greedily consume them to avoid |
| 680 | // empty strings in the pipeline. |
| 681 | while (text.consume_front(Prefix: ")" )) { |
| 682 | // If we try to pop the outer pipeline we have unbalanced parentheses. |
| 683 | if (pipelineStack.size() == 1) |
| 684 | return errorHandler(/*rawLoc=*/text.data() - 1, |
| 685 | "encountered extra closing ')' creating unbalanced " |
| 686 | "parentheses while parsing pipeline" ); |
| 687 | |
| 688 | pipelineStack.pop_back(); |
| 689 | // Consume space characters that an user might add for readability. |
| 690 | text = text.ltrim(); |
| 691 | } |
| 692 | |
| 693 | // Check if we've finished parsing. |
| 694 | if (text.empty()) |
| 695 | break; |
| 696 | |
| 697 | // Otherwise, the end of an inner pipeline always has to be followed by |
| 698 | // a comma, and then we can continue. |
| 699 | if (!text.consume_front(Prefix: "," )) |
| 700 | return errorHandler(text.data(), "expected ',' after parsing pipeline" ); |
| 701 | } |
| 702 | |
| 703 | // Check for unbalanced parentheses. |
| 704 | if (pipelineStack.size() > 1) |
| 705 | return errorHandler( |
| 706 | text.data(), |
| 707 | "encountered unbalanced parentheses while parsing pipeline" ); |
| 708 | |
| 709 | assert(pipelineStack.back() == &pipeline && |
| 710 | "wrong pipeline at the bottom of the stack" ); |
| 711 | return success(); |
| 712 | } |
| 713 | |
| 714 | /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to |
| 715 | /// the corresponding registry entry. |
| 716 | LogicalResult TextualPipeline::resolvePipelineElements( |
| 717 | MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) { |
| 718 | for (auto &elt : elements) |
| 719 | if (failed(Result: resolvePipelineElement(element&: elt, errorHandler))) |
| 720 | return failure(); |
| 721 | return success(); |
| 722 | } |
| 723 | |
| 724 | /// Resolve a single element of the pipeline. |
| 725 | LogicalResult |
| 726 | TextualPipeline::resolvePipelineElement(PipelineElement &element, |
| 727 | ErrorHandlerT errorHandler) { |
| 728 | // If the inner pipeline of this element is not empty, this is an operation |
| 729 | // pipeline. |
| 730 | if (!element.innerPipeline.empty()) |
| 731 | return resolvePipelineElements(elements: element.innerPipeline, errorHandler); |
| 732 | |
| 733 | // Otherwise, this must be a pass or pass pipeline. |
| 734 | // Check to see if a pipeline was registered with this name. |
| 735 | if ((element.registryEntry = PassPipelineInfo::lookup(pipelineArg: element.name))) |
| 736 | return success(); |
| 737 | |
| 738 | // If not, then this must be a specific pass name. |
| 739 | if ((element.registryEntry = PassInfo::lookup(passArg: element.name))) |
| 740 | return success(); |
| 741 | |
| 742 | // Emit an error for the unknown pass. |
| 743 | auto *rawLoc = element.name.data(); |
| 744 | return errorHandler(rawLoc, "'" + element.name + |
| 745 | "' does not refer to a " |
| 746 | "registered pass or pass pipeline" ); |
| 747 | } |
| 748 | |
| 749 | /// Add the given pipeline elements to the provided pass manager. |
| 750 | LogicalResult TextualPipeline::addToPipeline( |
| 751 | ArrayRef<PipelineElement> elements, OpPassManager &pm, |
| 752 | function_ref<LogicalResult(const Twine &)> errorHandler) const { |
| 753 | for (auto &elt : elements) { |
| 754 | if (elt.registryEntry) { |
| 755 | if (failed(Result: elt.registryEntry->addToPipeline(pm, options: elt.options, |
| 756 | errorHandler))) { |
| 757 | return errorHandler("failed to add `" + elt.name + "` with options `" + |
| 758 | elt.options + "`" ); |
| 759 | } |
| 760 | } else if (failed(Result: addToPipeline(elements: elt.innerPipeline, pm&: pm.nest(nestedName: elt.name), |
| 761 | errorHandler))) { |
| 762 | return errorHandler("failed to add `" + elt.name + "` with options `" + |
| 763 | elt.options + "` to inner pipeline" ); |
| 764 | } |
| 765 | } |
| 766 | return success(); |
| 767 | } |
| 768 | |
| 769 | LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, |
| 770 | raw_ostream &errorStream) { |
| 771 | TextualPipeline pipelineParser; |
| 772 | if (failed(Result: pipelineParser.initialize(text: pipeline, errorStream))) |
| 773 | return failure(); |
| 774 | auto errorHandler = [&](Twine msg) { |
| 775 | errorStream << msg << "\n" ; |
| 776 | return failure(); |
| 777 | }; |
| 778 | if (failed(Result: pipelineParser.addToPipeline(pm, errorHandler))) |
| 779 | return failure(); |
| 780 | return success(); |
| 781 | } |
| 782 | |
| 783 | FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline, |
| 784 | raw_ostream &errorStream) { |
| 785 | pipeline = pipeline.trim(); |
| 786 | // Pipelines are expected to be of the form `<op-name>(<pipeline>)`. |
| 787 | size_t pipelineStart = pipeline.find_first_of(C: '('); |
| 788 | if (pipelineStart == 0 || pipelineStart == StringRef::npos || |
| 789 | !pipeline.consume_back(Suffix: ")" )) { |
| 790 | errorStream << "expected pass pipeline to be wrapped with the anchor " |
| 791 | "operation type, e.g. 'builtin.module(...)'" ; |
| 792 | return failure(); |
| 793 | } |
| 794 | |
| 795 | StringRef opName = pipeline.take_front(N: pipelineStart).rtrim(); |
| 796 | OpPassManager pm(opName); |
| 797 | if (failed(Result: parsePassPipeline(pipeline: pipeline.drop_front(N: 1 + pipelineStart), pm, |
| 798 | errorStream))) |
| 799 | return failure(); |
| 800 | return pm; |
| 801 | } |
| 802 | |
| 803 | //===----------------------------------------------------------------------===// |
| 804 | // PassNameParser |
| 805 | //===----------------------------------------------------------------------===// |
| 806 | |
| 807 | namespace { |
| 808 | /// This struct represents the possible data entries in a parsed pass pipeline |
| 809 | /// list. |
| 810 | struct PassArgData { |
| 811 | PassArgData() = default; |
| 812 | PassArgData(const PassRegistryEntry *registryEntry) |
| 813 | : registryEntry(registryEntry) {} |
| 814 | |
| 815 | /// This field is used when the parsed option corresponds to a registered pass |
| 816 | /// or pass pipeline. |
| 817 | const PassRegistryEntry *registryEntry{nullptr}; |
| 818 | |
| 819 | /// This field is set when instance specific pass options have been provided |
| 820 | /// on the command line. |
| 821 | StringRef options; |
| 822 | }; |
| 823 | } // namespace |
| 824 | |
| 825 | namespace llvm { |
| 826 | namespace cl { |
| 827 | /// Define a valid OptionValue for the command line pass argument. |
| 828 | template <> |
| 829 | struct OptionValue<PassArgData> final |
| 830 | : OptionValueBase<PassArgData, /*isClass=*/true> { |
| 831 | OptionValue(const PassArgData &value) { this->setValue(value); } |
| 832 | OptionValue() = default; |
| 833 | void anchor() override {} |
| 834 | |
| 835 | bool hasValue() const { return true; } |
| 836 | const PassArgData &getValue() const { return value; } |
| 837 | void setValue(const PassArgData &value) { this->value = value; } |
| 838 | |
| 839 | PassArgData value; |
| 840 | }; |
| 841 | } // namespace cl |
| 842 | } // namespace llvm |
| 843 | |
| 844 | namespace { |
| 845 | |
| 846 | /// The name for the command line option used for parsing the textual pass |
| 847 | /// pipeline. |
| 848 | #define PASS_PIPELINE_ARG "pass-pipeline" |
| 849 | |
| 850 | /// Adds command line option for each registered pass or pass pipeline, as well |
| 851 | /// as textual pass pipelines. |
| 852 | struct PassNameParser : public llvm::cl::parser<PassArgData> { |
| 853 | PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {} |
| 854 | |
| 855 | void initialize(); |
| 856 | void printOptionInfo(const llvm::cl::Option &opt, |
| 857 | size_t globalWidth) const override; |
| 858 | size_t getOptionWidth(const llvm::cl::Option &opt) const override; |
| 859 | bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, |
| 860 | PassArgData &value); |
| 861 | |
| 862 | /// If true, this parser only parses entries that correspond to a concrete |
| 863 | /// pass registry entry, and does not include pipeline entries or the options |
| 864 | /// for pass entries. |
| 865 | bool passNamesOnly = false; |
| 866 | }; |
| 867 | } // namespace |
| 868 | |
| 869 | void PassNameParser::initialize() { |
| 870 | llvm::cl::parser<PassArgData>::initialize(); |
| 871 | |
| 872 | /// Add the pass entries. |
| 873 | for (const auto &kv : *passRegistry) { |
| 874 | addLiteralOption(Name: kv.second.getPassArgument(), V: &kv.second, |
| 875 | HelpStr: kv.second.getPassDescription()); |
| 876 | } |
| 877 | /// Add the pass pipeline entries. |
| 878 | if (!passNamesOnly) { |
| 879 | for (const auto &kv : *passPipelineRegistry) { |
| 880 | addLiteralOption(Name: kv.second.getPassArgument(), V: &kv.second, |
| 881 | HelpStr: kv.second.getPassDescription()); |
| 882 | } |
| 883 | } |
| 884 | } |
| 885 | |
| 886 | void PassNameParser::printOptionInfo(const llvm::cl::Option &opt, |
| 887 | size_t globalWidth) const { |
| 888 | // If this parser is just parsing pass names, print a simplified option |
| 889 | // string. |
| 890 | if (passNamesOnly) { |
| 891 | llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>" ; |
| 892 | opt.printHelpStr(HelpStr: opt.HelpStr, Indent: globalWidth, FirstLineIndentedBy: opt.ArgStr.size() + 18); |
| 893 | return; |
| 894 | } |
| 895 | |
| 896 | // Print the information for the top-level option. |
| 897 | if (opt.hasArgStr()) { |
| 898 | llvm::outs() << " --" << opt.ArgStr; |
| 899 | opt.printHelpStr(HelpStr: opt.HelpStr, Indent: globalWidth, FirstLineIndentedBy: opt.ArgStr.size() + 7); |
| 900 | } else { |
| 901 | llvm::outs() << " " << opt.HelpStr << '\n'; |
| 902 | } |
| 903 | |
| 904 | // Functor used to print the ordered entries of a registration map. |
| 905 | auto printOrderedEntries = [&](StringRef , auto &map) { |
| 906 | llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries; |
| 907 | for (auto &kv : map) |
| 908 | orderedEntries.push_back(Elt: &kv.second); |
| 909 | llvm::array_pod_sort( |
| 910 | orderedEntries.begin(), orderedEntries.end(), |
| 911 | [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { |
| 912 | return (*lhs)->getPassArgument().compare(RHS: (*rhs)->getPassArgument()); |
| 913 | }); |
| 914 | |
| 915 | llvm::outs().indent(NumSpaces: 4) << header << ":\n" ; |
| 916 | for (PassRegistryEntry *entry : orderedEntries) |
| 917 | entry->printHelpStr(/*indent=*/6, descIndent: globalWidth); |
| 918 | }; |
| 919 | |
| 920 | // Print the available passes. |
| 921 | printOrderedEntries("Passes" , *passRegistry); |
| 922 | |
| 923 | // Print the available pass pipelines. |
| 924 | if (!passPipelineRegistry->empty()) |
| 925 | printOrderedEntries("Pass Pipelines" , *passPipelineRegistry); |
| 926 | } |
| 927 | |
| 928 | size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const { |
| 929 | size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(O: opt) + 2; |
| 930 | |
| 931 | // Check for any wider pass or pipeline options. |
| 932 | for (auto &entry : *passRegistry) |
| 933 | maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4); |
| 934 | for (auto &entry : *passPipelineRegistry) |
| 935 | maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4); |
| 936 | return maxWidth; |
| 937 | } |
| 938 | |
| 939 | bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName, |
| 940 | StringRef arg, PassArgData &value) { |
| 941 | if (llvm::cl::parser<PassArgData>::parse(O&: opt, ArgName: argName, Arg: arg, V&: value)) |
| 942 | return true; |
| 943 | value.options = arg; |
| 944 | return false; |
| 945 | } |
| 946 | |
| 947 | //===----------------------------------------------------------------------===// |
| 948 | // PassPipelineCLParser |
| 949 | //===----------------------------------------------------------------------===// |
| 950 | |
| 951 | namespace mlir { |
| 952 | namespace detail { |
| 953 | struct PassPipelineCLParserImpl { |
| 954 | PassPipelineCLParserImpl(StringRef arg, StringRef description, |
| 955 | bool passNamesOnly) |
| 956 | : passList(arg, llvm::cl::desc(description)) { |
| 957 | passList.getParser().passNamesOnly = passNamesOnly; |
| 958 | passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); |
| 959 | } |
| 960 | |
| 961 | /// Returns true if the given pass registry entry was registered at the |
| 962 | /// top-level of the parser, i.e. not within an explicit textual pipeline. |
| 963 | bool contains(const PassRegistryEntry *entry) const { |
| 964 | return llvm::any_of(Range: passList, P: [&](const PassArgData &data) { |
| 965 | return data.registryEntry == entry; |
| 966 | }); |
| 967 | } |
| 968 | |
| 969 | /// The set of passes and pass pipelines to run. |
| 970 | llvm::cl::list<PassArgData, bool, PassNameParser> passList; |
| 971 | }; |
| 972 | } // namespace detail |
| 973 | } // namespace mlir |
| 974 | |
| 975 | /// Construct a pass pipeline parser with the given command line description. |
| 976 | PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) |
| 977 | : impl(std::make_unique<detail::PassPipelineCLParserImpl>( |
| 978 | args&: arg, args&: description, /*passNamesOnly=*/args: false)), |
| 979 | passPipeline( |
| 980 | PASS_PIPELINE_ARG, |
| 981 | llvm::cl::desc("Textual description of the pass pipeline to run" )) {} |
| 982 | |
| 983 | PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description, |
| 984 | StringRef alias) |
| 985 | : PassPipelineCLParser(arg, description) { |
| 986 | passPipelineAlias.emplace(args&: alias, |
| 987 | args: llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG), |
| 988 | args: llvm::cl::aliasopt(passPipeline)); |
| 989 | } |
| 990 | |
| 991 | PassPipelineCLParser::~PassPipelineCLParser() = default; |
| 992 | |
| 993 | /// Returns true if this parser contains any valid options to add. |
| 994 | bool PassPipelineCLParser::hasAnyOccurrences() const { |
| 995 | return passPipeline.getNumOccurrences() != 0 || |
| 996 | impl->passList.getNumOccurrences() != 0; |
| 997 | } |
| 998 | |
| 999 | /// Returns true if the given pass registry entry was registered at the |
| 1000 | /// top-level of the parser, i.e. not within an explicit textual pipeline. |
| 1001 | bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { |
| 1002 | return impl->contains(entry); |
| 1003 | } |
| 1004 | |
| 1005 | /// Adds the passes defined by this parser entry to the given pass manager. |
| 1006 | LogicalResult PassPipelineCLParser::addToPipeline( |
| 1007 | OpPassManager &pm, |
| 1008 | function_ref<LogicalResult(const Twine &)> errorHandler) const { |
| 1009 | if (passPipeline.getNumOccurrences()) { |
| 1010 | if (impl->passList.getNumOccurrences()) |
| 1011 | return errorHandler( |
| 1012 | "'-" PASS_PIPELINE_ARG |
| 1013 | "' option can't be used with individual pass options" ); |
| 1014 | std::string errMsg; |
| 1015 | llvm::raw_string_ostream os(errMsg); |
| 1016 | FailureOr<OpPassManager> parsed = parsePassPipeline(pipeline: passPipeline, errorStream&: os); |
| 1017 | if (failed(Result: parsed)) |
| 1018 | return errorHandler(errMsg); |
| 1019 | pm = std::move(*parsed); |
| 1020 | return success(); |
| 1021 | } |
| 1022 | |
| 1023 | for (auto &passIt : impl->passList) { |
| 1024 | if (failed(Result: passIt.registryEntry->addToPipeline(pm, options: passIt.options, |
| 1025 | errorHandler))) |
| 1026 | return failure(); |
| 1027 | } |
| 1028 | return success(); |
| 1029 | } |
| 1030 | |
| 1031 | //===----------------------------------------------------------------------===// |
| 1032 | // PassNameCLParser |
| 1033 | //===----------------------------------------------------------------------===// |
| 1034 | |
| 1035 | /// Construct a pass pipeline parser with the given command line description. |
| 1036 | PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description) |
| 1037 | : impl(std::make_unique<detail::PassPipelineCLParserImpl>( |
| 1038 | args&: arg, args&: description, /*passNamesOnly=*/args: true)) { |
| 1039 | impl->passList.setMiscFlag(llvm::cl::CommaSeparated); |
| 1040 | } |
| 1041 | PassNameCLParser::~PassNameCLParser() = default; |
| 1042 | |
| 1043 | /// Returns true if this parser contains any valid options to add. |
| 1044 | bool PassNameCLParser::hasAnyOccurrences() const { |
| 1045 | return impl->passList.getNumOccurrences() != 0; |
| 1046 | } |
| 1047 | |
| 1048 | /// Returns true if the given pass registry entry was registered at the |
| 1049 | /// top-level of the parser, i.e. not within an explicit textual pipeline. |
| 1050 | bool PassNameCLParser::contains(const PassRegistryEntry *entry) const { |
| 1051 | return impl->contains(entry); |
| 1052 | } |
| 1053 | |