1 | //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "PassDetail.h" |
10 | #include "mlir/IR/Diagnostics.h" |
11 | #include "mlir/IR/Dialect.h" |
12 | #include "mlir/IR/SymbolTable.h" |
13 | #include "mlir/IR/Verifier.h" |
14 | #include "mlir/Parser/Parser.h" |
15 | #include "mlir/Pass/Pass.h" |
16 | #include "mlir/Support/FileUtilities.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include "llvm/ADT/ScopeExit.h" |
19 | #include "llvm/ADT/SetVector.h" |
20 | #include "llvm/Support/CommandLine.h" |
21 | #include "llvm/Support/CrashRecoveryContext.h" |
22 | #include "llvm/Support/Mutex.h" |
23 | #include "llvm/Support/Signals.h" |
24 | #include "llvm/Support/Threading.h" |
25 | #include "llvm/Support/ToolOutputFile.h" |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::detail; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // RecoveryReproducerContext |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | namespace mlir { |
35 | namespace detail { |
36 | /// This class contains all of the context for generating a recovery reproducer. |
37 | /// Each recovery context is registered globally to allow for generating |
38 | /// reproducers when a signal is raised, such as a segfault. |
39 | struct RecoveryReproducerContext { |
40 | RecoveryReproducerContext(std::string passPipelineStr, Operation *op, |
41 | ReproducerStreamFactory &streamFactory, |
42 | bool verifyPasses); |
43 | ~RecoveryReproducerContext(); |
44 | |
45 | /// Generate a reproducer with the current context. |
46 | void generate(std::string &description); |
47 | |
48 | /// Disable this reproducer context. This prevents the context from generating |
49 | /// a reproducer in the result of a crash. |
50 | void disable(); |
51 | |
52 | /// Enable a previously disabled reproducer context. |
53 | void enable(); |
54 | |
55 | private: |
56 | /// This function is invoked in the event of a crash. |
57 | static void crashHandler(void *); |
58 | |
59 | /// Register a signal handler to run in the event of a crash. |
60 | static void registerSignalHandler(); |
61 | |
62 | /// The textual description of the currently executing pipeline. |
63 | std::string pipelineElements; |
64 | |
65 | /// The MLIR operation representing the IR before the crash. |
66 | Operation *preCrashOperation; |
67 | |
68 | /// The factory for the reproducer output stream to use when generating the |
69 | /// reproducer. |
70 | ReproducerStreamFactory &streamFactory; |
71 | |
72 | /// Various pass manager and context flags. |
73 | bool disableThreads; |
74 | bool verifyPasses; |
75 | |
76 | /// The current set of active reproducer contexts. This is used in the event |
77 | /// of a crash. This is not thread_local as the pass manager may produce any |
78 | /// number of child threads. This uses a set to allow for multiple MLIR pass |
79 | /// managers to be running at the same time. |
80 | static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex; |
81 | static llvm::ManagedStatic< |
82 | llvm::SmallSetVector<RecoveryReproducerContext *, 1>> |
83 | reproducerSet; |
84 | }; |
85 | } // namespace detail |
86 | } // namespace mlir |
87 | |
88 | llvm::ManagedStatic<llvm::sys::SmartMutex<true>> |
89 | RecoveryReproducerContext::reproducerMutex; |
90 | llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>> |
91 | RecoveryReproducerContext::reproducerSet; |
92 | |
93 | RecoveryReproducerContext::RecoveryReproducerContext( |
94 | std::string passPipelineStr, Operation *op, |
95 | ReproducerStreamFactory &streamFactory, bool verifyPasses) |
96 | : pipelineElements(std::move(passPipelineStr)), |
97 | preCrashOperation(op->clone()), streamFactory(streamFactory), |
98 | disableThreads(!op->getContext()->isMultithreadingEnabled()), |
99 | verifyPasses(verifyPasses) { |
100 | enable(); |
101 | } |
102 | |
103 | RecoveryReproducerContext::~RecoveryReproducerContext() { |
104 | // Erase the cloned preCrash IR that we cached. |
105 | preCrashOperation->erase(); |
106 | disable(); |
107 | } |
108 | |
109 | static void appendReproducer(std::string &description, Operation *op, |
110 | const ReproducerStreamFactory &factory, |
111 | const std::string &pipelineElements, |
112 | bool disableThreads, bool verifyPasses) { |
113 | llvm::raw_string_ostream descOS(description); |
114 | |
115 | // Try to create a new output stream for this crash reproducer. |
116 | std::string error; |
117 | std::unique_ptr<ReproducerStream> stream = factory(error); |
118 | if (!stream) { |
119 | descOS << "failed to create output stream: " << error; |
120 | return; |
121 | } |
122 | descOS << "reproducer generated at `" << stream->description() << "`" ; |
123 | |
124 | std::string pipeline = |
125 | (op->getName().getStringRef() + "(" + pipelineElements + ")" ).str(); |
126 | AsmState state(op); |
127 | state.attachResourcePrinter( |
128 | name: "mlir_reproducer" , printFn: [&](Operation *op, AsmResourceBuilder &builder) { |
129 | builder.buildString(key: "pipeline" , data: pipeline); |
130 | builder.buildBool(key: "disable_threading" , data: disableThreads); |
131 | builder.buildBool(key: "verify_each" , data: verifyPasses); |
132 | }); |
133 | |
134 | // Output the .mlir module. |
135 | op->print(os&: stream->os(), state); |
136 | } |
137 | |
138 | void RecoveryReproducerContext::generate(std::string &description) { |
139 | appendReproducer(description, op: preCrashOperation, factory: streamFactory, |
140 | pipelineElements, disableThreads, verifyPasses); |
141 | } |
142 | |
143 | void RecoveryReproducerContext::disable() { |
144 | llvm::sys::SmartScopedLock<true> lock(*reproducerMutex); |
145 | reproducerSet->remove(X: this); |
146 | if (reproducerSet->empty()) |
147 | llvm::CrashRecoveryContext::Disable(); |
148 | } |
149 | |
150 | void RecoveryReproducerContext::enable() { |
151 | llvm::sys::SmartScopedLock<true> lock(*reproducerMutex); |
152 | if (reproducerSet->empty()) |
153 | llvm::CrashRecoveryContext::Enable(); |
154 | registerSignalHandler(); |
155 | reproducerSet->insert(X: this); |
156 | } |
157 | |
158 | void RecoveryReproducerContext::crashHandler(void *) { |
159 | // Walk the current stack of contexts and generate a reproducer for each one. |
160 | // We can't know for certain which one was the cause, so we need to generate |
161 | // a reproducer for all of them. |
162 | for (RecoveryReproducerContext *context : *reproducerSet) { |
163 | std::string description; |
164 | context->generate(description); |
165 | |
166 | // Emit an error using information only available within the context. |
167 | emitError(loc: context->preCrashOperation->getLoc()) |
168 | << "A signal was caught while processing the MLIR module:" |
169 | << description << "; marking pass as failed" ; |
170 | } |
171 | } |
172 | |
173 | void RecoveryReproducerContext::registerSignalHandler() { |
174 | // Ensure that the handler is only registered once. |
175 | static bool registered = |
176 | (llvm::sys::AddSignalHandler(FnPtr: crashHandler, Cookie: nullptr), false); |
177 | (void)registered; |
178 | } |
179 | |
180 | //===----------------------------------------------------------------------===// |
181 | // PassCrashReproducerGenerator |
182 | //===----------------------------------------------------------------------===// |
183 | |
184 | struct PassCrashReproducerGenerator::Impl { |
185 | Impl(ReproducerStreamFactory &streamFactory, bool localReproducer) |
186 | : streamFactory(streamFactory), localReproducer(localReproducer) {} |
187 | |
188 | /// The factory to use when generating a crash reproducer. |
189 | ReproducerStreamFactory streamFactory; |
190 | |
191 | /// Flag indicating if reproducer generation should be localized to the |
192 | /// failing pass. |
193 | bool localReproducer = false; |
194 | |
195 | /// A record of all of the currently active reproducer contexts. |
196 | SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts; |
197 | |
198 | /// The set of all currently running passes. Note: This is not populated when |
199 | /// `localReproducer` is true, as each pass will get its own recovery context. |
200 | SetVector<std::pair<Pass *, Operation *>> runningPasses; |
201 | |
202 | /// Various pass manager flags that get emitted when generating a reproducer. |
203 | bool pmFlagVerifyPasses = false; |
204 | }; |
205 | |
206 | PassCrashReproducerGenerator::PassCrashReproducerGenerator( |
207 | ReproducerStreamFactory &streamFactory, bool localReproducer) |
208 | : impl(std::make_unique<Impl>(args&: streamFactory, args&: localReproducer)) {} |
209 | PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default; |
210 | |
211 | void PassCrashReproducerGenerator::initialize( |
212 | iterator_range<PassManager::pass_iterator> passes, Operation *op, |
213 | bool pmFlagVerifyPasses) { |
214 | assert((!impl->localReproducer || |
215 | !op->getContext()->isMultithreadingEnabled()) && |
216 | "expected multi-threading to be disabled when generating a local " |
217 | "reproducer" ); |
218 | |
219 | llvm::CrashRecoveryContext::Enable(); |
220 | impl->pmFlagVerifyPasses = pmFlagVerifyPasses; |
221 | |
222 | // If we aren't generating a local reproducer, prepare a reproducer for the |
223 | // given top-level operation. |
224 | if (!impl->localReproducer) |
225 | prepareReproducerFor(passes, op); |
226 | } |
227 | |
228 | static void |
229 | formatPassOpReproducerMessage(Diagnostic &os, |
230 | std::pair<Pass *, Operation *> passOpPair) { |
231 | os << "`" << passOpPair.first->getName() << "` on " |
232 | << "'" << passOpPair.second->getName() << "' operation" ; |
233 | if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second)) |
234 | os << ": @" << symbol.getName(); |
235 | } |
236 | |
237 | void PassCrashReproducerGenerator::finalize(Operation *rootOp, |
238 | LogicalResult executionResult) { |
239 | // Don't generate a reproducer if we have no active contexts. |
240 | if (impl->activeContexts.empty()) |
241 | return; |
242 | |
243 | // If the pass manager execution succeeded, we don't generate any reproducers. |
244 | if (succeeded(result: executionResult)) |
245 | return impl->activeContexts.clear(); |
246 | |
247 | InFlightDiagnostic diag = emitError(loc: rootOp->getLoc()) |
248 | << "Failures have been detected while " |
249 | "processing an MLIR pass pipeline" ; |
250 | |
251 | // If we are generating a global reproducer, we include all of the running |
252 | // passes in the error message for the only active context. |
253 | if (!impl->localReproducer) { |
254 | assert(impl->activeContexts.size() == 1 && "expected one active context" ); |
255 | |
256 | // Generate the reproducer. |
257 | std::string description; |
258 | impl->activeContexts.front()->generate(description); |
259 | |
260 | // Emit an error to the user. |
261 | Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing [" ; |
262 | llvm::interleaveComma(c: impl->runningPasses, os&: note, |
263 | each_fn: [&](const std::pair<Pass *, Operation *> &value) { |
264 | formatPassOpReproducerMessage(os&: note, passOpPair: value); |
265 | }); |
266 | note << "]: " << description; |
267 | impl->runningPasses.clear(); |
268 | impl->activeContexts.clear(); |
269 | return; |
270 | } |
271 | |
272 | // If we were generating a local reproducer, we generate a reproducer for the |
273 | // most recently executing pass using the matching entry from `runningPasses` |
274 | // to generate a localized diagnostic message. |
275 | assert(impl->activeContexts.size() == impl->runningPasses.size() && |
276 | "expected running passes to match active contexts" ); |
277 | |
278 | // Generate the reproducer. |
279 | RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back(); |
280 | std::string description; |
281 | reproducerContext.generate(description); |
282 | |
283 | // Emit an error to the user. |
284 | Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing " ; |
285 | formatPassOpReproducerMessage(os&: note, passOpPair: impl->runningPasses.back()); |
286 | note << ": " << description; |
287 | |
288 | impl->activeContexts.clear(); |
289 | impl->runningPasses.clear(); |
290 | } |
291 | |
292 | void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass, |
293 | Operation *op) { |
294 | // If not tracking local reproducers, we simply remember that this pass is |
295 | // running. |
296 | impl->runningPasses.insert(X: std::make_pair(x&: pass, y&: op)); |
297 | if (!impl->localReproducer) |
298 | return; |
299 | |
300 | // Disable the current pass recovery context, if there is one. This may happen |
301 | // in the case of dynamic pass pipelines. |
302 | if (!impl->activeContexts.empty()) |
303 | impl->activeContexts.back()->disable(); |
304 | |
305 | // Collect all of the parent scopes of this operation. |
306 | SmallVector<OperationName> scopes; |
307 | while (Operation *parentOp = op->getParentOp()) { |
308 | scopes.push_back(Elt: op->getName()); |
309 | op = parentOp; |
310 | } |
311 | |
312 | // Emit a pass pipeline string for the current pass running on the current |
313 | // operation type. |
314 | std::string passStr; |
315 | llvm::raw_string_ostream passOS(passStr); |
316 | for (OperationName scope : llvm::reverse(C&: scopes)) |
317 | passOS << scope << "(" ; |
318 | pass->printAsTextualPipeline(os&: passOS); |
319 | for (unsigned i = 0, e = scopes.size(); i < e; ++i) |
320 | passOS << ")" ; |
321 | |
322 | impl->activeContexts.push_back(Elt: std::make_unique<RecoveryReproducerContext>( |
323 | args&: passOS.str(), args&: op, args&: impl->streamFactory, args&: impl->pmFlagVerifyPasses)); |
324 | } |
325 | void PassCrashReproducerGenerator::prepareReproducerFor( |
326 | iterator_range<PassManager::pass_iterator> passes, Operation *op) { |
327 | std::string passStr; |
328 | llvm::raw_string_ostream passOS(passStr); |
329 | llvm::interleaveComma( |
330 | c: passes, os&: passOS, each_fn: [&](Pass &pass) { pass.printAsTextualPipeline(os&: passOS); }); |
331 | |
332 | impl->activeContexts.push_back(Elt: std::make_unique<RecoveryReproducerContext>( |
333 | args&: passOS.str(), args&: op, args&: impl->streamFactory, args&: impl->pmFlagVerifyPasses)); |
334 | } |
335 | |
336 | void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass, |
337 | Operation *op) { |
338 | // We only pop the active context if we are tracking local reproducers. |
339 | impl->runningPasses.remove(X: std::make_pair(x&: pass, y&: op)); |
340 | if (impl->localReproducer) { |
341 | impl->activeContexts.pop_back(); |
342 | |
343 | // Re-enable the previous pass recovery context, if there was one. This may |
344 | // happen in the case of dynamic pass pipelines. |
345 | if (!impl->activeContexts.empty()) |
346 | impl->activeContexts.back()->enable(); |
347 | } |
348 | } |
349 | |
350 | //===----------------------------------------------------------------------===// |
351 | // CrashReproducerInstrumentation |
352 | //===----------------------------------------------------------------------===// |
353 | |
354 | namespace { |
355 | struct CrashReproducerInstrumentation : public PassInstrumentation { |
356 | CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator) |
357 | : generator(generator) {} |
358 | ~CrashReproducerInstrumentation() override = default; |
359 | |
360 | void runBeforePass(Pass *pass, Operation *op) override { |
361 | if (!isa<OpToOpPassAdaptor>(Val: pass)) |
362 | generator.prepareReproducerFor(pass, op); |
363 | } |
364 | |
365 | void runAfterPass(Pass *pass, Operation *op) override { |
366 | if (!isa<OpToOpPassAdaptor>(Val: pass)) |
367 | generator.removeLastReproducerFor(pass, op); |
368 | } |
369 | |
370 | void runAfterPassFailed(Pass *pass, Operation *op) override { |
371 | // Only generate one reproducer per crash reproducer instrumentation. |
372 | if (alreadyFailed) |
373 | return; |
374 | |
375 | alreadyFailed = true; |
376 | generator.finalize(rootOp: op, /*executionResult=*/failure()); |
377 | } |
378 | |
379 | private: |
380 | /// The generator used to create crash reproducers. |
381 | PassCrashReproducerGenerator &generator; |
382 | bool alreadyFailed = false; |
383 | }; |
384 | } // namespace |
385 | |
386 | //===----------------------------------------------------------------------===// |
387 | // FileReproducerStream |
388 | //===----------------------------------------------------------------------===// |
389 | |
390 | namespace { |
391 | /// This class represents a default instance of mlir::ReproducerStream |
392 | /// that is backed by a file. |
393 | struct FileReproducerStream : public mlir::ReproducerStream { |
394 | FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile) |
395 | : outputFile(std::move(outputFile)) {} |
396 | ~FileReproducerStream() override { outputFile->keep(); } |
397 | |
398 | /// Returns a description of the reproducer stream. |
399 | StringRef description() override { return outputFile->getFilename(); } |
400 | |
401 | /// Returns the stream on which to output the reproducer. |
402 | raw_ostream &os() override { return outputFile->os(); } |
403 | |
404 | private: |
405 | /// ToolOutputFile corresponding to opened `filename`. |
406 | std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr; |
407 | }; |
408 | } // namespace |
409 | |
410 | //===----------------------------------------------------------------------===// |
411 | // PassManager |
412 | //===----------------------------------------------------------------------===// |
413 | |
414 | LogicalResult PassManager::runWithCrashRecovery(Operation *op, |
415 | AnalysisManager am) { |
416 | crashReproGenerator->initialize(passes: getPasses(), op, pmFlagVerifyPasses: verifyPasses); |
417 | |
418 | // Safely invoke the passes within a recovery context. |
419 | LogicalResult passManagerResult = failure(); |
420 | llvm::CrashRecoveryContext recoveryContext; |
421 | recoveryContext.RunSafelyOnThread( |
422 | [&] { passManagerResult = runPasses(op, am); }); |
423 | crashReproGenerator->finalize(rootOp: op, executionResult: passManagerResult); |
424 | return passManagerResult; |
425 | } |
426 | |
427 | static ReproducerStreamFactory |
428 | makeReproducerStreamFactory(StringRef outputFile) { |
429 | // Capture the filename by value in case outputFile is out of scope when |
430 | // invoked. |
431 | std::string filename = outputFile.str(); |
432 | return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> { |
433 | std::unique_ptr<llvm::ToolOutputFile> outputFile = |
434 | mlir::openOutputFile(outputFilename: filename, errorMessage: &error); |
435 | if (!outputFile) { |
436 | error = "Failed to create reproducer stream: " + error; |
437 | return nullptr; |
438 | } |
439 | return std::make_unique<FileReproducerStream>(args: std::move(outputFile)); |
440 | }; |
441 | } |
442 | |
443 | void printAsTextualPipeline( |
444 | raw_ostream &os, StringRef anchorName, |
445 | const llvm::iterator_range<OpPassManager::pass_iterator> &passes); |
446 | |
447 | std::string mlir::makeReproducer( |
448 | StringRef anchorName, |
449 | const llvm::iterator_range<OpPassManager::pass_iterator> &passes, |
450 | Operation *op, StringRef outputFile, bool disableThreads, |
451 | bool verifyPasses) { |
452 | |
453 | std::string description; |
454 | std::string pipelineStr; |
455 | llvm::raw_string_ostream passOS(pipelineStr); |
456 | ::printAsTextualPipeline(os&: passOS, anchorName, passes); |
457 | appendReproducer(description, op, factory: makeReproducerStreamFactory(outputFile), |
458 | pipelineElements: pipelineStr, disableThreads, verifyPasses); |
459 | return description; |
460 | } |
461 | |
462 | void PassManager::enableCrashReproducerGeneration(StringRef outputFile, |
463 | bool genLocalReproducer) { |
464 | enableCrashReproducerGeneration(factory: makeReproducerStreamFactory(outputFile), |
465 | genLocalReproducer); |
466 | } |
467 | |
468 | void PassManager::enableCrashReproducerGeneration( |
469 | ReproducerStreamFactory factory, bool genLocalReproducer) { |
470 | assert(!crashReproGenerator && |
471 | "crash reproducer has already been initialized" ); |
472 | if (genLocalReproducer && getContext()->isMultithreadingEnabled()) |
473 | llvm::report_fatal_error( |
474 | reason: "Local crash reproduction can't be setup on a " |
475 | "pass-manager without disabling multi-threading first." ); |
476 | |
477 | crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>( |
478 | args&: factory, args&: genLocalReproducer); |
479 | addInstrumentation( |
480 | pi: std::make_unique<CrashReproducerInstrumentation>(args&: *crashReproGenerator)); |
481 | } |
482 | |
483 | //===----------------------------------------------------------------------===// |
484 | // Asm Resource |
485 | //===----------------------------------------------------------------------===// |
486 | |
487 | void PassReproducerOptions::attachResourceParser(ParserConfig &config) { |
488 | auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult { |
489 | if (entry.getKey() == "pipeline" ) { |
490 | FailureOr<std::string> value = entry.parseAsString(); |
491 | if (succeeded(result: value)) |
492 | this->pipeline = std::move(*value); |
493 | return value; |
494 | } |
495 | if (entry.getKey() == "disable_threading" ) { |
496 | FailureOr<bool> value = entry.parseAsBool(); |
497 | if (succeeded(result: value)) |
498 | this->disableThreading = *value; |
499 | return value; |
500 | } |
501 | if (entry.getKey() == "verify_each" ) { |
502 | FailureOr<bool> value = entry.parseAsBool(); |
503 | if (succeeded(result: value)) |
504 | this->verifyEach = *value; |
505 | return value; |
506 | } |
507 | return entry.emitError() << "unknown 'mlir_reproducer' resource key '" |
508 | << entry.getKey() << "'" ; |
509 | }; |
510 | config.attachResourceParser(name: "mlir_reproducer" , parserFn&: parseFn); |
511 | } |
512 | |
513 | LogicalResult PassReproducerOptions::apply(PassManager &pm) const { |
514 | if (pipeline.has_value()) { |
515 | FailureOr<OpPassManager> reproPm = parsePassPipeline(pipeline: *pipeline); |
516 | if (failed(result: reproPm)) |
517 | return failure(); |
518 | static_cast<OpPassManager &>(pm) = std::move(*reproPm); |
519 | } |
520 | |
521 | if (disableThreading.has_value()) |
522 | pm.getContext()->disableMultithreading(disable: *disableThreading); |
523 | |
524 | if (verifyEach.has_value()) |
525 | pm.enableVerifier(enabled: *verifyEach); |
526 | |
527 | return success(); |
528 | } |
529 | |