1//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Pass/PassRegistry.h"
10
11#include "mlir/Pass/Pass.h"
12#include "mlir/Pass/PassManager.h"
13#include "llvm/ADT/DenseMap.h"
14#include "llvm/ADT/ScopeExit.h"
15#include "llvm/Support/Format.h"
16#include "llvm/Support/ManagedStatic.h"
17#include "llvm/Support/MemoryBuffer.h"
18#include "llvm/Support/SourceMgr.h"
19
20#include <optional>
21#include <utility>
22
23using namespace mlir;
24using namespace detail;
25
26/// Static mapping of all of the registered passes.
27static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
28
29/// A mapping of the above pass registry entries to the corresponding TypeID
30/// of the pass that they generate.
31static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
32
33/// Static mapping of all of the registered pass pipelines.
34static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
35 passPipelineRegistry;
36
37/// Utility to create a default registry function from a pass instance.
38static PassRegistryFunction
39buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
40 return [=](OpPassManager &pm, StringRef options,
41 function_ref<LogicalResult(const Twine &)> errorHandler) {
42 std::unique_ptr<Pass> pass = allocator();
43 LogicalResult result = pass->initializeOptions(options, errorHandler);
44
45 std::optional<StringRef> pmOpName = pm.getOpName();
46 std::optional<StringRef> passOpName = pass->getOpName();
47 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
48 passOpName && *pmOpName != *passOpName) {
49 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
50 "' restricted to '" + *pass->getOpName() +
51 "' on a PassManager intended to run on '" +
52 pm.getOpAnchorName() + "', did you intend to nest?");
53 }
54 pm.addPass(pass: std::move(pass));
55 return result;
56 };
57}
58
59/// Utility to print the help string for a specific option.
60static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
61 size_t descIndent, bool isTopLevel) {
62 size_t numSpaces = descIndent - indent - 4;
63 llvm::outs().indent(NumSpaces: indent)
64 << "--" << llvm::left_justify(Str: arg, Width: numSpaces) << "- " << desc << '\n';
65}
66
67//===----------------------------------------------------------------------===//
68// PassRegistry
69//===----------------------------------------------------------------------===//
70
71/// Print the help information for this pass. This includes the argument,
72/// description, and any pass options. `descIndent` is the indent that the
73/// descriptions should be aligned.
74void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
75 printOptionHelp(arg: getPassArgument(), desc: getPassDescription(), indent, descIndent,
76 /*isTopLevel=*/true);
77 // If this entry has options, print the help for those as well.
78 optHandler([=](const PassOptions &options) {
79 options.printHelp(indent, descIndent);
80 });
81}
82
83/// Return the maximum width required when printing the options of this
84/// entry.
85size_t PassRegistryEntry::getOptionWidth() const {
86 size_t maxLen = 0;
87 optHandler([&](const PassOptions &options) mutable {
88 maxLen = options.getOptionWidth() + 2;
89 });
90 return maxLen;
91}
92
93//===----------------------------------------------------------------------===//
94// PassPipelineInfo
95//===----------------------------------------------------------------------===//
96
97void mlir::registerPassPipeline(
98 StringRef arg, StringRef description, const PassRegistryFunction &function,
99 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
100 PassPipelineInfo pipelineInfo(arg, description, function,
101 std::move(optHandler));
102 bool inserted = passPipelineRegistry->try_emplace(Key: arg, Args&: pipelineInfo).second;
103#ifndef NDEBUG
104 if (!inserted)
105 report_fatal_error(reason: "Pass pipeline " + arg + " registered multiple times");
106#endif
107 (void)inserted;
108}
109
110//===----------------------------------------------------------------------===//
111// PassInfo
112//===----------------------------------------------------------------------===//
113
114PassInfo::PassInfo(StringRef arg, StringRef description,
115 const PassAllocatorFunction &allocator)
116 : PassRegistryEntry(
117 arg, description, buildDefaultRegistryFn(allocator),
118 // Use a temporary pass to provide an options instance.
119 [=](function_ref<void(const PassOptions &)> optHandler) {
120 optHandler(allocator()->passOptions);
121 }) {}
122
123void mlir::registerPass(const PassAllocatorFunction &function) {
124 std::unique_ptr<Pass> pass = function();
125 StringRef arg = pass->getArgument();
126 if (arg.empty())
127 llvm::report_fatal_error(reason: llvm::Twine("Trying to register '") +
128 pass->getName() +
129 "' pass that does not override `getArgument()`");
130 StringRef description = pass->getDescription();
131 PassInfo passInfo(arg, description, function);
132 passRegistry->try_emplace(Key: arg, Args&: passInfo);
133
134 // Verify that the registered pass has the same ID as any registered to this
135 // arg before it.
136 TypeID entryTypeID = pass->getTypeID();
137 auto it = passRegistryTypeIDs->try_emplace(Key: arg, Args&: entryTypeID).first;
138 if (it->second != entryTypeID)
139 llvm::report_fatal_error(
140 reason: "pass allocator creates a different pass than previously "
141 "registered for pass " +
142 arg);
143}
144
145/// Returns the pass info for the specified pass argument or null if unknown.
146const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
147 auto it = passRegistry->find(Key: passArg);
148 return it == passRegistry->end() ? nullptr : &it->second;
149}
150
151/// Returns the pass pipeline info for the specified pass pipeline argument or
152/// null if unknown.
153const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
154 auto it = passPipelineRegistry->find(Key: pipelineArg);
155 return it == passPipelineRegistry->end() ? nullptr : &it->second;
156}
157
158//===----------------------------------------------------------------------===//
159// PassOptions
160//===----------------------------------------------------------------------===//
161
162LogicalResult detail::pass_options::parseCommaSeparatedList(
163 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
164 function_ref<LogicalResult(StringRef)> elementParseFn) {
165 // Functor used for finding a character in a string, and skipping over
166 // various "range" characters.
167 llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
168 [&](StringRef str, size_t index, char c) -> size_t {
169 for (size_t i = index, e = str.size(); i < e; ++i) {
170 if (str[i] == c)
171 return i;
172 // Check for various range characters.
173 if (str[i] == '{')
174 i = findChar(str, i + 1, '}');
175 else if (str[i] == '(')
176 i = findChar(str, i + 1, ')');
177 else if (str[i] == '[')
178 i = findChar(str, i + 1, ']');
179 else if (str[i] == '\"')
180 i = str.find_first_of(C: '\"', From: i + 1);
181 else if (str[i] == '\'')
182 i = str.find_first_of(C: '\'', From: i + 1);
183 }
184 return StringRef::npos;
185 };
186
187 size_t nextElePos = findChar(optionStr, 0, ',');
188 while (nextElePos != StringRef::npos) {
189 // Process the portion before the comma.
190 if (failed(result: elementParseFn(optionStr.substr(Start: 0, N: nextElePos))))
191 return failure();
192
193 optionStr = optionStr.substr(Start: nextElePos + 1);
194 nextElePos = findChar(optionStr, 0, ',');
195 }
196 return elementParseFn(optionStr.substr(Start: 0, N: nextElePos));
197}
198
199/// Out of line virtual function to provide home for the class.
200void detail::PassOptions::OptionBase::anchor() {}
201
202/// Copy the option values from 'other'.
203void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
204 assert(options.size() == other.options.size());
205 if (options.empty())
206 return;
207 for (auto optionsIt : llvm::zip(t&: options, u: other.options))
208 std::get<0>(t&: optionsIt)->copyValueFrom(other: *std::get<1>(t&: optionsIt));
209}
210
211/// Parse in the next argument from the given options string. Returns a tuple
212/// containing [the key of the option, the value of the option, updated
213/// `options` string pointing after the parsed option].
214static std::tuple<StringRef, StringRef, StringRef>
215parseNextArg(StringRef options) {
216 // Functor used to extract an argument from 'options' and update it to point
217 // after the arg.
218 auto extractArgAndUpdateOptions = [&](size_t argSize) {
219 StringRef str = options.take_front(N: argSize).trim();
220 options = options.drop_front(N: argSize).ltrim();
221 return str;
222 };
223 // Try to process the given punctuation, properly escaping any contained
224 // characters.
225 auto tryProcessPunct = [&](size_t &currentPos, char punct) {
226 if (options[currentPos] != punct)
227 return false;
228 size_t nextIt = options.find_first_of(C: punct, From: currentPos + 1);
229 if (nextIt != StringRef::npos)
230 currentPos = nextIt;
231 return true;
232 };
233
234 // Parse the argument name of the option.
235 StringRef argName;
236 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
237 // Check for the end of the full option.
238 if (argEndIt == optionsE || options[argEndIt] == ' ') {
239 argName = extractArgAndUpdateOptions(argEndIt);
240 return std::make_tuple(args&: argName, args: StringRef(), args&: options);
241 }
242
243 // Check for the end of the name and the start of the value.
244 if (options[argEndIt] == '=') {
245 argName = extractArgAndUpdateOptions(argEndIt);
246 options = options.drop_front();
247 break;
248 }
249 }
250
251 // Parse the value of the option.
252 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
253 // Handle the end of the options string.
254 if (argEndIt == optionsE || options[argEndIt] == ' ') {
255 StringRef value = extractArgAndUpdateOptions(argEndIt);
256 return std::make_tuple(args&: argName, args&: value, args&: options);
257 }
258
259 // Skip over escaped sequences.
260 char c = options[argEndIt];
261 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
262 continue;
263 // '{...}' is used to specify options to passes, properly escape it so
264 // that we don't accidentally split any nested options.
265 if (c == '{') {
266 size_t braceCount = 1;
267 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
268 // Allow nested punctuation.
269 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
270 continue;
271 if (options[argEndIt] == '{')
272 ++braceCount;
273 else if (options[argEndIt] == '}' && --braceCount == 0)
274 break;
275 }
276 // Account for the increment at the top of the loop.
277 --argEndIt;
278 }
279 }
280 llvm_unreachable("unexpected control flow in pass option parsing");
281}
282
283LogicalResult detail::PassOptions::parseFromString(StringRef options,
284 raw_ostream &errorStream) {
285 // NOTE: `options` is modified in place to always refer to the unprocessed
286 // part of the string.
287 while (!options.empty()) {
288 StringRef key, value;
289 std::tie(args&: key, args&: value, args&: options) = parseNextArg(options);
290 if (key.empty())
291 continue;
292
293 auto it = OptionsMap.find(Key: key);
294 if (it == OptionsMap.end()) {
295 errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
296 return failure();
297 }
298 if (llvm::cl::ProvidePositionalOption(Handler: it->second, Arg: value, i: 0))
299 return failure();
300 }
301
302 return success();
303}
304
305/// Print the options held by this struct in a form that can be parsed via
306/// 'parseFromString'.
307void detail::PassOptions::print(raw_ostream &os) {
308 // If there are no options, there is nothing left to do.
309 if (OptionsMap.empty())
310 return;
311
312 // Sort the options to make the ordering deterministic.
313 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
314 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
315 return (*lhs)->getArgStr().compare(RHS: (*rhs)->getArgStr());
316 };
317 llvm::array_pod_sort(Start: orderedOps.begin(), End: orderedOps.end(), Compare: compareOptionArgs);
318
319 // Interleave the options with ' '.
320 os << '{';
321 llvm::interleave(
322 c: orderedOps, os, each_fn: [&](OptionBase *option) { option->print(os); }, separator: " ");
323 os << '}';
324}
325
326/// Print the help string for the options held by this struct. `descIndent` is
327/// the indent within the stream that the descriptions should be aligned.
328void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
329 // Sort the options to make the ordering deterministic.
330 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
331 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
332 return (*lhs)->getArgStr().compare(RHS: (*rhs)->getArgStr());
333 };
334 llvm::array_pod_sort(Start: orderedOps.begin(), End: orderedOps.end(), Compare: compareOptionArgs);
335 for (OptionBase *option : orderedOps) {
336 // TODO: printOptionInfo assumes a specific indent and will
337 // print options with values with incorrect indentation. We should add
338 // support to llvm::cl::Option for passing in a base indent to use when
339 // printing.
340 llvm::outs().indent(NumSpaces: indent);
341 option->getOption()->printOptionInfo(GlobalWidth: descIndent - indent);
342 }
343}
344
345/// Return the maximum width required when printing the help string.
346size_t detail::PassOptions::getOptionWidth() const {
347 size_t max = 0;
348 for (auto *option : options)
349 max = std::max(a: max, b: option->getOption()->getOptionWidth());
350 return max;
351}
352
353//===----------------------------------------------------------------------===//
354// MLIR Options
355//===----------------------------------------------------------------------===//
356
357//===----------------------------------------------------------------------===//
358// OpPassManager: OptionValue
359
360llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
361llvm::cl::OptionValue<OpPassManager>::OptionValue(
362 const mlir::OpPassManager &value) {
363 setValue(value);
364}
365llvm::cl::OptionValue<OpPassManager>::OptionValue(
366 const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) {
367 if (rhs.hasValue())
368 setValue(rhs.getValue());
369}
370llvm::cl::OptionValue<OpPassManager> &
371llvm::cl::OptionValue<OpPassManager>::operator=(
372 const mlir::OpPassManager &rhs) {
373 setValue(rhs);
374 return *this;
375}
376
377llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
378
379void llvm::cl::OptionValue<OpPassManager>::setValue(
380 const OpPassManager &newValue) {
381 if (hasValue())
382 *value = newValue;
383 else
384 value = std::make_unique<mlir::OpPassManager>(args: newValue);
385}
386void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
387 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipeline: pipelineStr);
388 assert(succeeded(pipeline) && "invalid pass pipeline");
389 setValue(*pipeline);
390}
391
392bool llvm::cl::OptionValue<OpPassManager>::compare(
393 const mlir::OpPassManager &rhs) const {
394 std::string lhsStr, rhsStr;
395 {
396 raw_string_ostream lhsStream(lhsStr);
397 value->printAsTextualPipeline(os&: lhsStream);
398
399 raw_string_ostream rhsStream(rhsStr);
400 rhs.printAsTextualPipeline(os&: rhsStream);
401 }
402
403 // Use the textual format for pipeline comparisons.
404 return lhsStr == rhsStr;
405}
406
407void llvm::cl::OptionValue<OpPassManager>::anchor() {}
408
409//===----------------------------------------------------------------------===//
410// OpPassManager: Parser
411
412namespace llvm {
413namespace cl {
414template class basic_parser<OpPassManager>;
415} // namespace cl
416} // namespace llvm
417
418bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
419 ParsedPassManager &value) {
420 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipeline: arg);
421 if (failed(result: pipeline))
422 return true;
423 value.value = std::make_unique<OpPassManager>(args: std::move(*pipeline));
424 return false;
425}
426
427void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
428 const OpPassManager &value) {
429 value.printAsTextualPipeline(os);
430}
431
432void llvm::cl::parser<OpPassManager>::printOptionDiff(
433 const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
434 size_t globalWidth) const {
435 printOptionName(O: opt, GlobalWidth: globalWidth);
436 outs() << "= ";
437 pm.printAsTextualPipeline(os&: outs());
438
439 if (defaultValue.hasValue()) {
440 outs().indent(NumSpaces: 2) << " (default: ";
441 defaultValue.getValue().printAsTextualPipeline(os&: outs());
442 outs() << ")";
443 }
444 outs() << "\n";
445}
446
447void llvm::cl::parser<OpPassManager>::anchor() {}
448
449llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
450 default;
451llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
452 ParsedPassManager &&) = default;
453llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
454 default;
455
456//===----------------------------------------------------------------------===//
457// TextualPassPipeline Parser
458//===----------------------------------------------------------------------===//
459
460namespace {
461/// This class represents a textual description of a pass pipeline.
462class TextualPipeline {
463public:
464 /// Try to initialize this pipeline with the given pipeline text.
465 /// `errorStream` is the output stream to emit errors to.
466 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
467
468 /// Add the internal pipeline elements to the provided pass manager.
469 LogicalResult
470 addToPipeline(OpPassManager &pm,
471 function_ref<LogicalResult(const Twine &)> errorHandler) const;
472
473private:
474 /// A functor used to emit errors found during pipeline handling. The first
475 /// parameter corresponds to the raw location within the pipeline string. This
476 /// should always return failure.
477 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
478
479 /// A struct to capture parsed pass pipeline names.
480 ///
481 /// A pipeline is defined as a series of names, each of which may in itself
482 /// recursively contain a nested pipeline. A name is either the name of a pass
483 /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
484 /// the name is the name of a pass, the InnerPipeline is empty, since passes
485 /// cannot contain inner pipelines.
486 struct PipelineElement {
487 PipelineElement(StringRef name) : name(name) {}
488
489 StringRef name;
490 StringRef options;
491 const PassRegistryEntry *registryEntry = nullptr;
492 std::vector<PipelineElement> innerPipeline;
493 };
494
495 /// Parse the given pipeline text into the internal pipeline vector. This
496 /// function only parses the structure of the pipeline, and does not resolve
497 /// its elements.
498 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
499
500 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
501 /// the corresponding registry entry.
502 LogicalResult
503 resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
504 ErrorHandlerT errorHandler);
505
506 /// Resolve a single element of the pipeline.
507 LogicalResult resolvePipelineElement(PipelineElement &element,
508 ErrorHandlerT errorHandler);
509
510 /// Add the given pipeline elements to the provided pass manager.
511 LogicalResult
512 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
513 function_ref<LogicalResult(const Twine &)> errorHandler) const;
514
515 std::vector<PipelineElement> pipeline;
516};
517
518} // namespace
519
520/// Try to initialize this pipeline with the given pipeline text. An option is
521/// given to enable accurate error reporting.
522LogicalResult TextualPipeline::initialize(StringRef text,
523 raw_ostream &errorStream) {
524 if (text.empty())
525 return success();
526
527 // Build a source manager to use for error reporting.
528 llvm::SourceMgr pipelineMgr;
529 pipelineMgr.AddNewSourceBuffer(
530 F: llvm::MemoryBuffer::getMemBuffer(InputData: text, BufferName: "MLIR Textual PassPipeline Parser",
531 /*RequiresNullTerminator=*/false),
532 IncludeLoc: SMLoc());
533 auto errorHandler = [&](const char *rawLoc, Twine msg) {
534 pipelineMgr.PrintMessage(OS&: errorStream, Loc: SMLoc::getFromPointer(Ptr: rawLoc),
535 Kind: llvm::SourceMgr::DK_Error, Msg: msg);
536 return failure();
537 };
538
539 // Parse the provided pipeline string.
540 if (failed(result: parsePipelineText(text, errorHandler)))
541 return failure();
542 return resolvePipelineElements(elements: pipeline, errorHandler);
543}
544
545/// Add the internal pipeline elements to the provided pass manager.
546LogicalResult TextualPipeline::addToPipeline(
547 OpPassManager &pm,
548 function_ref<LogicalResult(const Twine &)> errorHandler) const {
549 // Temporarily disable implicit nesting while we append to the pipeline. We
550 // want the created pipeline to exactly match the parsed text pipeline, so
551 // it's preferrable to just error out if implicit nesting would be required.
552 OpPassManager::Nesting nesting = pm.getNesting();
553 pm.setNesting(OpPassManager::Nesting::Explicit);
554 auto restore = llvm::make_scope_exit(F: [&]() { pm.setNesting(nesting); });
555
556 return addToPipeline(elements: pipeline, pm, errorHandler);
557}
558
559/// Parse the given pipeline text into the internal pipeline vector. This
560/// function only parses the structure of the pipeline, and does not resolve
561/// its elements.
562LogicalResult TextualPipeline::parsePipelineText(StringRef text,
563 ErrorHandlerT errorHandler) {
564 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
565 for (;;) {
566 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
567 size_t pos = text.find_first_of(Chars: ",(){");
568 pipeline.emplace_back(/*name=*/args: text.substr(Start: 0, N: pos).trim());
569
570 // If we have a single terminating name, we're done.
571 if (pos == StringRef::npos)
572 break;
573
574 text = text.substr(Start: pos);
575 char sep = text[0];
576
577 // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
578 if (sep == '{') {
579 text = text.substr(Start: 1);
580
581 // Skip over everything until the closing '}' and store as options.
582 size_t close = StringRef::npos;
583 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
584 if (text[i] == '{') {
585 ++braceCount;
586 continue;
587 }
588 if (text[i] == '}' && --braceCount == 0) {
589 close = i;
590 break;
591 }
592 }
593
594 // Check to see if a closing options brace was found.
595 if (close == StringRef::npos) {
596 return errorHandler(
597 /*rawLoc=*/text.data() - 1,
598 "missing closing '}' while processing pass options");
599 }
600 pipeline.back().options = text.substr(Start: 0, N: close);
601 text = text.substr(Start: close + 1);
602
603 // Consume space characters that an user might add for readability.
604 text = text.ltrim();
605
606 // Skip checking for '(' because nested pipelines cannot have options.
607 } else if (sep == '(') {
608 text = text.substr(Start: 1);
609
610 // Push the inner pipeline onto the stack to continue processing.
611 pipelineStack.push_back(Elt: &pipeline.back().innerPipeline);
612 continue;
613 }
614
615 // When handling the close parenthesis, we greedily consume them to avoid
616 // empty strings in the pipeline.
617 while (text.consume_front(Prefix: ")")) {
618 // If we try to pop the outer pipeline we have unbalanced parentheses.
619 if (pipelineStack.size() == 1)
620 return errorHandler(/*rawLoc=*/text.data() - 1,
621 "encountered extra closing ')' creating unbalanced "
622 "parentheses while parsing pipeline");
623
624 pipelineStack.pop_back();
625 // Consume space characters that an user might add for readability.
626 text = text.ltrim();
627 }
628
629 // Check if we've finished parsing.
630 if (text.empty())
631 break;
632
633 // Otherwise, the end of an inner pipeline always has to be followed by
634 // a comma, and then we can continue.
635 if (!text.consume_front(Prefix: ","))
636 return errorHandler(text.data(), "expected ',' after parsing pipeline");
637 }
638
639 // Check for unbalanced parentheses.
640 if (pipelineStack.size() > 1)
641 return errorHandler(
642 text.data(),
643 "encountered unbalanced parentheses while parsing pipeline");
644
645 assert(pipelineStack.back() == &pipeline &&
646 "wrong pipeline at the bottom of the stack");
647 return success();
648}
649
650/// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
651/// the corresponding registry entry.
652LogicalResult TextualPipeline::resolvePipelineElements(
653 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
654 for (auto &elt : elements)
655 if (failed(result: resolvePipelineElement(element&: elt, errorHandler)))
656 return failure();
657 return success();
658}
659
660/// Resolve a single element of the pipeline.
661LogicalResult
662TextualPipeline::resolvePipelineElement(PipelineElement &element,
663 ErrorHandlerT errorHandler) {
664 // If the inner pipeline of this element is not empty, this is an operation
665 // pipeline.
666 if (!element.innerPipeline.empty())
667 return resolvePipelineElements(elements: element.innerPipeline, errorHandler);
668
669 // Otherwise, this must be a pass or pass pipeline.
670 // Check to see if a pipeline was registered with this name.
671 if ((element.registryEntry = PassPipelineInfo::lookup(pipelineArg: element.name)))
672 return success();
673
674 // If not, then this must be a specific pass name.
675 if ((element.registryEntry = PassInfo::lookup(passArg: element.name)))
676 return success();
677
678 // Emit an error for the unknown pass.
679 auto *rawLoc = element.name.data();
680 return errorHandler(rawLoc, "'" + element.name +
681 "' does not refer to a "
682 "registered pass or pass pipeline");
683}
684
685/// Add the given pipeline elements to the provided pass manager.
686LogicalResult TextualPipeline::addToPipeline(
687 ArrayRef<PipelineElement> elements, OpPassManager &pm,
688 function_ref<LogicalResult(const Twine &)> errorHandler) const {
689 for (auto &elt : elements) {
690 if (elt.registryEntry) {
691 if (failed(result: elt.registryEntry->addToPipeline(pm, options: elt.options,
692 errorHandler))) {
693 return errorHandler("failed to add `" + elt.name + "` with options `" +
694 elt.options + "`");
695 }
696 } else if (failed(result: addToPipeline(elements: elt.innerPipeline, pm&: pm.nest(nestedName: elt.name),
697 errorHandler))) {
698 return errorHandler("failed to add `" + elt.name + "` with options `" +
699 elt.options + "` to inner pipeline");
700 }
701 }
702 return success();
703}
704
705LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
706 raw_ostream &errorStream) {
707 TextualPipeline pipelineParser;
708 if (failed(result: pipelineParser.initialize(text: pipeline, errorStream)))
709 return failure();
710 auto errorHandler = [&](Twine msg) {
711 errorStream << msg << "\n";
712 return failure();
713 };
714 if (failed(result: pipelineParser.addToPipeline(pm, errorHandler)))
715 return failure();
716 return success();
717}
718
719FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
720 raw_ostream &errorStream) {
721 pipeline = pipeline.trim();
722 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
723 size_t pipelineStart = pipeline.find_first_of(C: '(');
724 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
725 !pipeline.consume_back(Suffix: ")")) {
726 errorStream << "expected pass pipeline to be wrapped with the anchor "
727 "operation type, e.g. 'builtin.module(...)'";
728 return failure();
729 }
730
731 StringRef opName = pipeline.take_front(N: pipelineStart).rtrim();
732 OpPassManager pm(opName);
733 if (failed(result: parsePassPipeline(pipeline: pipeline.drop_front(N: 1 + pipelineStart), pm,
734 errorStream)))
735 return failure();
736 return pm;
737}
738
739//===----------------------------------------------------------------------===//
740// PassNameParser
741//===----------------------------------------------------------------------===//
742
743namespace {
744/// This struct represents the possible data entries in a parsed pass pipeline
745/// list.
746struct PassArgData {
747 PassArgData() = default;
748 PassArgData(const PassRegistryEntry *registryEntry)
749 : registryEntry(registryEntry) {}
750
751 /// This field is used when the parsed option corresponds to a registered pass
752 /// or pass pipeline.
753 const PassRegistryEntry *registryEntry{nullptr};
754
755 /// This field is set when instance specific pass options have been provided
756 /// on the command line.
757 StringRef options;
758};
759} // namespace
760
761namespace llvm {
762namespace cl {
763/// Define a valid OptionValue for the command line pass argument.
764template <>
765struct OptionValue<PassArgData> final
766 : OptionValueBase<PassArgData, /*isClass=*/true> {
767 OptionValue(const PassArgData &value) { this->setValue(value); }
768 OptionValue() = default;
769 void anchor() override {}
770
771 bool hasValue() const { return true; }
772 const PassArgData &getValue() const { return value; }
773 void setValue(const PassArgData &value) { this->value = value; }
774
775 PassArgData value;
776};
777} // namespace cl
778} // namespace llvm
779
780namespace {
781
782/// The name for the command line option used for parsing the textual pass
783/// pipeline.
784#define PASS_PIPELINE_ARG "pass-pipeline"
785
786/// Adds command line option for each registered pass or pass pipeline, as well
787/// as textual pass pipelines.
788struct PassNameParser : public llvm::cl::parser<PassArgData> {
789 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
790
791 void initialize();
792 void printOptionInfo(const llvm::cl::Option &opt,
793 size_t globalWidth) const override;
794 size_t getOptionWidth(const llvm::cl::Option &opt) const override;
795 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
796 PassArgData &value);
797
798 /// If true, this parser only parses entries that correspond to a concrete
799 /// pass registry entry, and does not include pipeline entries or the options
800 /// for pass entries.
801 bool passNamesOnly = false;
802};
803} // namespace
804
805void PassNameParser::initialize() {
806 llvm::cl::parser<PassArgData>::initialize();
807
808 /// Add the pass entries.
809 for (const auto &kv : *passRegistry) {
810 addLiteralOption(Name: kv.second.getPassArgument(), V: &kv.second,
811 HelpStr: kv.second.getPassDescription());
812 }
813 /// Add the pass pipeline entries.
814 if (!passNamesOnly) {
815 for (const auto &kv : *passPipelineRegistry) {
816 addLiteralOption(Name: kv.second.getPassArgument(), V: &kv.second,
817 HelpStr: kv.second.getPassDescription());
818 }
819 }
820}
821
822void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
823 size_t globalWidth) const {
824 // If this parser is just parsing pass names, print a simplified option
825 // string.
826 if (passNamesOnly) {
827 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
828 opt.printHelpStr(HelpStr: opt.HelpStr, Indent: globalWidth, FirstLineIndentedBy: opt.ArgStr.size() + 18);
829 return;
830 }
831
832 // Print the information for the top-level option.
833 if (opt.hasArgStr()) {
834 llvm::outs() << " --" << opt.ArgStr;
835 opt.printHelpStr(HelpStr: opt.HelpStr, Indent: globalWidth, FirstLineIndentedBy: opt.ArgStr.size() + 7);
836 } else {
837 llvm::outs() << " " << opt.HelpStr << '\n';
838 }
839
840 // Functor used to print the ordered entries of a registration map.
841 auto printOrderedEntries = [&](StringRef header, auto &map) {
842 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
843 for (auto &kv : map)
844 orderedEntries.push_back(Elt: &kv.second);
845 llvm::array_pod_sort(
846 orderedEntries.begin(), orderedEntries.end(),
847 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
848 return (*lhs)->getPassArgument().compare(RHS: (*rhs)->getPassArgument());
849 });
850
851 llvm::outs().indent(NumSpaces: 4) << header << ":\n";
852 for (PassRegistryEntry *entry : orderedEntries)
853 entry->printHelpStr(/*indent=*/6, descIndent: globalWidth);
854 };
855
856 // Print the available passes.
857 printOrderedEntries("Passes", *passRegistry);
858
859 // Print the available pass pipelines.
860 if (!passPipelineRegistry->empty())
861 printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
862}
863
864size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
865 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(O: opt) + 2;
866
867 // Check for any wider pass or pipeline options.
868 for (auto &entry : *passRegistry)
869 maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4);
870 for (auto &entry : *passPipelineRegistry)
871 maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4);
872 return maxWidth;
873}
874
875bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
876 StringRef arg, PassArgData &value) {
877 if (llvm::cl::parser<PassArgData>::parse(O&: opt, ArgName: argName, Arg: arg, V&: value))
878 return true;
879 value.options = arg;
880 return false;
881}
882
883//===----------------------------------------------------------------------===//
884// PassPipelineCLParser
885//===----------------------------------------------------------------------===//
886
887namespace mlir {
888namespace detail {
889struct PassPipelineCLParserImpl {
890 PassPipelineCLParserImpl(StringRef arg, StringRef description,
891 bool passNamesOnly)
892 : passList(arg, llvm::cl::desc(description)) {
893 passList.getParser().passNamesOnly = passNamesOnly;
894 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
895 }
896
897 /// Returns true if the given pass registry entry was registered at the
898 /// top-level of the parser, i.e. not within an explicit textual pipeline.
899 bool contains(const PassRegistryEntry *entry) const {
900 return llvm::any_of(Range: passList, P: [&](const PassArgData &data) {
901 return data.registryEntry == entry;
902 });
903 }
904
905 /// The set of passes and pass pipelines to run.
906 llvm::cl::list<PassArgData, bool, PassNameParser> passList;
907};
908} // namespace detail
909} // namespace mlir
910
911/// Construct a pass pipeline parser with the given command line description.
912PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
913 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
914 args&: arg, args&: description, /*passNamesOnly=*/args: false)),
915 passPipeline(
916 PASS_PIPELINE_ARG,
917 llvm::cl::desc("Textual description of the pass pipeline to run")) {}
918
919PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
920 StringRef alias)
921 : PassPipelineCLParser(arg, description) {
922 passPipelineAlias.emplace(args&: alias,
923 args: llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
924 args: llvm::cl::aliasopt(passPipeline));
925}
926
927PassPipelineCLParser::~PassPipelineCLParser() = default;
928
929/// Returns true if this parser contains any valid options to add.
930bool PassPipelineCLParser::hasAnyOccurrences() const {
931 return passPipeline.getNumOccurrences() != 0 ||
932 impl->passList.getNumOccurrences() != 0;
933}
934
935/// Returns true if the given pass registry entry was registered at the
936/// top-level of the parser, i.e. not within an explicit textual pipeline.
937bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
938 return impl->contains(entry);
939}
940
941/// Adds the passes defined by this parser entry to the given pass manager.
942LogicalResult PassPipelineCLParser::addToPipeline(
943 OpPassManager &pm,
944 function_ref<LogicalResult(const Twine &)> errorHandler) const {
945 if (passPipeline.getNumOccurrences()) {
946 if (impl->passList.getNumOccurrences())
947 return errorHandler(
948 "'-" PASS_PIPELINE_ARG
949 "' option can't be used with individual pass options");
950 std::string errMsg;
951 llvm::raw_string_ostream os(errMsg);
952 FailureOr<OpPassManager> parsed = parsePassPipeline(pipeline: passPipeline, errorStream&: os);
953 if (failed(result: parsed))
954 return errorHandler(errMsg);
955 pm = std::move(*parsed);
956 return success();
957 }
958
959 for (auto &passIt : impl->passList) {
960 if (failed(result: passIt.registryEntry->addToPipeline(pm, options: passIt.options,
961 errorHandler)))
962 return failure();
963 }
964 return success();
965}
966
967//===----------------------------------------------------------------------===//
968// PassNameCLParser
969
970/// Construct a pass pipeline parser with the given command line description.
971PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
972 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
973 args&: arg, args&: description, /*passNamesOnly=*/args: true)) {
974 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
975}
976PassNameCLParser::~PassNameCLParser() = default;
977
978/// Returns true if this parser contains any valid options to add.
979bool PassNameCLParser::hasAnyOccurrences() const {
980 return impl->passList.getNumOccurrences() != 0;
981}
982
983/// Returns true if the given pass registry entry was registered at the
984/// top-level of the parser, i.e. not within an explicit textual pipeline.
985bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
986 return impl->contains(entry);
987}
988

source code of mlir/lib/Pass/PassRegistry.cpp