| 1 | //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This is a utility that runs an optimization pass and prints the result back |
| 10 | // out. It is designed to support unit testing. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Tools/mlir-opt/MlirOptMain.h" |
| 15 | #include "mlir/Bytecode/BytecodeWriter.h" |
| 16 | #include "mlir/Debug/CLOptionsSetup.h" |
| 17 | #include "mlir/Debug/Counter.h" |
| 18 | #include "mlir/Debug/DebuggerExecutionContextHook.h" |
| 19 | #include "mlir/Debug/ExecutionContext.h" |
| 20 | #include "mlir/Debug/Observers/ActionLogging.h" |
| 21 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
| 22 | #include "mlir/Dialect/IRDL/IRDLLoading.h" |
| 23 | #include "mlir/IR/AsmState.h" |
| 24 | #include "mlir/IR/Attributes.h" |
| 25 | #include "mlir/IR/BuiltinOps.h" |
| 26 | #include "mlir/IR/Diagnostics.h" |
| 27 | #include "mlir/IR/Dialect.h" |
| 28 | #include "mlir/IR/Location.h" |
| 29 | #include "mlir/IR/MLIRContext.h" |
| 30 | #include "mlir/Parser/Parser.h" |
| 31 | #include "mlir/Pass/Pass.h" |
| 32 | #include "mlir/Pass/PassManager.h" |
| 33 | #include "mlir/Pass/PassRegistry.h" |
| 34 | #include "mlir/Support/FileUtilities.h" |
| 35 | #include "mlir/Support/Timing.h" |
| 36 | #include "mlir/Support/ToolUtilities.h" |
| 37 | #include "mlir/Tools/ParseUtilities.h" |
| 38 | #include "mlir/Tools/Plugins/DialectPlugin.h" |
| 39 | #include "mlir/Tools/Plugins/PassPlugin.h" |
| 40 | #include "llvm/ADT/StringRef.h" |
| 41 | #include "llvm/Support/CommandLine.h" |
| 42 | #include "llvm/Support/FileUtilities.h" |
| 43 | #include "llvm/Support/InitLLVM.h" |
| 44 | #include "llvm/Support/LogicalResult.h" |
| 45 | #include "llvm/Support/ManagedStatic.h" |
| 46 | #include "llvm/Support/Process.h" |
| 47 | #include "llvm/Support/Regex.h" |
| 48 | #include "llvm/Support/SourceMgr.h" |
| 49 | #include "llvm/Support/StringSaver.h" |
| 50 | #include "llvm/Support/ThreadPool.h" |
| 51 | #include "llvm/Support/ToolOutputFile.h" |
| 52 | |
| 53 | using namespace mlir; |
| 54 | using namespace llvm; |
| 55 | |
| 56 | namespace { |
| 57 | class BytecodeVersionParser : public cl::parser<std::optional<int64_t>> { |
| 58 | public: |
| 59 | BytecodeVersionParser(cl::Option &o) |
| 60 | : cl::parser<std::optional<int64_t>>(o) {} |
| 61 | |
| 62 | bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg, |
| 63 | std::optional<int64_t> &v) { |
| 64 | long long w; |
| 65 | if (getAsSignedInteger(Str: arg, Radix: 10, Result&: w)) |
| 66 | return o.error(Message: "Invalid argument '" + arg + |
| 67 | "', only integer is supported." ); |
| 68 | v = w; |
| 69 | return false; |
| 70 | } |
| 71 | }; |
| 72 | |
| 73 | /// This class is intended to manage the handling of command line options for |
| 74 | /// creating a *-opt config. This is a singleton. |
| 75 | struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { |
| 76 | MlirOptMainConfigCLOptions() { |
| 77 | // These options are static but all uses ExternalStorage to initialize the |
| 78 | // members of the parent class. This is unusual but since this class is a |
| 79 | // singleton it basically attaches command line option to the singleton |
| 80 | // members. |
| 81 | |
| 82 | static cl::opt<bool, /*ExternalStorage=*/true> allowUnregisteredDialects( |
| 83 | "allow-unregistered-dialect" , |
| 84 | cl::desc("Allow operation with no registered dialects" ), |
| 85 | cl::location(L&: allowUnregisteredDialectsFlag), cl::init(Val: false)); |
| 86 | |
| 87 | static cl::opt<bool, /*ExternalStorage=*/true> dumpPassPipeline( |
| 88 | "dump-pass-pipeline" , cl::desc("Print the pipeline that will be run" ), |
| 89 | cl::location(L&: dumpPassPipelineFlag), cl::init(Val: false)); |
| 90 | |
| 91 | static cl::opt<bool, /*ExternalStorage=*/true> emitBytecode( |
| 92 | "emit-bytecode" , cl::desc("Emit bytecode when generating output" ), |
| 93 | cl::location(L&: emitBytecodeFlag), cl::init(Val: false)); |
| 94 | |
| 95 | static cl::opt<bool, /*ExternalStorage=*/true> elideResourcesFromBytecode( |
| 96 | "elide-resource-data-from-bytecode" , |
| 97 | cl::desc("Elide resources when generating bytecode" ), |
| 98 | cl::location(L&: elideResourceDataFromBytecodeFlag), cl::init(Val: false)); |
| 99 | |
| 100 | static cl::opt<std::optional<int64_t>, /*ExternalStorage=*/true, |
| 101 | BytecodeVersionParser> |
| 102 | bytecodeVersion( |
| 103 | "emit-bytecode-version" , |
| 104 | cl::desc("Use specified bytecode when generating output" ), |
| 105 | cl::location(L&: emitBytecodeVersion), cl::init(Val: std::nullopt)); |
| 106 | |
| 107 | static cl::opt<std::string, /*ExternalStorage=*/true> irdlFile( |
| 108 | "irdl-file" , |
| 109 | cl::desc("IRDL file to register before processing the input" ), |
| 110 | cl::location(L&: irdlFileFlag), cl::init(Val: "" ), cl::value_desc("filename" )); |
| 111 | |
| 112 | static cl::opt<VerbosityLevel, /*ExternalStorage=*/true> |
| 113 | diagnosticVerbosityLevel( |
| 114 | "mlir-diagnostic-verbosity-level" , |
| 115 | cl::desc("Choose level of diagnostic information" ), |
| 116 | cl::location(L&: diagnosticVerbosityLevelFlag), |
| 117 | cl::init(Val: VerbosityLevel::ErrorsWarningsAndRemarks), |
| 118 | cl::values( |
| 119 | clEnumValN(VerbosityLevel::ErrorsOnly, "errors" , "Errors only" ), |
| 120 | clEnumValN(VerbosityLevel::ErrorsAndWarnings, "warnings" , |
| 121 | "Errors and warnings" ), |
| 122 | clEnumValN(VerbosityLevel::ErrorsWarningsAndRemarks, "remarks" , |
| 123 | "Errors, warnings and remarks" ))); |
| 124 | |
| 125 | static cl::opt<bool, /*ExternalStorage=*/true> disableDiagnosticNotes( |
| 126 | "mlir-disable-diagnostic-notes" , cl::desc("Disable diagnostic notes." ), |
| 127 | cl::location(L&: disableDiagnosticNotesFlag), cl::init(Val: false)); |
| 128 | |
| 129 | static cl::opt<bool, /*ExternalStorage=*/true> explicitModule( |
| 130 | "no-implicit-module" , |
| 131 | cl::desc("Disable implicit addition of a top-level module op during " |
| 132 | "parsing" ), |
| 133 | cl::location(L&: useExplicitModuleFlag), cl::init(Val: false)); |
| 134 | |
| 135 | static cl::opt<bool, /*ExternalStorage=*/true> listPasses( |
| 136 | "list-passes" , cl::desc("Print the list of registered passes and exit" ), |
| 137 | cl::location(L&: listPassesFlag), cl::init(Val: false)); |
| 138 | |
| 139 | static cl::opt<bool, /*ExternalStorage=*/true> runReproducer( |
| 140 | "run-reproducer" , cl::desc("Run the pipeline stored in the reproducer" ), |
| 141 | cl::location(L&: runReproducerFlag), cl::init(Val: false)); |
| 142 | |
| 143 | static cl::opt<bool, /*ExternalStorage=*/true> showDialects( |
| 144 | "show-dialects" , |
| 145 | cl::desc("Print the list of registered dialects and exit" ), |
| 146 | cl::location(L&: showDialectsFlag), cl::init(Val: false)); |
| 147 | |
| 148 | static cl::opt<std::string, /*ExternalStorage=*/true> splitInputFile{ |
| 149 | "split-input-file" , |
| 150 | llvm::cl::ValueOptional, |
| 151 | cl::callback(CB: [&](const std::string &str) { |
| 152 | // Implicit value: use default marker if flag was used without value. |
| 153 | if (str.empty()) |
| 154 | splitInputFile.setValue(V: kDefaultSplitMarker); |
| 155 | }), |
| 156 | cl::desc("Split the input file into chunks using the given or " |
| 157 | "default marker and process each chunk independently" ), |
| 158 | cl::location(L&: splitInputFileFlag), |
| 159 | cl::init(Val: "" )}; |
| 160 | |
| 161 | static cl::opt<std::string, /*ExternalStorage=*/true> outputSplitMarker( |
| 162 | "output-split-marker" , |
| 163 | cl::desc("Split marker to use for merging the ouput" ), |
| 164 | cl::location(L&: outputSplitMarkerFlag), cl::init(Val: kDefaultSplitMarker)); |
| 165 | |
| 166 | static cl::opt<SourceMgrDiagnosticVerifierHandler::Level, |
| 167 | /*ExternalStorage=*/true> |
| 168 | verifyDiagnostics{ |
| 169 | "verify-diagnostics" , llvm::cl::ValueOptional, |
| 170 | cl::desc("Check that emitted diagnostics match expected-* lines on " |
| 171 | "the corresponding line" ), |
| 172 | cl::location(L&: verifyDiagnosticsFlag), |
| 173 | cl::values( |
| 174 | clEnumValN(SourceMgrDiagnosticVerifierHandler::Level::All, |
| 175 | "all" , |
| 176 | "Check all diagnostics (expected, unexpected, " |
| 177 | "near-misses)" ), |
| 178 | // Implicit value: when passed with no arguments, e.g. |
| 179 | // `--verify-diagnostics` or `--verify-diagnostics=`. |
| 180 | clEnumValN(SourceMgrDiagnosticVerifierHandler::Level::All, "" , |
| 181 | "Check all diagnostics (expected, unexpected, " |
| 182 | "near-misses)" ), |
| 183 | clEnumValN( |
| 184 | SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected, |
| 185 | "only-expected" , "Check only expected diagnostics" ))}; |
| 186 | |
| 187 | static cl::opt<bool, /*ExternalStorage=*/true> verifyPasses( |
| 188 | "verify-each" , |
| 189 | cl::desc("Run the verifier after each transformation pass" ), |
| 190 | cl::location(L&: verifyPassesFlag), cl::init(Val: true)); |
| 191 | |
| 192 | static cl::opt<bool, /*ExternalStorage=*/true> disableVerifyOnParsing( |
| 193 | "mlir-very-unsafe-disable-verifier-on-parsing" , |
| 194 | cl::desc("Disable the verifier on parsing (very unsafe)" ), |
| 195 | cl::location(L&: disableVerifierOnParsingFlag), cl::init(Val: false)); |
| 196 | |
| 197 | static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip( |
| 198 | "verify-roundtrip" , |
| 199 | cl::desc("Round-trip the IR after parsing and ensure it succeeds" ), |
| 200 | cl::location(L&: verifyRoundtripFlag), cl::init(Val: false)); |
| 201 | |
| 202 | static cl::list<std::string> passPlugins( |
| 203 | "load-pass-plugin" , cl::desc("Load passes from plugin library" )); |
| 204 | |
| 205 | static cl::opt<std::string, /*ExternalStorage=*/true> |
| 206 | generateReproducerFile( |
| 207 | "mlir-generate-reproducer" , |
| 208 | llvm::cl::desc( |
| 209 | "Generate an mlir reproducer at the provided filename" |
| 210 | " (no crash required)" ), |
| 211 | cl::location(L&: generateReproducerFileFlag), cl::init(Val: "" ), |
| 212 | cl::value_desc("filename" )); |
| 213 | |
| 214 | /// Set the callback to load a pass plugin. |
| 215 | passPlugins.setCallback([&](const std::string &pluginPath) { |
| 216 | auto plugin = PassPlugin::load(filename: pluginPath); |
| 217 | if (!plugin) { |
| 218 | errs() << "Failed to load passes from '" << pluginPath |
| 219 | << "'. Request ignored.\n" ; |
| 220 | return; |
| 221 | } |
| 222 | plugin.get().registerPassRegistryCallbacks(); |
| 223 | }); |
| 224 | |
| 225 | static cl::list<std::string> dialectPlugins( |
| 226 | "load-dialect-plugin" , cl::desc("Load dialects from plugin library" )); |
| 227 | this->dialectPlugins = std::addressof(r&: dialectPlugins); |
| 228 | |
| 229 | static PassPipelineCLParser passPipeline("" , "Compiler passes to run" , "p" ); |
| 230 | setPassPipelineParser(passPipeline); |
| 231 | } |
| 232 | |
| 233 | /// Set the callback to load a dialect plugin. |
| 234 | void setDialectPluginsCallback(DialectRegistry ®istry); |
| 235 | |
| 236 | /// Pointer to static dialectPlugins variable in constructor, needed by |
| 237 | /// setDialectPluginsCallback(DialectRegistry&). |
| 238 | cl::list<std::string> *dialectPlugins = nullptr; |
| 239 | }; |
| 240 | |
| 241 | /// A scoped diagnostic handler that suppresses certain diagnostics based on |
| 242 | /// the verbosity level and whether the diagnostic is a note. |
| 243 | class DiagnosticFilter : public ScopedDiagnosticHandler { |
| 244 | public: |
| 245 | DiagnosticFilter(MLIRContext *ctx, VerbosityLevel verbosityLevel, |
| 246 | bool showNotes = true) |
| 247 | : ScopedDiagnosticHandler(ctx) { |
| 248 | setHandler([verbosityLevel, showNotes](Diagnostic &diag) { |
| 249 | auto severity = diag.getSeverity(); |
| 250 | switch (severity) { |
| 251 | case DiagnosticSeverity::Error: |
| 252 | // failure indicates that the error is not handled by the filter and |
| 253 | // goes through to the default handler. Therefore, the error can be |
| 254 | // successfully printed. |
| 255 | return failure(); |
| 256 | case DiagnosticSeverity::Warning: |
| 257 | if (verbosityLevel == VerbosityLevel::ErrorsOnly) |
| 258 | return success(); |
| 259 | else |
| 260 | return failure(); |
| 261 | case DiagnosticSeverity::Remark: |
| 262 | if (verbosityLevel == VerbosityLevel::ErrorsOnly || |
| 263 | verbosityLevel == VerbosityLevel::ErrorsAndWarnings) |
| 264 | return success(); |
| 265 | else |
| 266 | return failure(); |
| 267 | case DiagnosticSeverity::Note: |
| 268 | if (showNotes) |
| 269 | return failure(); |
| 270 | else |
| 271 | return success(); |
| 272 | } |
| 273 | llvm_unreachable("Unknown diagnostic severity" ); |
| 274 | }); |
| 275 | } |
| 276 | }; |
| 277 | } // namespace |
| 278 | |
| 279 | ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig; |
| 280 | |
| 281 | void MlirOptMainConfig::registerCLOptions(DialectRegistry ®istry) { |
| 282 | clOptionsConfig->setDialectPluginsCallback(registry); |
| 283 | tracing::DebugConfig::registerCLOptions(); |
| 284 | } |
| 285 | |
| 286 | MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() { |
| 287 | clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions()); |
| 288 | return *clOptionsConfig; |
| 289 | } |
| 290 | |
| 291 | MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser( |
| 292 | const PassPipelineCLParser &passPipeline) { |
| 293 | passPipelineCallback = [&](PassManager &pm) { |
| 294 | auto errorHandler = [&](const Twine &msg) { |
| 295 | emitError(UnknownLoc::get(pm.getContext())) << msg; |
| 296 | return failure(); |
| 297 | }; |
| 298 | if (failed(Result: passPipeline.addToPipeline(pm, errorHandler))) |
| 299 | return failure(); |
| 300 | if (this->shouldDumpPassPipeline()) { |
| 301 | |
| 302 | pm.dump(); |
| 303 | llvm::errs() << "\n" ; |
| 304 | } |
| 305 | return success(); |
| 306 | }; |
| 307 | return *this; |
| 308 | } |
| 309 | |
| 310 | void MlirOptMainConfigCLOptions::setDialectPluginsCallback( |
| 311 | DialectRegistry ®istry) { |
| 312 | dialectPlugins->setCallback([&](const std::string &pluginPath) { |
| 313 | auto plugin = DialectPlugin::load(filename: pluginPath); |
| 314 | if (!plugin) { |
| 315 | errs() << "Failed to load dialect plugin from '" << pluginPath |
| 316 | << "'. Request ignored.\n" ; |
| 317 | return; |
| 318 | }; |
| 319 | plugin.get().registerDialectRegistryCallbacks(dialectRegistry&: registry); |
| 320 | }); |
| 321 | } |
| 322 | |
| 323 | LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) { |
| 324 | DialectRegistry registry; |
| 325 | registry.insert<irdl::IRDLDialect>(); |
| 326 | ctx.appendDialectRegistry(registry); |
| 327 | |
| 328 | // Set up the input file. |
| 329 | std::string errorMessage; |
| 330 | std::unique_ptr<MemoryBuffer> file = openInputFile(inputFilename: irdlFile, errorMessage: &errorMessage); |
| 331 | if (!file) { |
| 332 | emitError(UnknownLoc::get(&ctx)) << errorMessage; |
| 333 | return failure(); |
| 334 | } |
| 335 | |
| 336 | // Give the buffer to the source manager. |
| 337 | // This will be picked up by the parser. |
| 338 | SourceMgr sourceMgr; |
| 339 | sourceMgr.AddNewSourceBuffer(F: std::move(file), IncludeLoc: SMLoc()); |
| 340 | |
| 341 | SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx); |
| 342 | |
| 343 | // Parse the input file. |
| 344 | OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx)); |
| 345 | if (!module) |
| 346 | return failure(); |
| 347 | |
| 348 | // Load IRDL dialects. |
| 349 | return irdl::loadDialects(module.get()); |
| 350 | } |
| 351 | |
| 352 | // Return success if the module can correctly round-trip. This intended to test |
| 353 | // that the custom printers/parsers are complete. |
| 354 | static LogicalResult doVerifyRoundTrip(Operation *op, |
| 355 | const MlirOptMainConfig &config, |
| 356 | bool useBytecode) { |
| 357 | // We use a new context to avoid resource handle renaming issue in the diff. |
| 358 | MLIRContext roundtripContext; |
| 359 | OwningOpRef<Operation *> roundtripModule; |
| 360 | roundtripContext.appendDialectRegistry( |
| 361 | registry: op->getContext()->getDialectRegistry()); |
| 362 | if (op->getContext()->allowsUnregisteredDialects()) |
| 363 | roundtripContext.allowUnregisteredDialects(); |
| 364 | StringRef irdlFile = config.getIrdlFile(); |
| 365 | if (!irdlFile.empty() && failed(Result: loadIRDLDialects(irdlFile, ctx&: roundtripContext))) |
| 366 | return failure(); |
| 367 | |
| 368 | std::string testType = (useBytecode) ? "bytecode" : "textual" ; |
| 369 | // Print a first time with custom format (or bytecode) and parse it back to |
| 370 | // the roundtripModule. |
| 371 | { |
| 372 | std::string buffer; |
| 373 | llvm::raw_string_ostream ostream(buffer); |
| 374 | if (useBytecode) { |
| 375 | if (failed(Result: writeBytecodeToFile(op, os&: ostream))) { |
| 376 | op->emitOpError() |
| 377 | << "failed to write bytecode, cannot verify round-trip.\n" ; |
| 378 | return failure(); |
| 379 | } |
| 380 | } else { |
| 381 | op->print(os&: ostream, |
| 382 | flags: OpPrintingFlags().printGenericOpForm().enableDebugInfo()); |
| 383 | } |
| 384 | FallbackAsmResourceMap fallbackResourceMap; |
| 385 | ParserConfig parseConfig(&roundtripContext, config.shouldVerifyOnParsing(), |
| 386 | &fallbackResourceMap); |
| 387 | roundtripModule = parseSourceString<Operation *>(sourceStr: buffer, config: parseConfig); |
| 388 | if (!roundtripModule) { |
| 389 | op->emitOpError() << "failed to parse " << testType |
| 390 | << " content back, cannot verify round-trip.\n" ; |
| 391 | return failure(); |
| 392 | } |
| 393 | } |
| 394 | |
| 395 | // Print in the generic form for the reference module and the round-tripped |
| 396 | // one and compare the outputs. |
| 397 | std::string reference, roundtrip; |
| 398 | { |
| 399 | llvm::raw_string_ostream ostreamref(reference); |
| 400 | op->print(os&: ostreamref, |
| 401 | flags: OpPrintingFlags().printGenericOpForm().enableDebugInfo()); |
| 402 | llvm::raw_string_ostream ostreamrndtrip(roundtrip); |
| 403 | roundtripModule.get()->print( |
| 404 | os&: ostreamrndtrip, |
| 405 | flags: OpPrintingFlags().printGenericOpForm().enableDebugInfo()); |
| 406 | } |
| 407 | if (reference != roundtrip) { |
| 408 | // TODO implement a diff. |
| 409 | return op->emitOpError() |
| 410 | << testType |
| 411 | << " roundTrip testing roundtripped module differs " |
| 412 | "from reference:\n<<<<<<Reference\n" |
| 413 | << reference << "\n=====\n" |
| 414 | << roundtrip << "\n>>>>>roundtripped\n" ; |
| 415 | } |
| 416 | |
| 417 | return success(); |
| 418 | } |
| 419 | |
| 420 | static LogicalResult doVerifyRoundTrip(Operation *op, |
| 421 | const MlirOptMainConfig &config) { |
| 422 | auto txtStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/false); |
| 423 | auto bcStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/true); |
| 424 | return success(IsSuccess: succeeded(Result: txtStatus) && succeeded(Result: bcStatus)); |
| 425 | } |
| 426 | |
| 427 | /// Perform the actions on the input file indicated by the command line flags |
| 428 | /// within the specified context. |
| 429 | /// |
| 430 | /// This typically parses the main source file, runs zero or more optimization |
| 431 | /// passes, then prints the output. |
| 432 | /// |
| 433 | static LogicalResult |
| 434 | performActions(raw_ostream &os, |
| 435 | const std::shared_ptr<llvm::SourceMgr> &sourceMgr, |
| 436 | MLIRContext *context, const MlirOptMainConfig &config) { |
| 437 | DefaultTimingManager tm; |
| 438 | applyDefaultTimingManagerCLOptions(tm); |
| 439 | TimingScope timing = tm.getRootScope(); |
| 440 | |
| 441 | // Disable multi-threading when parsing the input file. This removes the |
| 442 | // unnecessary/costly context synchronization when parsing. |
| 443 | bool wasThreadingEnabled = context->isMultithreadingEnabled(); |
| 444 | context->disableMultithreading(); |
| 445 | |
| 446 | // Prepare the parser config, and attach any useful/necessary resource |
| 447 | // handlers. Unhandled external resources are treated as passthrough, i.e. |
| 448 | // they are not processed and will be emitted directly to the output |
| 449 | // untouched. |
| 450 | PassReproducerOptions reproOptions; |
| 451 | FallbackAsmResourceMap fallbackResourceMap; |
| 452 | ParserConfig parseConfig(context, config.shouldVerifyOnParsing(), |
| 453 | &fallbackResourceMap); |
| 454 | if (config.shouldRunReproducer()) |
| 455 | reproOptions.attachResourceParser(config&: parseConfig); |
| 456 | |
| 457 | // Parse the input file and reset the context threading state. |
| 458 | TimingScope parserTiming = timing.nest(args: "Parser" ); |
| 459 | OwningOpRef<Operation *> op = parseSourceFileForTool( |
| 460 | sourceMgr, config: parseConfig, insertImplicitModule: !config.shouldUseExplicitModule()); |
| 461 | parserTiming.stop(); |
| 462 | if (!op) |
| 463 | return failure(); |
| 464 | |
| 465 | // Perform round-trip verification if requested |
| 466 | if (config.shouldVerifyRoundtrip() && |
| 467 | failed(Result: doVerifyRoundTrip(op: op.get(), config))) |
| 468 | return failure(); |
| 469 | |
| 470 | context->enableMultithreading(enable: wasThreadingEnabled); |
| 471 | |
| 472 | // Prepare the pass manager, applying command-line and reproducer options. |
| 473 | PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit); |
| 474 | pm.enableVerifier(enabled: config.shouldVerifyPasses()); |
| 475 | if (failed(Result: applyPassManagerCLOptions(pm))) |
| 476 | return failure(); |
| 477 | pm.enableTiming(timingScope&: timing); |
| 478 | if (config.shouldRunReproducer() && failed(Result: reproOptions.apply(pm))) |
| 479 | return failure(); |
| 480 | if (failed(Result: config.setupPassPipeline(pm))) |
| 481 | return failure(); |
| 482 | |
| 483 | // Run the pipeline. |
| 484 | if (failed(Result: pm.run(op: *op))) |
| 485 | return failure(); |
| 486 | |
| 487 | // Generate reproducers if requested |
| 488 | if (!config.getReproducerFilename().empty()) { |
| 489 | StringRef anchorName = pm.getAnyOpAnchorName(); |
| 490 | const auto &passes = pm.getPasses(); |
| 491 | makeReproducer(anchorName, passes, op: op.get(), |
| 492 | outputFile: config.getReproducerFilename()); |
| 493 | } |
| 494 | |
| 495 | // Print the output. |
| 496 | TimingScope outputTiming = timing.nest(args: "Output" ); |
| 497 | if (config.shouldEmitBytecode()) { |
| 498 | BytecodeWriterConfig writerConfig(fallbackResourceMap); |
| 499 | if (auto v = config.bytecodeVersionToEmit()) |
| 500 | writerConfig.setDesiredBytecodeVersion(*v); |
| 501 | if (config.shouldElideResourceDataFromBytecode()) |
| 502 | writerConfig.setElideResourceDataFlag(); |
| 503 | return writeBytecodeToFile(op: op.get(), os, config: writerConfig); |
| 504 | } |
| 505 | |
| 506 | if (config.bytecodeVersionToEmit().has_value()) |
| 507 | return emitError(UnknownLoc::get(pm.getContext())) |
| 508 | << "bytecode version while not emitting bytecode" ; |
| 509 | AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, |
| 510 | &fallbackResourceMap); |
| 511 | op.get()->print(os, state&: asmState); |
| 512 | os << '\n'; |
| 513 | return success(); |
| 514 | } |
| 515 | |
| 516 | /// Parses the memory buffer. If successfully, run a series of passes against |
| 517 | /// it and print the result. |
| 518 | static LogicalResult processBuffer(raw_ostream &os, |
| 519 | std::unique_ptr<MemoryBuffer> ownedBuffer, |
| 520 | const MlirOptMainConfig &config, |
| 521 | DialectRegistry ®istry, |
| 522 | llvm::ThreadPoolInterface *threadPool) { |
| 523 | // Tell sourceMgr about this buffer, which is what the parser will pick up. |
| 524 | auto sourceMgr = std::make_shared<SourceMgr>(); |
| 525 | sourceMgr->AddNewSourceBuffer(F: std::move(ownedBuffer), IncludeLoc: SMLoc()); |
| 526 | |
| 527 | // Create a context just for the current buffer. Disable threading on creation |
| 528 | // since we'll inject the thread-pool separately. |
| 529 | MLIRContext context(registry, MLIRContext::Threading::DISABLED); |
| 530 | if (threadPool) |
| 531 | context.setThreadPool(*threadPool); |
| 532 | |
| 533 | StringRef irdlFile = config.getIrdlFile(); |
| 534 | if (!irdlFile.empty() && failed(Result: loadIRDLDialects(irdlFile, ctx&: context))) |
| 535 | return failure(); |
| 536 | |
| 537 | // Parse the input file. |
| 538 | context.allowUnregisteredDialects(allow: config.shouldAllowUnregisteredDialects()); |
| 539 | if (config.shouldVerifyDiagnostics()) |
| 540 | context.printOpOnDiagnostic(enable: false); |
| 541 | |
| 542 | tracing::InstallDebugHandler installDebugHandler(context, |
| 543 | config.getDebugConfig()); |
| 544 | |
| 545 | // If we are in verify diagnostics mode then we have a lot of work to do, |
| 546 | // otherwise just perform the actions without worrying about it. |
| 547 | if (!config.shouldVerifyDiagnostics()) { |
| 548 | SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); |
| 549 | DiagnosticFilter diagnosticFilter(&context, |
| 550 | config.getDiagnosticVerbosityLevel(), |
| 551 | config.shouldShowNotes()); |
| 552 | return performActions(os, sourceMgr, context: &context, config); |
| 553 | } |
| 554 | |
| 555 | SourceMgrDiagnosticVerifierHandler sourceMgrHandler( |
| 556 | *sourceMgr, &context, config.verifyDiagnosticsLevel()); |
| 557 | |
| 558 | // Do any processing requested by command line flags. We don't care whether |
| 559 | // these actions succeed or fail, we only care what diagnostics they produce |
| 560 | // and whether they match our expectations. |
| 561 | (void)performActions(os, sourceMgr, context: &context, config); |
| 562 | |
| 563 | // Verify the diagnostic handler to make sure that each of the diagnostics |
| 564 | // matched. |
| 565 | return sourceMgrHandler.verify(); |
| 566 | } |
| 567 | |
| 568 | std::pair<std::string, std::string> |
| 569 | mlir::registerAndParseCLIOptions(int argc, char **argv, |
| 570 | llvm::StringRef toolName, |
| 571 | DialectRegistry ®istry) { |
| 572 | static cl::opt<std::string> inputFilename( |
| 573 | cl::Positional, cl::desc("<input file>" ), cl::init(Val: "-" )); |
| 574 | |
| 575 | static cl::opt<std::string> outputFilename("o" , cl::desc("Output filename" ), |
| 576 | cl::value_desc("filename" ), |
| 577 | cl::init(Val: "-" )); |
| 578 | // Register any command line options. |
| 579 | MlirOptMainConfig::registerCLOptions(registry); |
| 580 | registerAsmPrinterCLOptions(); |
| 581 | registerMLIRContextCLOptions(); |
| 582 | registerPassManagerCLOptions(); |
| 583 | registerDefaultTimingManagerCLOptions(); |
| 584 | tracing::DebugCounter::registerCLOptions(); |
| 585 | |
| 586 | // Build the list of dialects as a header for the --help message. |
| 587 | std::string = (toolName + "\nAvailable Dialects: " ).str(); |
| 588 | { |
| 589 | llvm::raw_string_ostream os(helpHeader); |
| 590 | interleaveComma(c: registry.getDialectNames(), os, |
| 591 | each_fn: [&](auto name) { os << name; }); |
| 592 | } |
| 593 | // Parse pass names in main to ensure static initialization completed. |
| 594 | cl::ParseCommandLineOptions(argc, argv, Overview: helpHeader); |
| 595 | return std::make_pair(x&: inputFilename.getValue(), y&: outputFilename.getValue()); |
| 596 | } |
| 597 | |
| 598 | static LogicalResult printRegisteredDialects(DialectRegistry ®istry) { |
| 599 | llvm::outs() << "Available Dialects: " ; |
| 600 | interleave(c: registry.getDialectNames(), os&: llvm::outs(), separator: "," ); |
| 601 | llvm::outs() << "\n" ; |
| 602 | return success(); |
| 603 | } |
| 604 | |
| 605 | static LogicalResult printRegisteredPassesAndReturn() { |
| 606 | mlir::printRegisteredPasses(); |
| 607 | return success(); |
| 608 | } |
| 609 | |
| 610 | LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream, |
| 611 | std::unique_ptr<llvm::MemoryBuffer> buffer, |
| 612 | DialectRegistry ®istry, |
| 613 | const MlirOptMainConfig &config) { |
| 614 | if (config.shouldShowDialects()) |
| 615 | return printRegisteredDialects(registry); |
| 616 | |
| 617 | if (config.shouldListPasses()) |
| 618 | return printRegisteredPassesAndReturn(); |
| 619 | |
| 620 | // The split-input-file mode is a very specific mode that slices the file |
| 621 | // up into small pieces and checks each independently. |
| 622 | // We use an explicit threadpool to avoid creating and joining/destroying |
| 623 | // threads for each of the split. |
| 624 | ThreadPoolInterface *threadPool = nullptr; |
| 625 | |
| 626 | // Create a temporary context for the sake of checking if |
| 627 | // --mlir-disable-threading was passed on the command line. |
| 628 | // We use the thread-pool this context is creating, and avoid |
| 629 | // creating any thread when disabled. |
| 630 | MLIRContext threadPoolCtx; |
| 631 | if (threadPoolCtx.isMultithreadingEnabled()) |
| 632 | threadPool = &threadPoolCtx.getThreadPool(); |
| 633 | |
| 634 | auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer, |
| 635 | raw_ostream &os) { |
| 636 | return processBuffer(os, ownedBuffer: std::move(chunkBuffer), config, registry, |
| 637 | threadPool); |
| 638 | }; |
| 639 | return splitAndProcessBuffer(originalBuffer: std::move(buffer), processChunkBuffer: chunkFn, os&: outputStream, |
| 640 | inputSplitMarker: config.inputSplitMarker(), |
| 641 | outputSplitMarker: config.outputSplitMarker()); |
| 642 | } |
| 643 | |
| 644 | LogicalResult mlir::MlirOptMain(int argc, char **argv, |
| 645 | llvm::StringRef inputFilename, |
| 646 | llvm::StringRef outputFilename, |
| 647 | DialectRegistry ®istry) { |
| 648 | |
| 649 | InitLLVM y(argc, argv); |
| 650 | |
| 651 | MlirOptMainConfig config = MlirOptMainConfig::createFromCLOptions(); |
| 652 | |
| 653 | if (config.shouldShowDialects()) |
| 654 | return printRegisteredDialects(registry); |
| 655 | |
| 656 | if (config.shouldListPasses()) |
| 657 | return printRegisteredPassesAndReturn(); |
| 658 | |
| 659 | // When reading from stdin and the input is a tty, it is often a user mistake |
| 660 | // and the process "appears to be stuck". Print a message to let the user know |
| 661 | // about it! |
| 662 | if (inputFilename == "-" && |
| 663 | sys::Process::FileDescriptorIsDisplayed(fd: fileno(stdin))) |
| 664 | llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to " |
| 665 | "interrupt)\n" ; |
| 666 | |
| 667 | // Set up the input file. |
| 668 | std::string errorMessage; |
| 669 | auto file = openInputFile(inputFilename, errorMessage: &errorMessage); |
| 670 | if (!file) { |
| 671 | llvm::errs() << errorMessage << "\n" ; |
| 672 | return failure(); |
| 673 | } |
| 674 | |
| 675 | auto output = openOutputFile(outputFilename, errorMessage: &errorMessage); |
| 676 | if (!output) { |
| 677 | llvm::errs() << errorMessage << "\n" ; |
| 678 | return failure(); |
| 679 | } |
| 680 | if (failed(Result: MlirOptMain(outputStream&: output->os(), buffer: std::move(file), registry, config))) |
| 681 | return failure(); |
| 682 | |
| 683 | // Keep the output file if the invocation of MlirOptMain was successful. |
| 684 | output->keep(); |
| 685 | return success(); |
| 686 | } |
| 687 | |
| 688 | LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, |
| 689 | DialectRegistry ®istry) { |
| 690 | |
| 691 | // Register and parse command line options. |
| 692 | std::string inputFilename, outputFilename; |
| 693 | std::tie(args&: inputFilename, args&: outputFilename) = |
| 694 | registerAndParseCLIOptions(argc, argv, toolName, registry); |
| 695 | |
| 696 | return MlirOptMain(argc, argv, inputFilename, outputFilename, registry); |
| 697 | } |
| 698 | |