1 | //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Pass/PassRegistry.h" |
10 | |
11 | #include "mlir/Pass/Pass.h" |
12 | #include "mlir/Pass/PassManager.h" |
13 | #include "llvm/ADT/DenseMap.h" |
14 | #include "llvm/ADT/ScopeExit.h" |
15 | #include "llvm/Support/Format.h" |
16 | #include "llvm/Support/ManagedStatic.h" |
17 | #include "llvm/Support/MemoryBuffer.h" |
18 | #include "llvm/Support/SourceMgr.h" |
19 | |
20 | #include <optional> |
21 | #include <utility> |
22 | |
23 | using namespace mlir; |
24 | using namespace detail; |
25 | |
26 | /// Static mapping of all of the registered passes. |
27 | static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry; |
28 | |
29 | /// A mapping of the above pass registry entries to the corresponding TypeID |
30 | /// of the pass that they generate. |
31 | static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs; |
32 | |
33 | /// Static mapping of all of the registered pass pipelines. |
34 | static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> |
35 | passPipelineRegistry; |
36 | |
37 | /// Utility to create a default registry function from a pass instance. |
38 | static PassRegistryFunction |
39 | buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { |
40 | return [=](OpPassManager &pm, StringRef options, |
41 | function_ref<LogicalResult(const Twine &)> errorHandler) { |
42 | std::unique_ptr<Pass> pass = allocator(); |
43 | LogicalResult result = pass->initializeOptions(options, errorHandler); |
44 | |
45 | std::optional<StringRef> pmOpName = pm.getOpName(); |
46 | std::optional<StringRef> passOpName = pass->getOpName(); |
47 | if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName && |
48 | passOpName && *pmOpName != *passOpName) { |
49 | return errorHandler(llvm::Twine("Can't add pass '" ) + pass->getName() + |
50 | "' restricted to '" + *pass->getOpName() + |
51 | "' on a PassManager intended to run on '" + |
52 | pm.getOpAnchorName() + "', did you intend to nest?" ); |
53 | } |
54 | pm.addPass(pass: std::move(pass)); |
55 | return result; |
56 | }; |
57 | } |
58 | |
59 | /// Utility to print the help string for a specific option. |
60 | static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, |
61 | size_t descIndent, bool isTopLevel) { |
62 | size_t numSpaces = descIndent - indent - 4; |
63 | llvm::outs().indent(NumSpaces: indent) |
64 | << "--" << llvm::left_justify(Str: arg, Width: numSpaces) << "- " << desc << '\n'; |
65 | } |
66 | |
67 | //===----------------------------------------------------------------------===// |
68 | // PassRegistry |
69 | //===----------------------------------------------------------------------===// |
70 | |
71 | /// Print the help information for this pass. This includes the argument, |
72 | /// description, and any pass options. `descIndent` is the indent that the |
73 | /// descriptions should be aligned. |
74 | void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const { |
75 | printOptionHelp(arg: getPassArgument(), desc: getPassDescription(), indent, descIndent, |
76 | /*isTopLevel=*/true); |
77 | // If this entry has options, print the help for those as well. |
78 | optHandler([=](const PassOptions &options) { |
79 | options.printHelp(indent, descIndent); |
80 | }); |
81 | } |
82 | |
83 | /// Return the maximum width required when printing the options of this |
84 | /// entry. |
85 | size_t PassRegistryEntry::getOptionWidth() const { |
86 | size_t maxLen = 0; |
87 | optHandler([&](const PassOptions &options) mutable { |
88 | maxLen = options.getOptionWidth() + 2; |
89 | }); |
90 | return maxLen; |
91 | } |
92 | |
93 | //===----------------------------------------------------------------------===// |
94 | // PassPipelineInfo |
95 | //===----------------------------------------------------------------------===// |
96 | |
97 | void mlir::registerPassPipeline( |
98 | StringRef arg, StringRef description, const PassRegistryFunction &function, |
99 | std::function<void(function_ref<void(const PassOptions &)>)> optHandler) { |
100 | PassPipelineInfo pipelineInfo(arg, description, function, |
101 | std::move(optHandler)); |
102 | bool inserted = passPipelineRegistry->try_emplace(Key: arg, Args&: pipelineInfo).second; |
103 | #ifndef NDEBUG |
104 | if (!inserted) |
105 | report_fatal_error(reason: "Pass pipeline " + arg + " registered multiple times" ); |
106 | #endif |
107 | (void)inserted; |
108 | } |
109 | |
110 | //===----------------------------------------------------------------------===// |
111 | // PassInfo |
112 | //===----------------------------------------------------------------------===// |
113 | |
114 | PassInfo::PassInfo(StringRef arg, StringRef description, |
115 | const PassAllocatorFunction &allocator) |
116 | : PassRegistryEntry( |
117 | arg, description, buildDefaultRegistryFn(allocator), |
118 | // Use a temporary pass to provide an options instance. |
119 | [=](function_ref<void(const PassOptions &)> optHandler) { |
120 | optHandler(allocator()->passOptions); |
121 | }) {} |
122 | |
123 | void mlir::registerPass(const PassAllocatorFunction &function) { |
124 | std::unique_ptr<Pass> pass = function(); |
125 | StringRef arg = pass->getArgument(); |
126 | if (arg.empty()) |
127 | llvm::report_fatal_error(reason: llvm::Twine("Trying to register '" ) + |
128 | pass->getName() + |
129 | "' pass that does not override `getArgument()`" ); |
130 | StringRef description = pass->getDescription(); |
131 | PassInfo passInfo(arg, description, function); |
132 | passRegistry->try_emplace(Key: arg, Args&: passInfo); |
133 | |
134 | // Verify that the registered pass has the same ID as any registered to this |
135 | // arg before it. |
136 | TypeID entryTypeID = pass->getTypeID(); |
137 | auto it = passRegistryTypeIDs->try_emplace(Key: arg, Args&: entryTypeID).first; |
138 | if (it->second != entryTypeID) |
139 | llvm::report_fatal_error( |
140 | reason: "pass allocator creates a different pass than previously " |
141 | "registered for pass " + |
142 | arg); |
143 | } |
144 | |
145 | /// Returns the pass info for the specified pass argument or null if unknown. |
146 | const PassInfo *mlir::PassInfo::lookup(StringRef passArg) { |
147 | auto it = passRegistry->find(Key: passArg); |
148 | return it == passRegistry->end() ? nullptr : &it->second; |
149 | } |
150 | |
151 | /// Returns the pass pipeline info for the specified pass pipeline argument or |
152 | /// null if unknown. |
153 | const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) { |
154 | auto it = passPipelineRegistry->find(Key: pipelineArg); |
155 | return it == passPipelineRegistry->end() ? nullptr : &it->second; |
156 | } |
157 | |
158 | //===----------------------------------------------------------------------===// |
159 | // PassOptions |
160 | //===----------------------------------------------------------------------===// |
161 | |
162 | LogicalResult detail::pass_options::parseCommaSeparatedList( |
163 | llvm::cl::Option &opt, StringRef argName, StringRef optionStr, |
164 | function_ref<LogicalResult(StringRef)> elementParseFn) { |
165 | // Functor used for finding a character in a string, and skipping over |
166 | // various "range" characters. |
167 | llvm::unique_function<size_t(StringRef, size_t, char)> findChar = |
168 | [&](StringRef str, size_t index, char c) -> size_t { |
169 | for (size_t i = index, e = str.size(); i < e; ++i) { |
170 | if (str[i] == c) |
171 | return i; |
172 | // Check for various range characters. |
173 | if (str[i] == '{') |
174 | i = findChar(str, i + 1, '}'); |
175 | else if (str[i] == '(') |
176 | i = findChar(str, i + 1, ')'); |
177 | else if (str[i] == '[') |
178 | i = findChar(str, i + 1, ']'); |
179 | else if (str[i] == '\"') |
180 | i = str.find_first_of(C: '\"', From: i + 1); |
181 | else if (str[i] == '\'') |
182 | i = str.find_first_of(C: '\'', From: i + 1); |
183 | } |
184 | return StringRef::npos; |
185 | }; |
186 | |
187 | size_t nextElePos = findChar(optionStr, 0, ','); |
188 | while (nextElePos != StringRef::npos) { |
189 | // Process the portion before the comma. |
190 | if (failed(result: elementParseFn(optionStr.substr(Start: 0, N: nextElePos)))) |
191 | return failure(); |
192 | |
193 | optionStr = optionStr.substr(Start: nextElePos + 1); |
194 | nextElePos = findChar(optionStr, 0, ','); |
195 | } |
196 | return elementParseFn(optionStr.substr(Start: 0, N: nextElePos)); |
197 | } |
198 | |
199 | /// Out of line virtual function to provide home for the class. |
200 | void detail::PassOptions::OptionBase::anchor() {} |
201 | |
202 | /// Copy the option values from 'other'. |
203 | void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { |
204 | assert(options.size() == other.options.size()); |
205 | if (options.empty()) |
206 | return; |
207 | for (auto optionsIt : llvm::zip(t&: options, u: other.options)) |
208 | std::get<0>(t&: optionsIt)->copyValueFrom(other: *std::get<1>(t&: optionsIt)); |
209 | } |
210 | |
211 | /// Parse in the next argument from the given options string. Returns a tuple |
212 | /// containing [the key of the option, the value of the option, updated |
213 | /// `options` string pointing after the parsed option]. |
214 | static std::tuple<StringRef, StringRef, StringRef> |
215 | parseNextArg(StringRef options) { |
216 | // Functor used to extract an argument from 'options' and update it to point |
217 | // after the arg. |
218 | auto extractArgAndUpdateOptions = [&](size_t argSize) { |
219 | StringRef str = options.take_front(N: argSize).trim(); |
220 | options = options.drop_front(N: argSize).ltrim(); |
221 | return str; |
222 | }; |
223 | // Try to process the given punctuation, properly escaping any contained |
224 | // characters. |
225 | auto tryProcessPunct = [&](size_t ¤tPos, char punct) { |
226 | if (options[currentPos] != punct) |
227 | return false; |
228 | size_t nextIt = options.find_first_of(C: punct, From: currentPos + 1); |
229 | if (nextIt != StringRef::npos) |
230 | currentPos = nextIt; |
231 | return true; |
232 | }; |
233 | |
234 | // Parse the argument name of the option. |
235 | StringRef argName; |
236 | for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { |
237 | // Check for the end of the full option. |
238 | if (argEndIt == optionsE || options[argEndIt] == ' ') { |
239 | argName = extractArgAndUpdateOptions(argEndIt); |
240 | return std::make_tuple(args&: argName, args: StringRef(), args&: options); |
241 | } |
242 | |
243 | // Check for the end of the name and the start of the value. |
244 | if (options[argEndIt] == '=') { |
245 | argName = extractArgAndUpdateOptions(argEndIt); |
246 | options = options.drop_front(); |
247 | break; |
248 | } |
249 | } |
250 | |
251 | // Parse the value of the option. |
252 | for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { |
253 | // Handle the end of the options string. |
254 | if (argEndIt == optionsE || options[argEndIt] == ' ') { |
255 | StringRef value = extractArgAndUpdateOptions(argEndIt); |
256 | return std::make_tuple(args&: argName, args&: value, args&: options); |
257 | } |
258 | |
259 | // Skip over escaped sequences. |
260 | char c = options[argEndIt]; |
261 | if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"')) |
262 | continue; |
263 | // '{...}' is used to specify options to passes, properly escape it so |
264 | // that we don't accidentally split any nested options. |
265 | if (c == '{') { |
266 | size_t braceCount = 1; |
267 | for (++argEndIt; argEndIt != optionsE; ++argEndIt) { |
268 | // Allow nested punctuation. |
269 | if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"')) |
270 | continue; |
271 | if (options[argEndIt] == '{') |
272 | ++braceCount; |
273 | else if (options[argEndIt] == '}' && --braceCount == 0) |
274 | break; |
275 | } |
276 | // Account for the increment at the top of the loop. |
277 | --argEndIt; |
278 | } |
279 | } |
280 | llvm_unreachable("unexpected control flow in pass option parsing" ); |
281 | } |
282 | |
283 | LogicalResult detail::PassOptions::parseFromString(StringRef options, |
284 | raw_ostream &errorStream) { |
285 | // NOTE: `options` is modified in place to always refer to the unprocessed |
286 | // part of the string. |
287 | while (!options.empty()) { |
288 | StringRef key, value; |
289 | std::tie(args&: key, args&: value, args&: options) = parseNextArg(options); |
290 | if (key.empty()) |
291 | continue; |
292 | |
293 | auto it = OptionsMap.find(Key: key); |
294 | if (it == OptionsMap.end()) { |
295 | errorStream << "<Pass-Options-Parser>: no such option " << key << "\n" ; |
296 | return failure(); |
297 | } |
298 | if (llvm::cl::ProvidePositionalOption(Handler: it->second, Arg: value, i: 0)) |
299 | return failure(); |
300 | } |
301 | |
302 | return success(); |
303 | } |
304 | |
305 | /// Print the options held by this struct in a form that can be parsed via |
306 | /// 'parseFromString'. |
307 | void detail::PassOptions::print(raw_ostream &os) { |
308 | // If there are no options, there is nothing left to do. |
309 | if (OptionsMap.empty()) |
310 | return; |
311 | |
312 | // Sort the options to make the ordering deterministic. |
313 | SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); |
314 | auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { |
315 | return (*lhs)->getArgStr().compare(RHS: (*rhs)->getArgStr()); |
316 | }; |
317 | llvm::array_pod_sort(Start: orderedOps.begin(), End: orderedOps.end(), Compare: compareOptionArgs); |
318 | |
319 | // Interleave the options with ' '. |
320 | os << '{'; |
321 | llvm::interleave( |
322 | c: orderedOps, os, each_fn: [&](OptionBase *option) { option->print(os); }, separator: " " ); |
323 | os << '}'; |
324 | } |
325 | |
326 | /// Print the help string for the options held by this struct. `descIndent` is |
327 | /// the indent within the stream that the descriptions should be aligned. |
328 | void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const { |
329 | // Sort the options to make the ordering deterministic. |
330 | SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); |
331 | auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { |
332 | return (*lhs)->getArgStr().compare(RHS: (*rhs)->getArgStr()); |
333 | }; |
334 | llvm::array_pod_sort(Start: orderedOps.begin(), End: orderedOps.end(), Compare: compareOptionArgs); |
335 | for (OptionBase *option : orderedOps) { |
336 | // TODO: printOptionInfo assumes a specific indent and will |
337 | // print options with values with incorrect indentation. We should add |
338 | // support to llvm::cl::Option for passing in a base indent to use when |
339 | // printing. |
340 | llvm::outs().indent(NumSpaces: indent); |
341 | option->getOption()->printOptionInfo(GlobalWidth: descIndent - indent); |
342 | } |
343 | } |
344 | |
345 | /// Return the maximum width required when printing the help string. |
346 | size_t detail::PassOptions::getOptionWidth() const { |
347 | size_t max = 0; |
348 | for (auto *option : options) |
349 | max = std::max(a: max, b: option->getOption()->getOptionWidth()); |
350 | return max; |
351 | } |
352 | |
353 | //===----------------------------------------------------------------------===// |
354 | // MLIR Options |
355 | //===----------------------------------------------------------------------===// |
356 | |
357 | //===----------------------------------------------------------------------===// |
358 | // OpPassManager: OptionValue |
359 | |
360 | llvm::cl::OptionValue<OpPassManager>::OptionValue() = default; |
361 | llvm::cl::OptionValue<OpPassManager>::OptionValue( |
362 | const mlir::OpPassManager &value) { |
363 | setValue(value); |
364 | } |
365 | llvm::cl::OptionValue<OpPassManager>::OptionValue( |
366 | const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) { |
367 | if (rhs.hasValue()) |
368 | setValue(rhs.getValue()); |
369 | } |
370 | llvm::cl::OptionValue<OpPassManager> & |
371 | llvm::cl::OptionValue<OpPassManager>::operator=( |
372 | const mlir::OpPassManager &rhs) { |
373 | setValue(rhs); |
374 | return *this; |
375 | } |
376 | |
377 | llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default; |
378 | |
379 | void llvm::cl::OptionValue<OpPassManager>::setValue( |
380 | const OpPassManager &newValue) { |
381 | if (hasValue()) |
382 | *value = newValue; |
383 | else |
384 | value = std::make_unique<mlir::OpPassManager>(args: newValue); |
385 | } |
386 | void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) { |
387 | FailureOr<OpPassManager> pipeline = parsePassPipeline(pipeline: pipelineStr); |
388 | assert(succeeded(pipeline) && "invalid pass pipeline" ); |
389 | setValue(*pipeline); |
390 | } |
391 | |
392 | bool llvm::cl::OptionValue<OpPassManager>::compare( |
393 | const mlir::OpPassManager &rhs) const { |
394 | std::string lhsStr, rhsStr; |
395 | { |
396 | raw_string_ostream lhsStream(lhsStr); |
397 | value->printAsTextualPipeline(os&: lhsStream); |
398 | |
399 | raw_string_ostream rhsStream(rhsStr); |
400 | rhs.printAsTextualPipeline(os&: rhsStream); |
401 | } |
402 | |
403 | // Use the textual format for pipeline comparisons. |
404 | return lhsStr == rhsStr; |
405 | } |
406 | |
407 | void llvm::cl::OptionValue<OpPassManager>::anchor() {} |
408 | |
409 | //===----------------------------------------------------------------------===// |
410 | // OpPassManager: Parser |
411 | |
412 | namespace llvm { |
413 | namespace cl { |
414 | template class basic_parser<OpPassManager>; |
415 | } // namespace cl |
416 | } // namespace llvm |
417 | |
418 | bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg, |
419 | ParsedPassManager &value) { |
420 | FailureOr<OpPassManager> pipeline = parsePassPipeline(pipeline: arg); |
421 | if (failed(result: pipeline)) |
422 | return true; |
423 | value.value = std::make_unique<OpPassManager>(args: std::move(*pipeline)); |
424 | return false; |
425 | } |
426 | |
427 | void llvm::cl::parser<OpPassManager>::print(raw_ostream &os, |
428 | const OpPassManager &value) { |
429 | value.printAsTextualPipeline(os); |
430 | } |
431 | |
432 | void llvm::cl::parser<OpPassManager>::printOptionDiff( |
433 | const Option &opt, OpPassManager &pm, const OptVal &defaultValue, |
434 | size_t globalWidth) const { |
435 | printOptionName(O: opt, GlobalWidth: globalWidth); |
436 | outs() << "= " ; |
437 | pm.printAsTextualPipeline(os&: outs()); |
438 | |
439 | if (defaultValue.hasValue()) { |
440 | outs().indent(NumSpaces: 2) << " (default: " ; |
441 | defaultValue.getValue().printAsTextualPipeline(os&: outs()); |
442 | outs() << ")" ; |
443 | } |
444 | outs() << "\n" ; |
445 | } |
446 | |
447 | void llvm::cl::parser<OpPassManager>::anchor() {} |
448 | |
449 | llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() = |
450 | default; |
451 | llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager( |
452 | ParsedPassManager &&) = default; |
453 | llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() = |
454 | default; |
455 | |
456 | //===----------------------------------------------------------------------===// |
457 | // TextualPassPipeline Parser |
458 | //===----------------------------------------------------------------------===// |
459 | |
460 | namespace { |
461 | /// This class represents a textual description of a pass pipeline. |
462 | class TextualPipeline { |
463 | public: |
464 | /// Try to initialize this pipeline with the given pipeline text. |
465 | /// `errorStream` is the output stream to emit errors to. |
466 | LogicalResult initialize(StringRef text, raw_ostream &errorStream); |
467 | |
468 | /// Add the internal pipeline elements to the provided pass manager. |
469 | LogicalResult |
470 | addToPipeline(OpPassManager &pm, |
471 | function_ref<LogicalResult(const Twine &)> errorHandler) const; |
472 | |
473 | private: |
474 | /// A functor used to emit errors found during pipeline handling. The first |
475 | /// parameter corresponds to the raw location within the pipeline string. This |
476 | /// should always return failure. |
477 | using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>; |
478 | |
479 | /// A struct to capture parsed pass pipeline names. |
480 | /// |
481 | /// A pipeline is defined as a series of names, each of which may in itself |
482 | /// recursively contain a nested pipeline. A name is either the name of a pass |
483 | /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If |
484 | /// the name is the name of a pass, the InnerPipeline is empty, since passes |
485 | /// cannot contain inner pipelines. |
486 | struct PipelineElement { |
487 | PipelineElement(StringRef name) : name(name) {} |
488 | |
489 | StringRef name; |
490 | StringRef options; |
491 | const PassRegistryEntry *registryEntry = nullptr; |
492 | std::vector<PipelineElement> innerPipeline; |
493 | }; |
494 | |
495 | /// Parse the given pipeline text into the internal pipeline vector. This |
496 | /// function only parses the structure of the pipeline, and does not resolve |
497 | /// its elements. |
498 | LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler); |
499 | |
500 | /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to |
501 | /// the corresponding registry entry. |
502 | LogicalResult |
503 | resolvePipelineElements(MutableArrayRef<PipelineElement> elements, |
504 | ErrorHandlerT errorHandler); |
505 | |
506 | /// Resolve a single element of the pipeline. |
507 | LogicalResult resolvePipelineElement(PipelineElement &element, |
508 | ErrorHandlerT errorHandler); |
509 | |
510 | /// Add the given pipeline elements to the provided pass manager. |
511 | LogicalResult |
512 | addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm, |
513 | function_ref<LogicalResult(const Twine &)> errorHandler) const; |
514 | |
515 | std::vector<PipelineElement> pipeline; |
516 | }; |
517 | |
518 | } // namespace |
519 | |
520 | /// Try to initialize this pipeline with the given pipeline text. An option is |
521 | /// given to enable accurate error reporting. |
522 | LogicalResult TextualPipeline::initialize(StringRef text, |
523 | raw_ostream &errorStream) { |
524 | if (text.empty()) |
525 | return success(); |
526 | |
527 | // Build a source manager to use for error reporting. |
528 | llvm::SourceMgr pipelineMgr; |
529 | pipelineMgr.AddNewSourceBuffer( |
530 | F: llvm::MemoryBuffer::getMemBuffer(InputData: text, BufferName: "MLIR Textual PassPipeline Parser" , |
531 | /*RequiresNullTerminator=*/false), |
532 | IncludeLoc: SMLoc()); |
533 | auto errorHandler = [&](const char *rawLoc, Twine msg) { |
534 | pipelineMgr.PrintMessage(OS&: errorStream, Loc: SMLoc::getFromPointer(Ptr: rawLoc), |
535 | Kind: llvm::SourceMgr::DK_Error, Msg: msg); |
536 | return failure(); |
537 | }; |
538 | |
539 | // Parse the provided pipeline string. |
540 | if (failed(result: parsePipelineText(text, errorHandler))) |
541 | return failure(); |
542 | return resolvePipelineElements(elements: pipeline, errorHandler); |
543 | } |
544 | |
545 | /// Add the internal pipeline elements to the provided pass manager. |
546 | LogicalResult TextualPipeline::addToPipeline( |
547 | OpPassManager &pm, |
548 | function_ref<LogicalResult(const Twine &)> errorHandler) const { |
549 | // Temporarily disable implicit nesting while we append to the pipeline. We |
550 | // want the created pipeline to exactly match the parsed text pipeline, so |
551 | // it's preferrable to just error out if implicit nesting would be required. |
552 | OpPassManager::Nesting nesting = pm.getNesting(); |
553 | pm.setNesting(OpPassManager::Nesting::Explicit); |
554 | auto restore = llvm::make_scope_exit(F: [&]() { pm.setNesting(nesting); }); |
555 | |
556 | return addToPipeline(elements: pipeline, pm, errorHandler); |
557 | } |
558 | |
559 | /// Parse the given pipeline text into the internal pipeline vector. This |
560 | /// function only parses the structure of the pipeline, and does not resolve |
561 | /// its elements. |
562 | LogicalResult TextualPipeline::parsePipelineText(StringRef text, |
563 | ErrorHandlerT errorHandler) { |
564 | SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline}; |
565 | for (;;) { |
566 | std::vector<PipelineElement> &pipeline = *pipelineStack.back(); |
567 | size_t pos = text.find_first_of(Chars: ",(){" ); |
568 | pipeline.emplace_back(/*name=*/args: text.substr(Start: 0, N: pos).trim()); |
569 | |
570 | // If we have a single terminating name, we're done. |
571 | if (pos == StringRef::npos) |
572 | break; |
573 | |
574 | text = text.substr(Start: pos); |
575 | char sep = text[0]; |
576 | |
577 | // Handle pulling ... from 'pass{...}' out as PipelineElement.options. |
578 | if (sep == '{') { |
579 | text = text.substr(Start: 1); |
580 | |
581 | // Skip over everything until the closing '}' and store as options. |
582 | size_t close = StringRef::npos; |
583 | for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) { |
584 | if (text[i] == '{') { |
585 | ++braceCount; |
586 | continue; |
587 | } |
588 | if (text[i] == '}' && --braceCount == 0) { |
589 | close = i; |
590 | break; |
591 | } |
592 | } |
593 | |
594 | // Check to see if a closing options brace was found. |
595 | if (close == StringRef::npos) { |
596 | return errorHandler( |
597 | /*rawLoc=*/text.data() - 1, |
598 | "missing closing '}' while processing pass options" ); |
599 | } |
600 | pipeline.back().options = text.substr(Start: 0, N: close); |
601 | text = text.substr(Start: close + 1); |
602 | |
603 | // Consume space characters that an user might add for readability. |
604 | text = text.ltrim(); |
605 | |
606 | // Skip checking for '(' because nested pipelines cannot have options. |
607 | } else if (sep == '(') { |
608 | text = text.substr(Start: 1); |
609 | |
610 | // Push the inner pipeline onto the stack to continue processing. |
611 | pipelineStack.push_back(Elt: &pipeline.back().innerPipeline); |
612 | continue; |
613 | } |
614 | |
615 | // When handling the close parenthesis, we greedily consume them to avoid |
616 | // empty strings in the pipeline. |
617 | while (text.consume_front(Prefix: ")" )) { |
618 | // If we try to pop the outer pipeline we have unbalanced parentheses. |
619 | if (pipelineStack.size() == 1) |
620 | return errorHandler(/*rawLoc=*/text.data() - 1, |
621 | "encountered extra closing ')' creating unbalanced " |
622 | "parentheses while parsing pipeline" ); |
623 | |
624 | pipelineStack.pop_back(); |
625 | // Consume space characters that an user might add for readability. |
626 | text = text.ltrim(); |
627 | } |
628 | |
629 | // Check if we've finished parsing. |
630 | if (text.empty()) |
631 | break; |
632 | |
633 | // Otherwise, the end of an inner pipeline always has to be followed by |
634 | // a comma, and then we can continue. |
635 | if (!text.consume_front(Prefix: "," )) |
636 | return errorHandler(text.data(), "expected ',' after parsing pipeline" ); |
637 | } |
638 | |
639 | // Check for unbalanced parentheses. |
640 | if (pipelineStack.size() > 1) |
641 | return errorHandler( |
642 | text.data(), |
643 | "encountered unbalanced parentheses while parsing pipeline" ); |
644 | |
645 | assert(pipelineStack.back() == &pipeline && |
646 | "wrong pipeline at the bottom of the stack" ); |
647 | return success(); |
648 | } |
649 | |
650 | /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to |
651 | /// the corresponding registry entry. |
652 | LogicalResult TextualPipeline::resolvePipelineElements( |
653 | MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) { |
654 | for (auto &elt : elements) |
655 | if (failed(result: resolvePipelineElement(element&: elt, errorHandler))) |
656 | return failure(); |
657 | return success(); |
658 | } |
659 | |
660 | /// Resolve a single element of the pipeline. |
661 | LogicalResult |
662 | TextualPipeline::resolvePipelineElement(PipelineElement &element, |
663 | ErrorHandlerT errorHandler) { |
664 | // If the inner pipeline of this element is not empty, this is an operation |
665 | // pipeline. |
666 | if (!element.innerPipeline.empty()) |
667 | return resolvePipelineElements(elements: element.innerPipeline, errorHandler); |
668 | |
669 | // Otherwise, this must be a pass or pass pipeline. |
670 | // Check to see if a pipeline was registered with this name. |
671 | if ((element.registryEntry = PassPipelineInfo::lookup(pipelineArg: element.name))) |
672 | return success(); |
673 | |
674 | // If not, then this must be a specific pass name. |
675 | if ((element.registryEntry = PassInfo::lookup(passArg: element.name))) |
676 | return success(); |
677 | |
678 | // Emit an error for the unknown pass. |
679 | auto *rawLoc = element.name.data(); |
680 | return errorHandler(rawLoc, "'" + element.name + |
681 | "' does not refer to a " |
682 | "registered pass or pass pipeline" ); |
683 | } |
684 | |
685 | /// Add the given pipeline elements to the provided pass manager. |
686 | LogicalResult TextualPipeline::addToPipeline( |
687 | ArrayRef<PipelineElement> elements, OpPassManager &pm, |
688 | function_ref<LogicalResult(const Twine &)> errorHandler) const { |
689 | for (auto &elt : elements) { |
690 | if (elt.registryEntry) { |
691 | if (failed(result: elt.registryEntry->addToPipeline(pm, options: elt.options, |
692 | errorHandler))) { |
693 | return errorHandler("failed to add `" + elt.name + "` with options `" + |
694 | elt.options + "`" ); |
695 | } |
696 | } else if (failed(result: addToPipeline(elements: elt.innerPipeline, pm&: pm.nest(nestedName: elt.name), |
697 | errorHandler))) { |
698 | return errorHandler("failed to add `" + elt.name + "` with options `" + |
699 | elt.options + "` to inner pipeline" ); |
700 | } |
701 | } |
702 | return success(); |
703 | } |
704 | |
705 | LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, |
706 | raw_ostream &errorStream) { |
707 | TextualPipeline pipelineParser; |
708 | if (failed(result: pipelineParser.initialize(text: pipeline, errorStream))) |
709 | return failure(); |
710 | auto errorHandler = [&](Twine msg) { |
711 | errorStream << msg << "\n" ; |
712 | return failure(); |
713 | }; |
714 | if (failed(result: pipelineParser.addToPipeline(pm, errorHandler))) |
715 | return failure(); |
716 | return success(); |
717 | } |
718 | |
719 | FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline, |
720 | raw_ostream &errorStream) { |
721 | pipeline = pipeline.trim(); |
722 | // Pipelines are expected to be of the form `<op-name>(<pipeline>)`. |
723 | size_t pipelineStart = pipeline.find_first_of(C: '('); |
724 | if (pipelineStart == 0 || pipelineStart == StringRef::npos || |
725 | !pipeline.consume_back(Suffix: ")" )) { |
726 | errorStream << "expected pass pipeline to be wrapped with the anchor " |
727 | "operation type, e.g. 'builtin.module(...)'" ; |
728 | return failure(); |
729 | } |
730 | |
731 | StringRef opName = pipeline.take_front(N: pipelineStart).rtrim(); |
732 | OpPassManager pm(opName); |
733 | if (failed(result: parsePassPipeline(pipeline: pipeline.drop_front(N: 1 + pipelineStart), pm, |
734 | errorStream))) |
735 | return failure(); |
736 | return pm; |
737 | } |
738 | |
739 | //===----------------------------------------------------------------------===// |
740 | // PassNameParser |
741 | //===----------------------------------------------------------------------===// |
742 | |
743 | namespace { |
744 | /// This struct represents the possible data entries in a parsed pass pipeline |
745 | /// list. |
746 | struct PassArgData { |
747 | PassArgData() = default; |
748 | PassArgData(const PassRegistryEntry *registryEntry) |
749 | : registryEntry(registryEntry) {} |
750 | |
751 | /// This field is used when the parsed option corresponds to a registered pass |
752 | /// or pass pipeline. |
753 | const PassRegistryEntry *registryEntry{nullptr}; |
754 | |
755 | /// This field is set when instance specific pass options have been provided |
756 | /// on the command line. |
757 | StringRef options; |
758 | }; |
759 | } // namespace |
760 | |
761 | namespace llvm { |
762 | namespace cl { |
763 | /// Define a valid OptionValue for the command line pass argument. |
764 | template <> |
765 | struct OptionValue<PassArgData> final |
766 | : OptionValueBase<PassArgData, /*isClass=*/true> { |
767 | OptionValue(const PassArgData &value) { this->setValue(value); } |
768 | OptionValue() = default; |
769 | void anchor() override {} |
770 | |
771 | bool hasValue() const { return true; } |
772 | const PassArgData &getValue() const { return value; } |
773 | void setValue(const PassArgData &value) { this->value = value; } |
774 | |
775 | PassArgData value; |
776 | }; |
777 | } // namespace cl |
778 | } // namespace llvm |
779 | |
780 | namespace { |
781 | |
782 | /// The name for the command line option used for parsing the textual pass |
783 | /// pipeline. |
784 | #define PASS_PIPELINE_ARG "pass-pipeline" |
785 | |
786 | /// Adds command line option for each registered pass or pass pipeline, as well |
787 | /// as textual pass pipelines. |
788 | struct PassNameParser : public llvm::cl::parser<PassArgData> { |
789 | PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {} |
790 | |
791 | void initialize(); |
792 | void printOptionInfo(const llvm::cl::Option &opt, |
793 | size_t globalWidth) const override; |
794 | size_t getOptionWidth(const llvm::cl::Option &opt) const override; |
795 | bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, |
796 | PassArgData &value); |
797 | |
798 | /// If true, this parser only parses entries that correspond to a concrete |
799 | /// pass registry entry, and does not include pipeline entries or the options |
800 | /// for pass entries. |
801 | bool passNamesOnly = false; |
802 | }; |
803 | } // namespace |
804 | |
805 | void PassNameParser::initialize() { |
806 | llvm::cl::parser<PassArgData>::initialize(); |
807 | |
808 | /// Add the pass entries. |
809 | for (const auto &kv : *passRegistry) { |
810 | addLiteralOption(Name: kv.second.getPassArgument(), V: &kv.second, |
811 | HelpStr: kv.second.getPassDescription()); |
812 | } |
813 | /// Add the pass pipeline entries. |
814 | if (!passNamesOnly) { |
815 | for (const auto &kv : *passPipelineRegistry) { |
816 | addLiteralOption(Name: kv.second.getPassArgument(), V: &kv.second, |
817 | HelpStr: kv.second.getPassDescription()); |
818 | } |
819 | } |
820 | } |
821 | |
822 | void PassNameParser::printOptionInfo(const llvm::cl::Option &opt, |
823 | size_t globalWidth) const { |
824 | // If this parser is just parsing pass names, print a simplified option |
825 | // string. |
826 | if (passNamesOnly) { |
827 | llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>" ; |
828 | opt.printHelpStr(HelpStr: opt.HelpStr, Indent: globalWidth, FirstLineIndentedBy: opt.ArgStr.size() + 18); |
829 | return; |
830 | } |
831 | |
832 | // Print the information for the top-level option. |
833 | if (opt.hasArgStr()) { |
834 | llvm::outs() << " --" << opt.ArgStr; |
835 | opt.printHelpStr(HelpStr: opt.HelpStr, Indent: globalWidth, FirstLineIndentedBy: opt.ArgStr.size() + 7); |
836 | } else { |
837 | llvm::outs() << " " << opt.HelpStr << '\n'; |
838 | } |
839 | |
840 | // Functor used to print the ordered entries of a registration map. |
841 | auto printOrderedEntries = [&](StringRef , auto &map) { |
842 | llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries; |
843 | for (auto &kv : map) |
844 | orderedEntries.push_back(Elt: &kv.second); |
845 | llvm::array_pod_sort( |
846 | orderedEntries.begin(), orderedEntries.end(), |
847 | [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { |
848 | return (*lhs)->getPassArgument().compare(RHS: (*rhs)->getPassArgument()); |
849 | }); |
850 | |
851 | llvm::outs().indent(NumSpaces: 4) << header << ":\n" ; |
852 | for (PassRegistryEntry *entry : orderedEntries) |
853 | entry->printHelpStr(/*indent=*/6, descIndent: globalWidth); |
854 | }; |
855 | |
856 | // Print the available passes. |
857 | printOrderedEntries("Passes" , *passRegistry); |
858 | |
859 | // Print the available pass pipelines. |
860 | if (!passPipelineRegistry->empty()) |
861 | printOrderedEntries("Pass Pipelines" , *passPipelineRegistry); |
862 | } |
863 | |
864 | size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const { |
865 | size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(O: opt) + 2; |
866 | |
867 | // Check for any wider pass or pipeline options. |
868 | for (auto &entry : *passRegistry) |
869 | maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4); |
870 | for (auto &entry : *passPipelineRegistry) |
871 | maxWidth = std::max(a: maxWidth, b: entry.second.getOptionWidth() + 4); |
872 | return maxWidth; |
873 | } |
874 | |
875 | bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName, |
876 | StringRef arg, PassArgData &value) { |
877 | if (llvm::cl::parser<PassArgData>::parse(O&: opt, ArgName: argName, Arg: arg, V&: value)) |
878 | return true; |
879 | value.options = arg; |
880 | return false; |
881 | } |
882 | |
883 | //===----------------------------------------------------------------------===// |
884 | // PassPipelineCLParser |
885 | //===----------------------------------------------------------------------===// |
886 | |
887 | namespace mlir { |
888 | namespace detail { |
889 | struct PassPipelineCLParserImpl { |
890 | PassPipelineCLParserImpl(StringRef arg, StringRef description, |
891 | bool passNamesOnly) |
892 | : passList(arg, llvm::cl::desc(description)) { |
893 | passList.getParser().passNamesOnly = passNamesOnly; |
894 | passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); |
895 | } |
896 | |
897 | /// Returns true if the given pass registry entry was registered at the |
898 | /// top-level of the parser, i.e. not within an explicit textual pipeline. |
899 | bool contains(const PassRegistryEntry *entry) const { |
900 | return llvm::any_of(Range: passList, P: [&](const PassArgData &data) { |
901 | return data.registryEntry == entry; |
902 | }); |
903 | } |
904 | |
905 | /// The set of passes and pass pipelines to run. |
906 | llvm::cl::list<PassArgData, bool, PassNameParser> passList; |
907 | }; |
908 | } // namespace detail |
909 | } // namespace mlir |
910 | |
911 | /// Construct a pass pipeline parser with the given command line description. |
912 | PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) |
913 | : impl(std::make_unique<detail::PassPipelineCLParserImpl>( |
914 | args&: arg, args&: description, /*passNamesOnly=*/args: false)), |
915 | passPipeline( |
916 | PASS_PIPELINE_ARG, |
917 | llvm::cl::desc("Textual description of the pass pipeline to run" )) {} |
918 | |
919 | PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description, |
920 | StringRef alias) |
921 | : PassPipelineCLParser(arg, description) { |
922 | passPipelineAlias.emplace(args&: alias, |
923 | args: llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG), |
924 | args: llvm::cl::aliasopt(passPipeline)); |
925 | } |
926 | |
927 | PassPipelineCLParser::~PassPipelineCLParser() = default; |
928 | |
929 | /// Returns true if this parser contains any valid options to add. |
930 | bool PassPipelineCLParser::hasAnyOccurrences() const { |
931 | return passPipeline.getNumOccurrences() != 0 || |
932 | impl->passList.getNumOccurrences() != 0; |
933 | } |
934 | |
935 | /// Returns true if the given pass registry entry was registered at the |
936 | /// top-level of the parser, i.e. not within an explicit textual pipeline. |
937 | bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { |
938 | return impl->contains(entry); |
939 | } |
940 | |
941 | /// Adds the passes defined by this parser entry to the given pass manager. |
942 | LogicalResult PassPipelineCLParser::addToPipeline( |
943 | OpPassManager &pm, |
944 | function_ref<LogicalResult(const Twine &)> errorHandler) const { |
945 | if (passPipeline.getNumOccurrences()) { |
946 | if (impl->passList.getNumOccurrences()) |
947 | return errorHandler( |
948 | "'-" PASS_PIPELINE_ARG |
949 | "' option can't be used with individual pass options" ); |
950 | std::string errMsg; |
951 | llvm::raw_string_ostream os(errMsg); |
952 | FailureOr<OpPassManager> parsed = parsePassPipeline(pipeline: passPipeline, errorStream&: os); |
953 | if (failed(result: parsed)) |
954 | return errorHandler(errMsg); |
955 | pm = std::move(*parsed); |
956 | return success(); |
957 | } |
958 | |
959 | for (auto &passIt : impl->passList) { |
960 | if (failed(result: passIt.registryEntry->addToPipeline(pm, options: passIt.options, |
961 | errorHandler))) |
962 | return failure(); |
963 | } |
964 | return success(); |
965 | } |
966 | |
967 | //===----------------------------------------------------------------------===// |
968 | // PassNameCLParser |
969 | |
970 | /// Construct a pass pipeline parser with the given command line description. |
971 | PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description) |
972 | : impl(std::make_unique<detail::PassPipelineCLParserImpl>( |
973 | args&: arg, args&: description, /*passNamesOnly=*/args: true)) { |
974 | impl->passList.setMiscFlag(llvm::cl::CommaSeparated); |
975 | } |
976 | PassNameCLParser::~PassNameCLParser() = default; |
977 | |
978 | /// Returns true if this parser contains any valid options to add. |
979 | bool PassNameCLParser::hasAnyOccurrences() const { |
980 | return impl->passList.getNumOccurrences() != 0; |
981 | } |
982 | |
983 | /// Returns true if the given pass registry entry was registered at the |
984 | /// top-level of the parser, i.e. not within an explicit textual pipeline. |
985 | bool PassNameCLParser::contains(const PassRegistryEntry *entry) const { |
986 | return impl->contains(entry); |
987 | } |
988 | |