1//===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is a utility that runs an optimization pass and prints the result back
10// out. It is designed to support unit testing.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Tools/mlir-opt/MlirOptMain.h"
15#include "mlir/Bytecode/BytecodeWriter.h"
16#include "mlir/Debug/CLOptionsSetup.h"
17#include "mlir/Debug/Counter.h"
18#include "mlir/Debug/DebuggerExecutionContextHook.h"
19#include "mlir/Debug/ExecutionContext.h"
20#include "mlir/Debug/Observers/ActionLogging.h"
21#include "mlir/Dialect/IRDL/IR/IRDL.h"
22#include "mlir/Dialect/IRDL/IRDLLoading.h"
23#include "mlir/IR/AsmState.h"
24#include "mlir/IR/Attributes.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/Dialect.h"
28#include "mlir/IR/Location.h"
29#include "mlir/IR/MLIRContext.h"
30#include "mlir/Parser/Parser.h"
31#include "mlir/Pass/Pass.h"
32#include "mlir/Pass/PassManager.h"
33#include "mlir/Pass/PassRegistry.h"
34#include "mlir/Support/FileUtilities.h"
35#include "mlir/Support/Timing.h"
36#include "mlir/Support/ToolUtilities.h"
37#include "mlir/Tools/ParseUtilities.h"
38#include "mlir/Tools/Plugins/DialectPlugin.h"
39#include "mlir/Tools/Plugins/PassPlugin.h"
40#include "llvm/ADT/StringRef.h"
41#include "llvm/Support/CommandLine.h"
42#include "llvm/Support/FileUtilities.h"
43#include "llvm/Support/InitLLVM.h"
44#include "llvm/Support/LogicalResult.h"
45#include "llvm/Support/ManagedStatic.h"
46#include "llvm/Support/Process.h"
47#include "llvm/Support/Regex.h"
48#include "llvm/Support/SourceMgr.h"
49#include "llvm/Support/StringSaver.h"
50#include "llvm/Support/ThreadPool.h"
51#include "llvm/Support/ToolOutputFile.h"
52
53using namespace mlir;
54using namespace llvm;
55
56namespace {
57class BytecodeVersionParser : public cl::parser<std::optional<int64_t>> {
58public:
59 BytecodeVersionParser(cl::Option &o)
60 : cl::parser<std::optional<int64_t>>(o) {}
61
62 bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg,
63 std::optional<int64_t> &v) {
64 long long w;
65 if (getAsSignedInteger(Str: arg, Radix: 10, Result&: w))
66 return o.error(Message: "Invalid argument '" + arg +
67 "', only integer is supported.");
68 v = w;
69 return false;
70 }
71};
72
73/// This class is intended to manage the handling of command line options for
74/// creating a *-opt config. This is a singleton.
75struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
76 MlirOptMainConfigCLOptions() {
77 // These options are static but all uses ExternalStorage to initialize the
78 // members of the parent class. This is unusual but since this class is a
79 // singleton it basically attaches command line option to the singleton
80 // members.
81
82 static cl::opt<bool, /*ExternalStorage=*/true> allowUnregisteredDialects(
83 "allow-unregistered-dialect",
84 cl::desc("Allow operation with no registered dialects"),
85 cl::location(L&: allowUnregisteredDialectsFlag), cl::init(Val: false));
86
87 static cl::opt<bool, /*ExternalStorage=*/true> dumpPassPipeline(
88 "dump-pass-pipeline", cl::desc("Print the pipeline that will be run"),
89 cl::location(L&: dumpPassPipelineFlag), cl::init(Val: false));
90
91 static cl::opt<bool, /*ExternalStorage=*/true> emitBytecode(
92 "emit-bytecode", cl::desc("Emit bytecode when generating output"),
93 cl::location(L&: emitBytecodeFlag), cl::init(Val: false));
94
95 static cl::opt<bool, /*ExternalStorage=*/true> elideResourcesFromBytecode(
96 "elide-resource-data-from-bytecode",
97 cl::desc("Elide resources when generating bytecode"),
98 cl::location(L&: elideResourceDataFromBytecodeFlag), cl::init(Val: false));
99
100 static cl::opt<std::optional<int64_t>, /*ExternalStorage=*/true,
101 BytecodeVersionParser>
102 bytecodeVersion(
103 "emit-bytecode-version",
104 cl::desc("Use specified bytecode when generating output"),
105 cl::location(L&: emitBytecodeVersion), cl::init(Val: std::nullopt));
106
107 static cl::opt<std::string, /*ExternalStorage=*/true> irdlFile(
108 "irdl-file",
109 cl::desc("IRDL file to register before processing the input"),
110 cl::location(L&: irdlFileFlag), cl::init(Val: ""), cl::value_desc("filename"));
111
112 static cl::opt<VerbosityLevel, /*ExternalStorage=*/true>
113 diagnosticVerbosityLevel(
114 "mlir-diagnostic-verbosity-level",
115 cl::desc("Choose level of diagnostic information"),
116 cl::location(L&: diagnosticVerbosityLevelFlag),
117 cl::init(Val: VerbosityLevel::ErrorsWarningsAndRemarks),
118 cl::values(
119 clEnumValN(VerbosityLevel::ErrorsOnly, "errors", "Errors only"),
120 clEnumValN(VerbosityLevel::ErrorsAndWarnings, "warnings",
121 "Errors and warnings"),
122 clEnumValN(VerbosityLevel::ErrorsWarningsAndRemarks, "remarks",
123 "Errors, warnings and remarks")));
124
125 static cl::opt<bool, /*ExternalStorage=*/true> disableDiagnosticNotes(
126 "mlir-disable-diagnostic-notes", cl::desc("Disable diagnostic notes."),
127 cl::location(L&: disableDiagnosticNotesFlag), cl::init(Val: false));
128
129 static cl::opt<bool, /*ExternalStorage=*/true> explicitModule(
130 "no-implicit-module",
131 cl::desc("Disable implicit addition of a top-level module op during "
132 "parsing"),
133 cl::location(L&: useExplicitModuleFlag), cl::init(Val: false));
134
135 static cl::opt<bool, /*ExternalStorage=*/true> listPasses(
136 "list-passes", cl::desc("Print the list of registered passes and exit"),
137 cl::location(L&: listPassesFlag), cl::init(Val: false));
138
139 static cl::opt<bool, /*ExternalStorage=*/true> runReproducer(
140 "run-reproducer", cl::desc("Run the pipeline stored in the reproducer"),
141 cl::location(L&: runReproducerFlag), cl::init(Val: false));
142
143 static cl::opt<bool, /*ExternalStorage=*/true> showDialects(
144 "show-dialects",
145 cl::desc("Print the list of registered dialects and exit"),
146 cl::location(L&: showDialectsFlag), cl::init(Val: false));
147
148 static cl::opt<std::string, /*ExternalStorage=*/true> splitInputFile{
149 "split-input-file",
150 llvm::cl::ValueOptional,
151 cl::callback(CB: [&](const std::string &str) {
152 // Implicit value: use default marker if flag was used without value.
153 if (str.empty())
154 splitInputFile.setValue(V: kDefaultSplitMarker);
155 }),
156 cl::desc("Split the input file into chunks using the given or "
157 "default marker and process each chunk independently"),
158 cl::location(L&: splitInputFileFlag),
159 cl::init(Val: "")};
160
161 static cl::opt<std::string, /*ExternalStorage=*/true> outputSplitMarker(
162 "output-split-marker",
163 cl::desc("Split marker to use for merging the ouput"),
164 cl::location(L&: outputSplitMarkerFlag), cl::init(Val: kDefaultSplitMarker));
165
166 static cl::opt<SourceMgrDiagnosticVerifierHandler::Level,
167 /*ExternalStorage=*/true>
168 verifyDiagnostics{
169 "verify-diagnostics", llvm::cl::ValueOptional,
170 cl::desc("Check that emitted diagnostics match expected-* lines on "
171 "the corresponding line"),
172 cl::location(L&: verifyDiagnosticsFlag),
173 cl::values(
174 clEnumValN(SourceMgrDiagnosticVerifierHandler::Level::All,
175 "all",
176 "Check all diagnostics (expected, unexpected, "
177 "near-misses)"),
178 // Implicit value: when passed with no arguments, e.g.
179 // `--verify-diagnostics` or `--verify-diagnostics=`.
180 clEnumValN(SourceMgrDiagnosticVerifierHandler::Level::All, "",
181 "Check all diagnostics (expected, unexpected, "
182 "near-misses)"),
183 clEnumValN(
184 SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
185 "only-expected", "Check only expected diagnostics"))};
186
187 static cl::opt<bool, /*ExternalStorage=*/true> verifyPasses(
188 "verify-each",
189 cl::desc("Run the verifier after each transformation pass"),
190 cl::location(L&: verifyPassesFlag), cl::init(Val: true));
191
192 static cl::opt<bool, /*ExternalStorage=*/true> disableVerifyOnParsing(
193 "mlir-very-unsafe-disable-verifier-on-parsing",
194 cl::desc("Disable the verifier on parsing (very unsafe)"),
195 cl::location(L&: disableVerifierOnParsingFlag), cl::init(Val: false));
196
197 static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip(
198 "verify-roundtrip",
199 cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
200 cl::location(L&: verifyRoundtripFlag), cl::init(Val: false));
201
202 static cl::list<std::string> passPlugins(
203 "load-pass-plugin", cl::desc("Load passes from plugin library"));
204
205 static cl::opt<std::string, /*ExternalStorage=*/true>
206 generateReproducerFile(
207 "mlir-generate-reproducer",
208 llvm::cl::desc(
209 "Generate an mlir reproducer at the provided filename"
210 " (no crash required)"),
211 cl::location(L&: generateReproducerFileFlag), cl::init(Val: ""),
212 cl::value_desc("filename"));
213
214 /// Set the callback to load a pass plugin.
215 passPlugins.setCallback([&](const std::string &pluginPath) {
216 auto plugin = PassPlugin::load(filename: pluginPath);
217 if (!plugin) {
218 errs() << "Failed to load passes from '" << pluginPath
219 << "'. Request ignored.\n";
220 return;
221 }
222 plugin.get().registerPassRegistryCallbacks();
223 });
224
225 static cl::list<std::string> dialectPlugins(
226 "load-dialect-plugin", cl::desc("Load dialects from plugin library"));
227 this->dialectPlugins = std::addressof(r&: dialectPlugins);
228
229 static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p");
230 setPassPipelineParser(passPipeline);
231 }
232
233 /// Set the callback to load a dialect plugin.
234 void setDialectPluginsCallback(DialectRegistry &registry);
235
236 /// Pointer to static dialectPlugins variable in constructor, needed by
237 /// setDialectPluginsCallback(DialectRegistry&).
238 cl::list<std::string> *dialectPlugins = nullptr;
239};
240
241/// A scoped diagnostic handler that suppresses certain diagnostics based on
242/// the verbosity level and whether the diagnostic is a note.
243class DiagnosticFilter : public ScopedDiagnosticHandler {
244public:
245 DiagnosticFilter(MLIRContext *ctx, VerbosityLevel verbosityLevel,
246 bool showNotes = true)
247 : ScopedDiagnosticHandler(ctx) {
248 setHandler([verbosityLevel, showNotes](Diagnostic &diag) {
249 auto severity = diag.getSeverity();
250 switch (severity) {
251 case DiagnosticSeverity::Error:
252 // failure indicates that the error is not handled by the filter and
253 // goes through to the default handler. Therefore, the error can be
254 // successfully printed.
255 return failure();
256 case DiagnosticSeverity::Warning:
257 if (verbosityLevel == VerbosityLevel::ErrorsOnly)
258 return success();
259 else
260 return failure();
261 case DiagnosticSeverity::Remark:
262 if (verbosityLevel == VerbosityLevel::ErrorsOnly ||
263 verbosityLevel == VerbosityLevel::ErrorsAndWarnings)
264 return success();
265 else
266 return failure();
267 case DiagnosticSeverity::Note:
268 if (showNotes)
269 return failure();
270 else
271 return success();
272 }
273 llvm_unreachable("Unknown diagnostic severity");
274 });
275 }
276};
277} // namespace
278
279ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig;
280
281void MlirOptMainConfig::registerCLOptions(DialectRegistry &registry) {
282 clOptionsConfig->setDialectPluginsCallback(registry);
283 tracing::DebugConfig::registerCLOptions();
284}
285
286MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() {
287 clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions());
288 return *clOptionsConfig;
289}
290
291MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser(
292 const PassPipelineCLParser &passPipeline) {
293 passPipelineCallback = [&](PassManager &pm) {
294 auto errorHandler = [&](const Twine &msg) {
295 emitError(UnknownLoc::get(pm.getContext())) << msg;
296 return failure();
297 };
298 if (failed(Result: passPipeline.addToPipeline(pm, errorHandler)))
299 return failure();
300 if (this->shouldDumpPassPipeline()) {
301
302 pm.dump();
303 llvm::errs() << "\n";
304 }
305 return success();
306 };
307 return *this;
308}
309
310void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
311 DialectRegistry &registry) {
312 dialectPlugins->setCallback([&](const std::string &pluginPath) {
313 auto plugin = DialectPlugin::load(filename: pluginPath);
314 if (!plugin) {
315 errs() << "Failed to load dialect plugin from '" << pluginPath
316 << "'. Request ignored.\n";
317 return;
318 };
319 plugin.get().registerDialectRegistryCallbacks(dialectRegistry&: registry);
320 });
321}
322
323LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
324 DialectRegistry registry;
325 registry.insert<irdl::IRDLDialect>();
326 ctx.appendDialectRegistry(registry);
327
328 // Set up the input file.
329 std::string errorMessage;
330 std::unique_ptr<MemoryBuffer> file = openInputFile(inputFilename: irdlFile, errorMessage: &errorMessage);
331 if (!file) {
332 emitError(UnknownLoc::get(&ctx)) << errorMessage;
333 return failure();
334 }
335
336 // Give the buffer to the source manager.
337 // This will be picked up by the parser.
338 SourceMgr sourceMgr;
339 sourceMgr.AddNewSourceBuffer(F: std::move(file), IncludeLoc: SMLoc());
340
341 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
342
343 // Parse the input file.
344 OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
345 if (!module)
346 return failure();
347
348 // Load IRDL dialects.
349 return irdl::loadDialects(module.get());
350}
351
352// Return success if the module can correctly round-trip. This intended to test
353// that the custom printers/parsers are complete.
354static LogicalResult doVerifyRoundTrip(Operation *op,
355 const MlirOptMainConfig &config,
356 bool useBytecode) {
357 // We use a new context to avoid resource handle renaming issue in the diff.
358 MLIRContext roundtripContext;
359 OwningOpRef<Operation *> roundtripModule;
360 roundtripContext.appendDialectRegistry(
361 registry: op->getContext()->getDialectRegistry());
362 if (op->getContext()->allowsUnregisteredDialects())
363 roundtripContext.allowUnregisteredDialects();
364 StringRef irdlFile = config.getIrdlFile();
365 if (!irdlFile.empty() && failed(Result: loadIRDLDialects(irdlFile, ctx&: roundtripContext)))
366 return failure();
367
368 std::string testType = (useBytecode) ? "bytecode" : "textual";
369 // Print a first time with custom format (or bytecode) and parse it back to
370 // the roundtripModule.
371 {
372 std::string buffer;
373 llvm::raw_string_ostream ostream(buffer);
374 if (useBytecode) {
375 if (failed(Result: writeBytecodeToFile(op, os&: ostream))) {
376 op->emitOpError()
377 << "failed to write bytecode, cannot verify round-trip.\n";
378 return failure();
379 }
380 } else {
381 op->print(os&: ostream,
382 flags: OpPrintingFlags().printGenericOpForm().enableDebugInfo());
383 }
384 FallbackAsmResourceMap fallbackResourceMap;
385 ParserConfig parseConfig(&roundtripContext, config.shouldVerifyOnParsing(),
386 &fallbackResourceMap);
387 roundtripModule = parseSourceString<Operation *>(sourceStr: buffer, config: parseConfig);
388 if (!roundtripModule) {
389 op->emitOpError() << "failed to parse " << testType
390 << " content back, cannot verify round-trip.\n";
391 return failure();
392 }
393 }
394
395 // Print in the generic form for the reference module and the round-tripped
396 // one and compare the outputs.
397 std::string reference, roundtrip;
398 {
399 llvm::raw_string_ostream ostreamref(reference);
400 op->print(os&: ostreamref,
401 flags: OpPrintingFlags().printGenericOpForm().enableDebugInfo());
402 llvm::raw_string_ostream ostreamrndtrip(roundtrip);
403 roundtripModule.get()->print(
404 os&: ostreamrndtrip,
405 flags: OpPrintingFlags().printGenericOpForm().enableDebugInfo());
406 }
407 if (reference != roundtrip) {
408 // TODO implement a diff.
409 return op->emitOpError()
410 << testType
411 << " roundTrip testing roundtripped module differs "
412 "from reference:\n<<<<<<Reference\n"
413 << reference << "\n=====\n"
414 << roundtrip << "\n>>>>>roundtripped\n";
415 }
416
417 return success();
418}
419
420static LogicalResult doVerifyRoundTrip(Operation *op,
421 const MlirOptMainConfig &config) {
422 auto txtStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/false);
423 auto bcStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/true);
424 return success(IsSuccess: succeeded(Result: txtStatus) && succeeded(Result: bcStatus));
425}
426
427/// Perform the actions on the input file indicated by the command line flags
428/// within the specified context.
429///
430/// This typically parses the main source file, runs zero or more optimization
431/// passes, then prints the output.
432///
433static LogicalResult
434performActions(raw_ostream &os,
435 const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
436 MLIRContext *context, const MlirOptMainConfig &config) {
437 DefaultTimingManager tm;
438 applyDefaultTimingManagerCLOptions(tm);
439 TimingScope timing = tm.getRootScope();
440
441 // Disable multi-threading when parsing the input file. This removes the
442 // unnecessary/costly context synchronization when parsing.
443 bool wasThreadingEnabled = context->isMultithreadingEnabled();
444 context->disableMultithreading();
445
446 // Prepare the parser config, and attach any useful/necessary resource
447 // handlers. Unhandled external resources are treated as passthrough, i.e.
448 // they are not processed and will be emitted directly to the output
449 // untouched.
450 PassReproducerOptions reproOptions;
451 FallbackAsmResourceMap fallbackResourceMap;
452 ParserConfig parseConfig(context, config.shouldVerifyOnParsing(),
453 &fallbackResourceMap);
454 if (config.shouldRunReproducer())
455 reproOptions.attachResourceParser(config&: parseConfig);
456
457 // Parse the input file and reset the context threading state.
458 TimingScope parserTiming = timing.nest(args: "Parser");
459 OwningOpRef<Operation *> op = parseSourceFileForTool(
460 sourceMgr, config: parseConfig, insertImplicitModule: !config.shouldUseExplicitModule());
461 parserTiming.stop();
462 if (!op)
463 return failure();
464
465 // Perform round-trip verification if requested
466 if (config.shouldVerifyRoundtrip() &&
467 failed(Result: doVerifyRoundTrip(op: op.get(), config)))
468 return failure();
469
470 context->enableMultithreading(enable: wasThreadingEnabled);
471
472 // Prepare the pass manager, applying command-line and reproducer options.
473 PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
474 pm.enableVerifier(enabled: config.shouldVerifyPasses());
475 if (failed(Result: applyPassManagerCLOptions(pm)))
476 return failure();
477 pm.enableTiming(timingScope&: timing);
478 if (config.shouldRunReproducer() && failed(Result: reproOptions.apply(pm)))
479 return failure();
480 if (failed(Result: config.setupPassPipeline(pm)))
481 return failure();
482
483 // Run the pipeline.
484 if (failed(Result: pm.run(op: *op)))
485 return failure();
486
487 // Generate reproducers if requested
488 if (!config.getReproducerFilename().empty()) {
489 StringRef anchorName = pm.getAnyOpAnchorName();
490 const auto &passes = pm.getPasses();
491 makeReproducer(anchorName, passes, op: op.get(),
492 outputFile: config.getReproducerFilename());
493 }
494
495 // Print the output.
496 TimingScope outputTiming = timing.nest(args: "Output");
497 if (config.shouldEmitBytecode()) {
498 BytecodeWriterConfig writerConfig(fallbackResourceMap);
499 if (auto v = config.bytecodeVersionToEmit())
500 writerConfig.setDesiredBytecodeVersion(*v);
501 if (config.shouldElideResourceDataFromBytecode())
502 writerConfig.setElideResourceDataFlag();
503 return writeBytecodeToFile(op: op.get(), os, config: writerConfig);
504 }
505
506 if (config.bytecodeVersionToEmit().has_value())
507 return emitError(UnknownLoc::get(pm.getContext()))
508 << "bytecode version while not emitting bytecode";
509 AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
510 &fallbackResourceMap);
511 op.get()->print(os, state&: asmState);
512 os << '\n';
513 return success();
514}
515
516/// Parses the memory buffer. If successfully, run a series of passes against
517/// it and print the result.
518static LogicalResult processBuffer(raw_ostream &os,
519 std::unique_ptr<MemoryBuffer> ownedBuffer,
520 const MlirOptMainConfig &config,
521 DialectRegistry &registry,
522 llvm::ThreadPoolInterface *threadPool) {
523 // Tell sourceMgr about this buffer, which is what the parser will pick up.
524 auto sourceMgr = std::make_shared<SourceMgr>();
525 sourceMgr->AddNewSourceBuffer(F: std::move(ownedBuffer), IncludeLoc: SMLoc());
526
527 // Create a context just for the current buffer. Disable threading on creation
528 // since we'll inject the thread-pool separately.
529 MLIRContext context(registry, MLIRContext::Threading::DISABLED);
530 if (threadPool)
531 context.setThreadPool(*threadPool);
532
533 StringRef irdlFile = config.getIrdlFile();
534 if (!irdlFile.empty() && failed(Result: loadIRDLDialects(irdlFile, ctx&: context)))
535 return failure();
536
537 // Parse the input file.
538 context.allowUnregisteredDialects(allow: config.shouldAllowUnregisteredDialects());
539 if (config.shouldVerifyDiagnostics())
540 context.printOpOnDiagnostic(enable: false);
541
542 tracing::InstallDebugHandler installDebugHandler(context,
543 config.getDebugConfig());
544
545 // If we are in verify diagnostics mode then we have a lot of work to do,
546 // otherwise just perform the actions without worrying about it.
547 if (!config.shouldVerifyDiagnostics()) {
548 SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
549 DiagnosticFilter diagnosticFilter(&context,
550 config.getDiagnosticVerbosityLevel(),
551 config.shouldShowNotes());
552 return performActions(os, sourceMgr, context: &context, config);
553 }
554
555 SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
556 *sourceMgr, &context, config.verifyDiagnosticsLevel());
557
558 // Do any processing requested by command line flags. We don't care whether
559 // these actions succeed or fail, we only care what diagnostics they produce
560 // and whether they match our expectations.
561 (void)performActions(os, sourceMgr, context: &context, config);
562
563 // Verify the diagnostic handler to make sure that each of the diagnostics
564 // matched.
565 return sourceMgrHandler.verify();
566}
567
568std::pair<std::string, std::string>
569mlir::registerAndParseCLIOptions(int argc, char **argv,
570 llvm::StringRef toolName,
571 DialectRegistry &registry) {
572 static cl::opt<std::string> inputFilename(
573 cl::Positional, cl::desc("<input file>"), cl::init(Val: "-"));
574
575 static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
576 cl::value_desc("filename"),
577 cl::init(Val: "-"));
578 // Register any command line options.
579 MlirOptMainConfig::registerCLOptions(registry);
580 registerAsmPrinterCLOptions();
581 registerMLIRContextCLOptions();
582 registerPassManagerCLOptions();
583 registerDefaultTimingManagerCLOptions();
584 tracing::DebugCounter::registerCLOptions();
585
586 // Build the list of dialects as a header for the --help message.
587 std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
588 {
589 llvm::raw_string_ostream os(helpHeader);
590 interleaveComma(c: registry.getDialectNames(), os,
591 each_fn: [&](auto name) { os << name; });
592 }
593 // Parse pass names in main to ensure static initialization completed.
594 cl::ParseCommandLineOptions(argc, argv, Overview: helpHeader);
595 return std::make_pair(x&: inputFilename.getValue(), y&: outputFilename.getValue());
596}
597
598static LogicalResult printRegisteredDialects(DialectRegistry &registry) {
599 llvm::outs() << "Available Dialects: ";
600 interleave(c: registry.getDialectNames(), os&: llvm::outs(), separator: ",");
601 llvm::outs() << "\n";
602 return success();
603}
604
605static LogicalResult printRegisteredPassesAndReturn() {
606 mlir::printRegisteredPasses();
607 return success();
608}
609
610LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
611 std::unique_ptr<llvm::MemoryBuffer> buffer,
612 DialectRegistry &registry,
613 const MlirOptMainConfig &config) {
614 if (config.shouldShowDialects())
615 return printRegisteredDialects(registry);
616
617 if (config.shouldListPasses())
618 return printRegisteredPassesAndReturn();
619
620 // The split-input-file mode is a very specific mode that slices the file
621 // up into small pieces and checks each independently.
622 // We use an explicit threadpool to avoid creating and joining/destroying
623 // threads for each of the split.
624 ThreadPoolInterface *threadPool = nullptr;
625
626 // Create a temporary context for the sake of checking if
627 // --mlir-disable-threading was passed on the command line.
628 // We use the thread-pool this context is creating, and avoid
629 // creating any thread when disabled.
630 MLIRContext threadPoolCtx;
631 if (threadPoolCtx.isMultithreadingEnabled())
632 threadPool = &threadPoolCtx.getThreadPool();
633
634 auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
635 raw_ostream &os) {
636 return processBuffer(os, ownedBuffer: std::move(chunkBuffer), config, registry,
637 threadPool);
638 };
639 return splitAndProcessBuffer(originalBuffer: std::move(buffer), processChunkBuffer: chunkFn, os&: outputStream,
640 inputSplitMarker: config.inputSplitMarker(),
641 outputSplitMarker: config.outputSplitMarker());
642}
643
644LogicalResult mlir::MlirOptMain(int argc, char **argv,
645 llvm::StringRef inputFilename,
646 llvm::StringRef outputFilename,
647 DialectRegistry &registry) {
648
649 InitLLVM y(argc, argv);
650
651 MlirOptMainConfig config = MlirOptMainConfig::createFromCLOptions();
652
653 if (config.shouldShowDialects())
654 return printRegisteredDialects(registry);
655
656 if (config.shouldListPasses())
657 return printRegisteredPassesAndReturn();
658
659 // When reading from stdin and the input is a tty, it is often a user mistake
660 // and the process "appears to be stuck". Print a message to let the user know
661 // about it!
662 if (inputFilename == "-" &&
663 sys::Process::FileDescriptorIsDisplayed(fd: fileno(stdin)))
664 llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to "
665 "interrupt)\n";
666
667 // Set up the input file.
668 std::string errorMessage;
669 auto file = openInputFile(inputFilename, errorMessage: &errorMessage);
670 if (!file) {
671 llvm::errs() << errorMessage << "\n";
672 return failure();
673 }
674
675 auto output = openOutputFile(outputFilename, errorMessage: &errorMessage);
676 if (!output) {
677 llvm::errs() << errorMessage << "\n";
678 return failure();
679 }
680 if (failed(Result: MlirOptMain(outputStream&: output->os(), buffer: std::move(file), registry, config)))
681 return failure();
682
683 // Keep the output file if the invocation of MlirOptMain was successful.
684 output->keep();
685 return success();
686}
687
688LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
689 DialectRegistry &registry) {
690
691 // Register and parse command line options.
692 std::string inputFilename, outputFilename;
693 std::tie(args&: inputFilename, args&: outputFilename) =
694 registerAndParseCLIOptions(argc, argv, toolName, registry);
695
696 return MlirOptMain(argc, argv, inputFilename, outputFilename, registry);
697}
698

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Tools/mlir-opt/MlirOptMain.cpp