1//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Pass/PassRegistry.h"
10
11#include "mlir/Pass/Pass.h"
12#include "mlir/Pass/PassManager.h"
13#include "llvm/ADT/DenseMap.h"
14#include "llvm/ADT/ScopeExit.h"
15#include "llvm/ADT/StringRef.h"
16#include "llvm/Support/Format.h"
17#include "llvm/Support/ManagedStatic.h"
18#include "llvm/Support/MemoryBuffer.h"
19#include "llvm/Support/SourceMgr.h"
20
21#include <optional>
22#include <utility>
23
24using namespace mlir;
25using namespace detail;
26
27/// Static mapping of all of the registered passes.
28static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
29
30/// A mapping of the above pass registry entries to the corresponding TypeID
31/// of the pass that they generate.
32static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
33
34/// Static mapping of all of the registered pass pipelines.
35static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
36 passPipelineRegistry;
37
38/// Utility to create a default registry function from a pass instance.
39static PassRegistryFunction
40buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
41 return [=](OpPassManager &pm, StringRef options,
42 function_ref<LogicalResult(const Twine &)> errorHandler) {
43 std::unique_ptr<Pass> pass = allocator();
44 LogicalResult result = pass->initializeOptions(options, errorHandler);
45
46 std::optional<StringRef> pmOpName = pm.getOpName();
47 std::optional<StringRef> passOpName = pass->getOpName();
48 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
49 passOpName && *pmOpName != *passOpName) {
50 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
51 "' restricted to '" + *pass->getOpName() +
52 "' on a PassManager intended to run on '" +
53 pm.getOpAnchorName() + "', did you intend to nest?");
54 }
55 pm.addPass(pass: std::move(pass));
56 return result;
57 };
58}
59
60/// Utility to print the help string for a specific option.
61static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
62 size_t descIndent, bool isTopLevel) {
63 size_t numSpaces = descIndent - indent - 4;
64 llvm::outs().indent(NumSpaces: indent)
65 << "--" << llvm::left_justify(Str: arg, Width: numSpaces) << "- " << desc << '\n';
66}
67
68//===----------------------------------------------------------------------===//
69// PassRegistry
70//===----------------------------------------------------------------------===//
71
72/// Prints the passes that were previously registered and stored in passRegistry
73void mlir::printRegisteredPasses() {
74 size_t maxWidth = 0;
75 for (auto &entry : *passRegistry)
76 maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4);
77
78 // Functor used to print the ordered entries of a registration map.
79 auto printOrderedEntries = [&](StringRef header, auto &map) {
80 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
81 for (auto &kv : map)
82 orderedEntries.push_back(Elt: &kv.second);
83 llvm::array_pod_sort(
84 orderedEntries.begin(), orderedEntries.end(),
85 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
86 return (*lhs)->getPassArgument().compare(RHS: (*rhs)->getPassArgument());
87 });
88
89 llvm::outs().indent(NumSpaces: 0) << header << ":\n";
90 for (PassRegistryEntry *entry : orderedEntries)
91 entry->printHelpStr(/*indent=*/2, descIndent: maxWidth);
92 };
93
94 // Print the available passes.
95 printOrderedEntries("Passes", *passRegistry);
96}
97
98/// Print the help information for this pass. This includes the argument,
99/// description, and any pass options. `descIndent` is the indent that the
100/// descriptions should be aligned.
101void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
102 printOptionHelp(arg: getPassArgument(), desc: getPassDescription(), indent, descIndent,
103 /*isTopLevel=*/true);
104 // If this entry has options, print the help for those as well.
105 optHandler([=](const PassOptions &options) {
106 options.printHelp(indent, descIndent);
107 });
108}
109
110/// Return the maximum width required when printing the options of this
111/// entry.
112size_t PassRegistryEntry::getOptionWidth() const {
113 size_t maxLen = 0;
114 optHandler([&](const PassOptions &options) mutable {
115 maxLen = options.getOptionWidth() + 2;
116 });
117 return maxLen;
118}
119
120//===----------------------------------------------------------------------===//
121// PassPipelineInfo
122//===----------------------------------------------------------------------===//
123
124void mlir::registerPassPipeline(
125 StringRef arg, StringRef description, const PassRegistryFunction &function,
126 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
127 PassPipelineInfo pipelineInfo(arg, description, function,
128 std::move(optHandler));
129 bool inserted = passPipelineRegistry->try_emplace(Key: arg, Args&: pipelineInfo).second;
130#ifndef NDEBUG
131 if (!inserted)
132 report_fatal_error(reason: "Pass pipeline " + arg + " registered multiple times");
133#endif
134 (void)inserted;
135}
136
137//===----------------------------------------------------------------------===//
138// PassInfo
139//===----------------------------------------------------------------------===//
140
141PassInfo::PassInfo(StringRef arg, StringRef description,
142 const PassAllocatorFunction &allocator)
143 : PassRegistryEntry(
144 arg, description, buildDefaultRegistryFn(allocator),
145 // Use a temporary pass to provide an options instance.
146 [=](function_ref<void(const PassOptions &)> optHandler) {
147 optHandler(allocator()->passOptions);
148 }) {}
149
150void mlir::registerPass(const PassAllocatorFunction &function) {
151 std::unique_ptr<Pass> pass = function();
152 StringRef arg = pass->getArgument();
153 if (arg.empty())
154 llvm::report_fatal_error(reason: llvm::Twine("Trying to register '") +
155 pass->getName() +
156 "' pass that does not override `getArgument()`");
157 StringRef description = pass->getDescription();
158 PassInfo passInfo(arg, description, function);
159 passRegistry->try_emplace(Key: arg, Args&: passInfo);
160
161 // Verify that the registered pass has the same ID as any registered to this
162 // arg before it.
163 TypeID entryTypeID = pass->getTypeID();
164 auto it = passRegistryTypeIDs->try_emplace(Key: arg, Args&: entryTypeID).first;
165 if (it->second != entryTypeID)
166 llvm::report_fatal_error(
167 reason: "pass allocator creates a different pass than previously "
168 "registered for pass " +
169 arg);
170}
171
172/// Returns the pass info for the specified pass argument or null if unknown.
173const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
174 auto it = passRegistry->find(Key: passArg);
175 return it == passRegistry->end() ? nullptr : &it->second;
176}
177
178/// Returns the pass pipeline info for the specified pass pipeline argument or
179/// null if unknown.
180const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
181 auto it = passPipelineRegistry->find(Key: pipelineArg);
182 return it == passPipelineRegistry->end() ? nullptr : &it->second;
183}
184
185//===----------------------------------------------------------------------===//
186// PassOptions
187//===----------------------------------------------------------------------===//
188
189/// Attempt to find the next occurance of character 'c' in the string starting
190/// from the `index`-th position , omitting any occurances that appear within
191/// intervening ranges or literals.
192static size_t findChar(StringRef str, size_t index, char c) {
193 for (size_t i = index, e = str.size(); i < e; ++i) {
194 if (str[i] == c)
195 return i;
196 // Check for various range characters.
197 if (str[i] == '{')
198 i = findChar(str, index: i + 1, c: '}');
199 else if (str[i] == '(')
200 i = findChar(str, index: i + 1, c: ')');
201 else if (str[i] == '[')
202 i = findChar(str, index: i + 1, c: ']');
203 else if (str[i] == '\"')
204 i = str.find_first_of(C: '\"', From: i + 1);
205 else if (str[i] == '\'')
206 i = str.find_first_of(C: '\'', From: i + 1);
207 if (i == StringRef::npos)
208 return StringRef::npos;
209 }
210 return StringRef::npos;
211}
212
213/// Extract an argument from 'options' and update it to point after the arg.
214/// Returns the cleaned argument string.
215static StringRef extractArgAndUpdateOptions(StringRef &options,
216 size_t argSize) {
217 StringRef str = options.take_front(N: argSize).trim();
218 options = options.drop_front(N: argSize).ltrim();
219
220 // Early exit if there's no escape sequence.
221 if (str.size() <= 1)
222 return str;
223
224 const auto escapePairs = {std::make_pair(x: '\'', y: '\''),
225 std::make_pair(x: '"', y: '"')};
226 for (const auto &escape : escapePairs) {
227 if (str.front() == escape.first && str.back() == escape.second) {
228 // Drop the escape characters and trim.
229 // Don't process additional escape sequences.
230 return str.drop_front().drop_back().trim();
231 }
232 }
233
234 // Arguments may be wrapped in `{...}`. Unlike the quotation markers that
235 // denote literals, we respect scoping here. The outer `{...}` should not
236 // be stripped in cases such as "arg={...},{...}", which can be used to denote
237 // lists of nested option structs.
238 if (str.front() == '{') {
239 unsigned match = findChar(str, index: 1, c: '}');
240 if (match == str.size() - 1)
241 str = str.drop_front().drop_back().trim();
242 }
243
244 return str;
245}
246
247LogicalResult detail::pass_options::parseCommaSeparatedList(
248 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
249 function_ref<LogicalResult(StringRef)> elementParseFn) {
250 if (optionStr.empty())
251 return success();
252
253 size_t nextElePos = findChar(str: optionStr, index: 0, c: ',');
254 while (nextElePos != StringRef::npos) {
255 // Process the portion before the comma.
256 if (failed(
257 Result: elementParseFn(extractArgAndUpdateOptions(options&: optionStr, argSize: nextElePos))))
258 return failure();
259
260 // Drop the leading ','
261 optionStr = optionStr.drop_front();
262 nextElePos = findChar(str: optionStr, index: 0, c: ',');
263 }
264 return elementParseFn(
265 extractArgAndUpdateOptions(options&: optionStr, argSize: optionStr.size()));
266}
267
268/// Out of line virtual function to provide home for the class.
269void detail::PassOptions::OptionBase::anchor() {}
270
271/// Copy the option values from 'other'.
272void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
273 assert(options.size() == other.options.size());
274 if (options.empty())
275 return;
276 for (auto optionsIt : llvm::zip(t&: options, u: other.options))
277 std::get<0>(t&: optionsIt)->copyValueFrom(other: *std::get<1>(t&: optionsIt));
278}
279
280/// Parse in the next argument from the given options string. Returns a tuple
281/// containing [the key of the option, the value of the option, updated
282/// `options` string pointing after the parsed option].
283static std::tuple<StringRef, StringRef, StringRef>
284parseNextArg(StringRef options) {
285 // Try to process the given punctuation, properly escaping any contained
286 // characters.
287 auto tryProcessPunct = [&](size_t &currentPos, char punct) {
288 if (options[currentPos] != punct)
289 return false;
290 size_t nextIt = options.find_first_of(C: punct, From: currentPos + 1);
291 if (nextIt != StringRef::npos)
292 currentPos = nextIt;
293 return true;
294 };
295
296 // Parse the argument name of the option.
297 StringRef argName;
298 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
299 // Check for the end of the full option.
300 if (argEndIt == optionsE || options[argEndIt] == ' ') {
301 argName = extractArgAndUpdateOptions(options, argSize: argEndIt);
302 return std::make_tuple(args&: argName, args: StringRef(), args&: options);
303 }
304
305 // Check for the end of the name and the start of the value.
306 if (options[argEndIt] == '=') {
307 argName = extractArgAndUpdateOptions(options, argSize: argEndIt);
308 options = options.drop_front();
309 break;
310 }
311 }
312
313 // Parse the value of the option.
314 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
315 // Handle the end of the options string.
316 if (argEndIt == optionsE || options[argEndIt] == ' ') {
317 StringRef value = extractArgAndUpdateOptions(options, argSize: argEndIt);
318 return std::make_tuple(args&: argName, args&: value, args&: options);
319 }
320
321 // Skip over escaped sequences.
322 char c = options[argEndIt];
323 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
324 continue;
325 // '{...}' is used to specify options to passes, properly escape it so
326 // that we don't accidentally split any nested options.
327 if (c == '{') {
328 size_t braceCount = 1;
329 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
330 // Allow nested punctuation.
331 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
332 continue;
333 if (options[argEndIt] == '{')
334 ++braceCount;
335 else if (options[argEndIt] == '}' && --braceCount == 0)
336 break;
337 }
338 // Account for the increment at the top of the loop.
339 --argEndIt;
340 }
341 }
342 llvm_unreachable("unexpected control flow in pass option parsing");
343}
344
345LogicalResult detail::PassOptions::parseFromString(StringRef options,
346 raw_ostream &errorStream) {
347 // NOTE: `options` is modified in place to always refer to the unprocessed
348 // part of the string.
349 while (!options.empty()) {
350 StringRef key, value;
351 std::tie(args&: key, args&: value, args&: options) = parseNextArg(options);
352 if (key.empty())
353 continue;
354
355 auto it = OptionsMap.find(Key: key);
356 if (it == OptionsMap.end()) {
357 errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
358 return failure();
359 }
360 if (llvm::cl::ProvidePositionalOption(Handler: it->second, Arg: value, i: 0))
361 return failure();
362 }
363
364 return success();
365}
366
367/// Print the options held by this struct in a form that can be parsed via
368/// 'parseFromString'.
369void detail::PassOptions::print(raw_ostream &os) const {
370 // If there are no options, there is nothing left to do.
371 if (OptionsMap.empty())
372 return;
373
374 // Sort the options to make the ordering deterministic.
375 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
376 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
377 return (*lhs)->getArgStr().compare(RHS: (*rhs)->getArgStr());
378 };
379 llvm::array_pod_sort(Start: orderedOps.begin(), End: orderedOps.end(), Compare: compareOptionArgs);
380
381 // Interleave the options with ' '.
382 os << '{';
383 llvm::interleave(
384 c: orderedOps, os, each_fn: [&](OptionBase *option) { option->print(os); }, separator: " ");
385 os << '}';
386}
387
388/// Print the help string for the options held by this struct. `descIndent` is
389/// the indent within the stream that the descriptions should be aligned.
390void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
391 // Sort the options to make the ordering deterministic.
392 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
393 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
394 return (*lhs)->getArgStr().compare(RHS: (*rhs)->getArgStr());
395 };
396 llvm::array_pod_sort(Start: orderedOps.begin(), End: orderedOps.end(), Compare: compareOptionArgs);
397 for (OptionBase *option : orderedOps) {
398 // TODO: printOptionInfo assumes a specific indent and will
399 // print options with values with incorrect indentation. We should add
400 // support to llvm::cl::Option for passing in a base indent to use when
401 // printing.
402 llvm::outs().indent(NumSpaces: indent);
403 option->getOption()->printOptionInfo(GlobalWidth: descIndent - indent);
404 }
405}
406
407/// Return the maximum width required when printing the help string.
408size_t detail::PassOptions::getOptionWidth() const {
409 size_t max = 0;
410 for (auto *option : options)
411 max = std::max(a: max, b: option->getOption()->getOptionWidth());
412 return max;
413}
414
415//===----------------------------------------------------------------------===//
416// MLIR Options
417//===----------------------------------------------------------------------===//
418
419//===----------------------------------------------------------------------===//
420// OpPassManager: OptionValue
421//===----------------------------------------------------------------------===//
422
423llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
424llvm::cl::OptionValue<OpPassManager>::OptionValue(
425 const mlir::OpPassManager &value) {
426 setValue(value);
427}
428llvm::cl::OptionValue<OpPassManager>::OptionValue(
429 const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) {
430 if (rhs.hasValue())
431 setValue(rhs.getValue());
432}
433llvm::cl::OptionValue<OpPassManager> &
434llvm::cl::OptionValue<OpPassManager>::operator=(
435 const mlir::OpPassManager &rhs) {
436 setValue(rhs);
437 return *this;
438}
439
440llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
441
442void llvm::cl::OptionValue<OpPassManager>::setValue(
443 const OpPassManager &newValue) {
444 if (hasValue())
445 *value = newValue;
446 else
447 value = std::make_unique<mlir::OpPassManager>(args: newValue);
448}
449void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
450 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipeline: pipelineStr);
451 assert(succeeded(pipeline) && "invalid pass pipeline");
452 setValue(*pipeline);
453}
454
455bool llvm::cl::OptionValue<OpPassManager>::compare(
456 const mlir::OpPassManager &rhs) const {
457 std::string lhsStr, rhsStr;
458 {
459 raw_string_ostream lhsStream(lhsStr);
460 value->printAsTextualPipeline(os&: lhsStream);
461
462 raw_string_ostream rhsStream(rhsStr);
463 rhs.printAsTextualPipeline(os&: rhsStream);
464 }
465
466 // Use the textual format for pipeline comparisons.
467 return lhsStr == rhsStr;
468}
469
470void llvm::cl::OptionValue<OpPassManager>::anchor() {}
471
472//===----------------------------------------------------------------------===//
473// OpPassManager: Parser
474//===----------------------------------------------------------------------===//
475
476namespace llvm {
477namespace cl {
478template class basic_parser<OpPassManager>;
479} // namespace cl
480} // namespace llvm
481
482bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
483 ParsedPassManager &value) {
484 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipeline: arg);
485 if (failed(Result: pipeline))
486 return true;
487 value.value = std::make_unique<OpPassManager>(args: std::move(*pipeline));
488 return false;
489}
490
491void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
492 const OpPassManager &value) {
493 value.printAsTextualPipeline(os);
494}
495
496void llvm::cl::parser<OpPassManager>::printOptionDiff(
497 const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
498 size_t globalWidth) const {
499 printOptionName(O: opt, GlobalWidth: globalWidth);
500 outs() << "= ";
501 pm.printAsTextualPipeline(os&: outs());
502
503 if (defaultValue.hasValue()) {
504 outs().indent(NumSpaces: 2) << " (default: ";
505 defaultValue.getValue().printAsTextualPipeline(os&: outs());
506 outs() << ")";
507 }
508 outs() << "\n";
509}
510
511void llvm::cl::parser<OpPassManager>::anchor() {}
512
513llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
514 default;
515llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
516 ParsedPassManager &&) = default;
517llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
518 default;
519
520//===----------------------------------------------------------------------===//
521// TextualPassPipeline Parser
522//===----------------------------------------------------------------------===//
523
524namespace {
525/// This class represents a textual description of a pass pipeline.
526class TextualPipeline {
527public:
528 /// Try to initialize this pipeline with the given pipeline text.
529 /// `errorStream` is the output stream to emit errors to.
530 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
531
532 /// Add the internal pipeline elements to the provided pass manager.
533 LogicalResult
534 addToPipeline(OpPassManager &pm,
535 function_ref<LogicalResult(const Twine &)> errorHandler) const;
536
537private:
538 /// A functor used to emit errors found during pipeline handling. The first
539 /// parameter corresponds to the raw location within the pipeline string. This
540 /// should always return failure.
541 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
542
543 /// A struct to capture parsed pass pipeline names.
544 ///
545 /// A pipeline is defined as a series of names, each of which may in itself
546 /// recursively contain a nested pipeline. A name is either the name of a pass
547 /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
548 /// the name is the name of a pass, the InnerPipeline is empty, since passes
549 /// cannot contain inner pipelines.
550 struct PipelineElement {
551 PipelineElement(StringRef name) : name(name) {}
552
553 StringRef name;
554 StringRef options;
555 const PassRegistryEntry *registryEntry = nullptr;
556 std::vector<PipelineElement> innerPipeline;
557 };
558
559 /// Parse the given pipeline text into the internal pipeline vector. This
560 /// function only parses the structure of the pipeline, and does not resolve
561 /// its elements.
562 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
563
564 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
565 /// the corresponding registry entry.
566 LogicalResult
567 resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
568 ErrorHandlerT errorHandler);
569
570 /// Resolve a single element of the pipeline.
571 LogicalResult resolvePipelineElement(PipelineElement &element,
572 ErrorHandlerT errorHandler);
573
574 /// Add the given pipeline elements to the provided pass manager.
575 LogicalResult
576 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
577 function_ref<LogicalResult(const Twine &)> errorHandler) const;
578
579 std::vector<PipelineElement> pipeline;
580};
581
582} // namespace
583
584/// Try to initialize this pipeline with the given pipeline text. An option is
585/// given to enable accurate error reporting.
586LogicalResult TextualPipeline::initialize(StringRef text,
587 raw_ostream &errorStream) {
588 if (text.empty())
589 return success();
590
591 // Build a source manager to use for error reporting.
592 llvm::SourceMgr pipelineMgr;
593 pipelineMgr.AddNewSourceBuffer(
594 F: llvm::MemoryBuffer::getMemBuffer(InputData: text, BufferName: "MLIR Textual PassPipeline Parser",
595 /*RequiresNullTerminator=*/false),
596 IncludeLoc: SMLoc());
597 auto errorHandler = [&](const char *rawLoc, Twine msg) {
598 pipelineMgr.PrintMessage(OS&: errorStream, Loc: SMLoc::getFromPointer(Ptr: rawLoc),
599 Kind: llvm::SourceMgr::DK_Error, Msg: msg);
600 return failure();
601 };
602
603 // Parse the provided pipeline string.
604 if (failed(Result: parsePipelineText(text, errorHandler)))
605 return failure();
606 return resolvePipelineElements(elements: pipeline, errorHandler);
607}
608
609/// Add the internal pipeline elements to the provided pass manager.
610LogicalResult TextualPipeline::addToPipeline(
611 OpPassManager &pm,
612 function_ref<LogicalResult(const Twine &)> errorHandler) const {
613 // Temporarily disable implicit nesting while we append to the pipeline. We
614 // want the created pipeline to exactly match the parsed text pipeline, so
615 // it's preferrable to just error out if implicit nesting would be required.
616 OpPassManager::Nesting nesting = pm.getNesting();
617 pm.setNesting(OpPassManager::Nesting::Explicit);
618 auto restore = llvm::make_scope_exit(F: [&]() { pm.setNesting(nesting); });
619
620 return addToPipeline(elements: pipeline, pm, errorHandler);
621}
622
623/// Parse the given pipeline text into the internal pipeline vector. This
624/// function only parses the structure of the pipeline, and does not resolve
625/// its elements.
626LogicalResult TextualPipeline::parsePipelineText(StringRef text,
627 ErrorHandlerT errorHandler) {
628 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
629 for (;;) {
630 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
631 size_t pos = text.find_first_of(Chars: ",(){");
632 pipeline.emplace_back(/*name=*/args: text.substr(Start: 0, N: pos).trim());
633
634 // If we have a single terminating name, we're done.
635 if (pos == StringRef::npos)
636 break;
637
638 text = text.substr(Start: pos);
639 char sep = text[0];
640
641 // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
642 if (sep == '{') {
643 text = text.substr(Start: 1);
644
645 // Skip over everything until the closing '}' and store as options.
646 size_t close = StringRef::npos;
647 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
648 if (text[i] == '{') {
649 ++braceCount;
650 continue;
651 }
652 if (text[i] == '}' && --braceCount == 0) {
653 close = i;
654 break;
655 }
656 }
657
658 // Check to see if a closing options brace was found.
659 if (close == StringRef::npos) {
660 return errorHandler(
661 /*rawLoc=*/text.data() - 1,
662 "missing closing '}' while processing pass options");
663 }
664 pipeline.back().options = text.substr(Start: 0, N: close);
665 text = text.substr(Start: close + 1);
666
667 // Consume space characters that an user might add for readability.
668 text = text.ltrim();
669
670 // Skip checking for '(' because nested pipelines cannot have options.
671 } else if (sep == '(') {
672 text = text.substr(Start: 1);
673
674 // Push the inner pipeline onto the stack to continue processing.
675 pipelineStack.push_back(Elt: &pipeline.back().innerPipeline);
676 continue;
677 }
678
679 // When handling the close parenthesis, we greedily consume them to avoid
680 // empty strings in the pipeline.
681 while (text.consume_front(Prefix: ")")) {
682 // If we try to pop the outer pipeline we have unbalanced parentheses.
683 if (pipelineStack.size() == 1)
684 return errorHandler(/*rawLoc=*/text.data() - 1,
685 "encountered extra closing ')' creating unbalanced "
686 "parentheses while parsing pipeline");
687
688 pipelineStack.pop_back();
689 // Consume space characters that an user might add for readability.
690 text = text.ltrim();
691 }
692
693 // Check if we've finished parsing.
694 if (text.empty())
695 break;
696
697 // Otherwise, the end of an inner pipeline always has to be followed by
698 // a comma, and then we can continue.
699 if (!text.consume_front(Prefix: ","))
700 return errorHandler(text.data(), "expected ',' after parsing pipeline");
701 }
702
703 // Check for unbalanced parentheses.
704 if (pipelineStack.size() > 1)
705 return errorHandler(
706 text.data(),
707 "encountered unbalanced parentheses while parsing pipeline");
708
709 assert(pipelineStack.back() == &pipeline &&
710 "wrong pipeline at the bottom of the stack");
711 return success();
712}
713
714/// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
715/// the corresponding registry entry.
716LogicalResult TextualPipeline::resolvePipelineElements(
717 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
718 for (auto &elt : elements)
719 if (failed(Result: resolvePipelineElement(element&: elt, errorHandler)))
720 return failure();
721 return success();
722}
723
724/// Resolve a single element of the pipeline.
725LogicalResult
726TextualPipeline::resolvePipelineElement(PipelineElement &element,
727 ErrorHandlerT errorHandler) {
728 // If the inner pipeline of this element is not empty, this is an operation
729 // pipeline.
730 if (!element.innerPipeline.empty())
731 return resolvePipelineElements(elements: element.innerPipeline, errorHandler);
732
733 // Otherwise, this must be a pass or pass pipeline.
734 // Check to see if a pipeline was registered with this name.
735 if ((element.registryEntry = PassPipelineInfo::lookup(pipelineArg: element.name)))
736 return success();
737
738 // If not, then this must be a specific pass name.
739 if ((element.registryEntry = PassInfo::lookup(passArg: element.name)))
740 return success();
741
742 // Emit an error for the unknown pass.
743 auto *rawLoc = element.name.data();
744 return errorHandler(rawLoc, "'" + element.name +
745 "' does not refer to a "
746 "registered pass or pass pipeline");
747}
748
749/// Add the given pipeline elements to the provided pass manager.
750LogicalResult TextualPipeline::addToPipeline(
751 ArrayRef<PipelineElement> elements, OpPassManager &pm,
752 function_ref<LogicalResult(const Twine &)> errorHandler) const {
753 for (auto &elt : elements) {
754 if (elt.registryEntry) {
755 if (failed(Result: elt.registryEntry->addToPipeline(pm, options: elt.options,
756 errorHandler))) {
757 return errorHandler("failed to add `" + elt.name + "` with options `" +
758 elt.options + "`");
759 }
760 } else if (failed(Result: addToPipeline(elements: elt.innerPipeline, pm&: pm.nest(nestedName: elt.name),
761 errorHandler))) {
762 return errorHandler("failed to add `" + elt.name + "` with options `" +
763 elt.options + "` to inner pipeline");
764 }
765 }
766 return success();
767}
768
769LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
770 raw_ostream &errorStream) {
771 TextualPipeline pipelineParser;
772 if (failed(Result: pipelineParser.initialize(text: pipeline, errorStream)))
773 return failure();
774 auto errorHandler = [&](Twine msg) {
775 errorStream << msg << "\n";
776 return failure();
777 };
778 if (failed(Result: pipelineParser.addToPipeline(pm, errorHandler)))
779 return failure();
780 return success();
781}
782
783FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
784 raw_ostream &errorStream) {
785 pipeline = pipeline.trim();
786 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
787 size_t pipelineStart = pipeline.find_first_of(C: '(');
788 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
789 !pipeline.consume_back(Suffix: ")")) {
790 errorStream << "expected pass pipeline to be wrapped with the anchor "
791 "operation type, e.g. 'builtin.module(...)'";
792 return failure();
793 }
794
795 StringRef opName = pipeline.take_front(N: pipelineStart).rtrim();
796 OpPassManager pm(opName);
797 if (failed(Result: parsePassPipeline(pipeline: pipeline.drop_front(N: 1 + pipelineStart), pm,
798 errorStream)))
799 return failure();
800 return pm;
801}
802
803//===----------------------------------------------------------------------===//
804// PassNameParser
805//===----------------------------------------------------------------------===//
806
807namespace {
808/// This struct represents the possible data entries in a parsed pass pipeline
809/// list.
810struct PassArgData {
811 PassArgData() = default;
812 PassArgData(const PassRegistryEntry *registryEntry)
813 : registryEntry(registryEntry) {}
814
815 /// This field is used when the parsed option corresponds to a registered pass
816 /// or pass pipeline.
817 const PassRegistryEntry *registryEntry{nullptr};
818
819 /// This field is set when instance specific pass options have been provided
820 /// on the command line.
821 StringRef options;
822};
823} // namespace
824
825namespace llvm {
826namespace cl {
827/// Define a valid OptionValue for the command line pass argument.
828template <>
829struct OptionValue<PassArgData> final
830 : OptionValueBase<PassArgData, /*isClass=*/true> {
831 OptionValue(const PassArgData &value) { this->setValue(value); }
832 OptionValue() = default;
833 void anchor() override {}
834
835 bool hasValue() const { return true; }
836 const PassArgData &getValue() const { return value; }
837 void setValue(const PassArgData &value) { this->value = value; }
838
839 PassArgData value;
840};
841} // namespace cl
842} // namespace llvm
843
844namespace {
845
846/// The name for the command line option used for parsing the textual pass
847/// pipeline.
848#define PASS_PIPELINE_ARG "pass-pipeline"
849
850/// Adds command line option for each registered pass or pass pipeline, as well
851/// as textual pass pipelines.
852struct PassNameParser : public llvm::cl::parser<PassArgData> {
853 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
854
855 void initialize();
856 void printOptionInfo(const llvm::cl::Option &opt,
857 size_t globalWidth) const override;
858 size_t getOptionWidth(const llvm::cl::Option &opt) const override;
859 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
860 PassArgData &value);
861
862 /// If true, this parser only parses entries that correspond to a concrete
863 /// pass registry entry, and does not include pipeline entries or the options
864 /// for pass entries.
865 bool passNamesOnly = false;
866};
867} // namespace
868
869void PassNameParser::initialize() {
870 llvm::cl::parser<PassArgData>::initialize();
871
872 /// Add the pass entries.
873 for (const auto &kv : *passRegistry) {
874 addLiteralOption(Name: kv.second.getPassArgument(), V: &kv.second,
875 HelpStr: kv.second.getPassDescription());
876 }
877 /// Add the pass pipeline entries.
878 if (!passNamesOnly) {
879 for (const auto &kv : *passPipelineRegistry) {
880 addLiteralOption(Name: kv.second.getPassArgument(), V: &kv.second,
881 HelpStr: kv.second.getPassDescription());
882 }
883 }
884}
885
886void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
887 size_t globalWidth) const {
888 // If this parser is just parsing pass names, print a simplified option
889 // string.
890 if (passNamesOnly) {
891 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
892 opt.printHelpStr(HelpStr: opt.HelpStr, Indent: globalWidth, FirstLineIndentedBy: opt.ArgStr.size() + 18);
893 return;
894 }
895
896 // Print the information for the top-level option.
897 if (opt.hasArgStr()) {
898 llvm::outs() << " --" << opt.ArgStr;
899 opt.printHelpStr(HelpStr: opt.HelpStr, Indent: globalWidth, FirstLineIndentedBy: opt.ArgStr.size() + 7);
900 } else {
901 llvm::outs() << " " << opt.HelpStr << '\n';
902 }
903
904 // Functor used to print the ordered entries of a registration map.
905 auto printOrderedEntries = [&](StringRef header, auto &map) {
906 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
907 for (auto &kv : map)
908 orderedEntries.push_back(Elt: &kv.second);
909 llvm::array_pod_sort(
910 orderedEntries.begin(), orderedEntries.end(),
911 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
912 return (*lhs)->getPassArgument().compare(RHS: (*rhs)->getPassArgument());
913 });
914
915 llvm::outs().indent(NumSpaces: 4) << header << ":\n";
916 for (PassRegistryEntry *entry : orderedEntries)
917 entry->printHelpStr(/*indent=*/6, descIndent: globalWidth);
918 };
919
920 // Print the available passes.
921 printOrderedEntries("Passes", *passRegistry);
922
923 // Print the available pass pipelines.
924 if (!passPipelineRegistry->empty())
925 printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
926}
927
928size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
929 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(O: opt) + 2;
930
931 // Check for any wider pass or pipeline options.
932 for (auto &entry : *passRegistry)
933 maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4);
934 for (auto &entry : *passPipelineRegistry)
935 maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4);
936 return maxWidth;
937}
938
939bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
940 StringRef arg, PassArgData &value) {
941 if (llvm::cl::parser<PassArgData>::parse(O&: opt, ArgName: argName, Arg: arg, V&: value))
942 return true;
943 value.options = arg;
944 return false;
945}
946
947//===----------------------------------------------------------------------===//
948// PassPipelineCLParser
949//===----------------------------------------------------------------------===//
950
951namespace mlir {
952namespace detail {
953struct PassPipelineCLParserImpl {
954 PassPipelineCLParserImpl(StringRef arg, StringRef description,
955 bool passNamesOnly)
956 : passList(arg, llvm::cl::desc(description)) {
957 passList.getParser().passNamesOnly = passNamesOnly;
958 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
959 }
960
961 /// Returns true if the given pass registry entry was registered at the
962 /// top-level of the parser, i.e. not within an explicit textual pipeline.
963 bool contains(const PassRegistryEntry *entry) const {
964 return llvm::any_of(Range: passList, P: [&](const PassArgData &data) {
965 return data.registryEntry == entry;
966 });
967 }
968
969 /// The set of passes and pass pipelines to run.
970 llvm::cl::list<PassArgData, bool, PassNameParser> passList;
971};
972} // namespace detail
973} // namespace mlir
974
975/// Construct a pass pipeline parser with the given command line description.
976PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
977 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
978 args&: arg, args&: description, /*passNamesOnly=*/args: false)),
979 passPipeline(
980 PASS_PIPELINE_ARG,
981 llvm::cl::desc("Textual description of the pass pipeline to run")) {}
982
983PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
984 StringRef alias)
985 : PassPipelineCLParser(arg, description) {
986 passPipelineAlias.emplace(args&: alias,
987 args: llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
988 args: llvm::cl::aliasopt(passPipeline));
989}
990
991PassPipelineCLParser::~PassPipelineCLParser() = default;
992
993/// Returns true if this parser contains any valid options to add.
994bool PassPipelineCLParser::hasAnyOccurrences() const {
995 return passPipeline.getNumOccurrences() != 0 ||
996 impl->passList.getNumOccurrences() != 0;
997}
998
999/// Returns true if the given pass registry entry was registered at the
1000/// top-level of the parser, i.e. not within an explicit textual pipeline.
1001bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
1002 return impl->contains(entry);
1003}
1004
1005/// Adds the passes defined by this parser entry to the given pass manager.
1006LogicalResult PassPipelineCLParser::addToPipeline(
1007 OpPassManager &pm,
1008 function_ref<LogicalResult(const Twine &)> errorHandler) const {
1009 if (passPipeline.getNumOccurrences()) {
1010 if (impl->passList.getNumOccurrences())
1011 return errorHandler(
1012 "'-" PASS_PIPELINE_ARG
1013 "' option can't be used with individual pass options");
1014 std::string errMsg;
1015 llvm::raw_string_ostream os(errMsg);
1016 FailureOr<OpPassManager> parsed = parsePassPipeline(pipeline: passPipeline, errorStream&: os);
1017 if (failed(Result: parsed))
1018 return errorHandler(errMsg);
1019 pm = std::move(*parsed);
1020 return success();
1021 }
1022
1023 for (auto &passIt : impl->passList) {
1024 if (failed(Result: passIt.registryEntry->addToPipeline(pm, options: passIt.options,
1025 errorHandler)))
1026 return failure();
1027 }
1028 return success();
1029}
1030
1031//===----------------------------------------------------------------------===//
1032// PassNameCLParser
1033//===----------------------------------------------------------------------===//
1034
1035/// Construct a pass pipeline parser with the given command line description.
1036PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
1037 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
1038 args&: arg, args&: description, /*passNamesOnly=*/args: true)) {
1039 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1040}
1041PassNameCLParser::~PassNameCLParser() = default;
1042
1043/// Returns true if this parser contains any valid options to add.
1044bool PassNameCLParser::hasAnyOccurrences() const {
1045 return impl->passList.getNumOccurrences() != 0;
1046}
1047
1048/// Returns true if the given pass registry entry was registered at the
1049/// top-level of the parser, i.e. not within an explicit textual pipeline.
1050bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
1051 return impl->contains(entry);
1052}
1053

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Pass/PassRegistry.cpp