1//===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/Diagnostics.h"
11#include "mlir/IR/Dialect.h"
12#include "mlir/IR/SymbolTable.h"
13#include "mlir/IR/Verifier.h"
14#include "mlir/Parser/Parser.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Support/FileUtilities.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/CrashRecoveryContext.h"
22#include "llvm/Support/Mutex.h"
23#include "llvm/Support/Signals.h"
24#include "llvm/Support/Threading.h"
25#include "llvm/Support/ToolOutputFile.h"
26
27using namespace mlir;
28using namespace mlir::detail;
29
30//===----------------------------------------------------------------------===//
31// RecoveryReproducerContext
32//===----------------------------------------------------------------------===//
33
34namespace mlir {
35namespace detail {
36/// This class contains all of the context for generating a recovery reproducer.
37/// Each recovery context is registered globally to allow for generating
38/// reproducers when a signal is raised, such as a segfault.
39struct RecoveryReproducerContext {
40 RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
41 ReproducerStreamFactory &streamFactory,
42 bool verifyPasses);
43 ~RecoveryReproducerContext();
44
45 /// Generate a reproducer with the current context.
46 void generate(std::string &description);
47
48 /// Disable this reproducer context. This prevents the context from generating
49 /// a reproducer in the result of a crash.
50 void disable();
51
52 /// Enable a previously disabled reproducer context.
53 void enable();
54
55private:
56 /// This function is invoked in the event of a crash.
57 static void crashHandler(void *);
58
59 /// Register a signal handler to run in the event of a crash.
60 static void registerSignalHandler();
61
62 /// The textual description of the currently executing pipeline.
63 std::string pipelineElements;
64
65 /// The MLIR operation representing the IR before the crash.
66 Operation *preCrashOperation;
67
68 /// The factory for the reproducer output stream to use when generating the
69 /// reproducer.
70 ReproducerStreamFactory &streamFactory;
71
72 /// Various pass manager and context flags.
73 bool disableThreads;
74 bool verifyPasses;
75
76 /// The current set of active reproducer contexts. This is used in the event
77 /// of a crash. This is not thread_local as the pass manager may produce any
78 /// number of child threads. This uses a set to allow for multiple MLIR pass
79 /// managers to be running at the same time.
80 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
81 static llvm::ManagedStatic<
82 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
83 reproducerSet;
84};
85} // namespace detail
86} // namespace mlir
87
88llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
89 RecoveryReproducerContext::reproducerMutex;
90llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
91 RecoveryReproducerContext::reproducerSet;
92
93RecoveryReproducerContext::RecoveryReproducerContext(
94 std::string passPipelineStr, Operation *op,
95 ReproducerStreamFactory &streamFactory, bool verifyPasses)
96 : pipelineElements(std::move(passPipelineStr)),
97 preCrashOperation(op->clone()), streamFactory(streamFactory),
98 disableThreads(!op->getContext()->isMultithreadingEnabled()),
99 verifyPasses(verifyPasses) {
100 enable();
101}
102
103RecoveryReproducerContext::~RecoveryReproducerContext() {
104 // Erase the cloned preCrash IR that we cached.
105 preCrashOperation->erase();
106 disable();
107}
108
109static void appendReproducer(std::string &description, Operation *op,
110 const ReproducerStreamFactory &factory,
111 const std::string &pipelineElements,
112 bool disableThreads, bool verifyPasses) {
113 llvm::raw_string_ostream descOS(description);
114
115 // Try to create a new output stream for this crash reproducer.
116 std::string error;
117 std::unique_ptr<ReproducerStream> stream = factory(error);
118 if (!stream) {
119 descOS << "failed to create output stream: " << error;
120 return;
121 }
122 descOS << "reproducer generated at `" << stream->description() << "`";
123
124 std::string pipeline =
125 (op->getName().getStringRef() + "(" + pipelineElements + ")").str();
126 AsmState state(op);
127 state.attachResourcePrinter(
128 name: "mlir_reproducer", printFn: [&](Operation *op, AsmResourceBuilder &builder) {
129 builder.buildString(key: "pipeline", data: pipeline);
130 builder.buildBool(key: "disable_threading", data: disableThreads);
131 builder.buildBool(key: "verify_each", data: verifyPasses);
132 });
133
134 // Output the .mlir module.
135 op->print(os&: stream->os(), state);
136}
137
138void RecoveryReproducerContext::generate(std::string &description) {
139 appendReproducer(description, op: preCrashOperation, factory: streamFactory,
140 pipelineElements, disableThreads, verifyPasses);
141}
142
143void RecoveryReproducerContext::disable() {
144 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
145 reproducerSet->remove(X: this);
146 if (reproducerSet->empty())
147 llvm::CrashRecoveryContext::Disable();
148}
149
150void RecoveryReproducerContext::enable() {
151 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
152 if (reproducerSet->empty())
153 llvm::CrashRecoveryContext::Enable();
154 registerSignalHandler();
155 reproducerSet->insert(X: this);
156}
157
158void RecoveryReproducerContext::crashHandler(void *) {
159 // Walk the current stack of contexts and generate a reproducer for each one.
160 // We can't know for certain which one was the cause, so we need to generate
161 // a reproducer for all of them.
162 for (RecoveryReproducerContext *context : *reproducerSet) {
163 std::string description;
164 context->generate(description);
165
166 // Emit an error using information only available within the context.
167 emitError(loc: context->preCrashOperation->getLoc())
168 << "A signal was caught while processing the MLIR module:"
169 << description << "; marking pass as failed";
170 }
171}
172
173void RecoveryReproducerContext::registerSignalHandler() {
174 // Ensure that the handler is only registered once.
175 static bool registered =
176 (llvm::sys::AddSignalHandler(FnPtr: crashHandler, Cookie: nullptr), false);
177 (void)registered;
178}
179
180//===----------------------------------------------------------------------===//
181// PassCrashReproducerGenerator
182//===----------------------------------------------------------------------===//
183
184struct PassCrashReproducerGenerator::Impl {
185 Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
186 : streamFactory(streamFactory), localReproducer(localReproducer) {}
187
188 /// The factory to use when generating a crash reproducer.
189 ReproducerStreamFactory streamFactory;
190
191 /// Flag indicating if reproducer generation should be localized to the
192 /// failing pass.
193 bool localReproducer = false;
194
195 /// A record of all of the currently active reproducer contexts.
196 SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts;
197
198 /// The set of all currently running passes. Note: This is not populated when
199 /// `localReproducer` is true, as each pass will get its own recovery context.
200 SetVector<std::pair<Pass *, Operation *>> runningPasses;
201
202 /// Various pass manager flags that get emitted when generating a reproducer.
203 bool pmFlagVerifyPasses = false;
204};
205
206PassCrashReproducerGenerator::PassCrashReproducerGenerator(
207 ReproducerStreamFactory &streamFactory, bool localReproducer)
208 : impl(std::make_unique<Impl>(args&: streamFactory, args&: localReproducer)) {}
209PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;
210
211void PassCrashReproducerGenerator::initialize(
212 iterator_range<PassManager::pass_iterator> passes, Operation *op,
213 bool pmFlagVerifyPasses) {
214 assert((!impl->localReproducer ||
215 !op->getContext()->isMultithreadingEnabled()) &&
216 "expected multi-threading to be disabled when generating a local "
217 "reproducer");
218
219 llvm::CrashRecoveryContext::Enable();
220 impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
221
222 // If we aren't generating a local reproducer, prepare a reproducer for the
223 // given top-level operation.
224 if (!impl->localReproducer)
225 prepareReproducerFor(passes, op);
226}
227
228static void
229formatPassOpReproducerMessage(Diagnostic &os,
230 std::pair<Pass *, Operation *> passOpPair) {
231 os << "`" << passOpPair.first->getName() << "` on "
232 << "'" << passOpPair.second->getName() << "' operation";
233 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
234 os << ": @" << symbol.getName();
235}
236
237void PassCrashReproducerGenerator::finalize(Operation *rootOp,
238 LogicalResult executionResult) {
239 // Don't generate a reproducer if we have no active contexts.
240 if (impl->activeContexts.empty())
241 return;
242
243 // If the pass manager execution succeeded, we don't generate any reproducers.
244 if (succeeded(result: executionResult))
245 return impl->activeContexts.clear();
246
247 InFlightDiagnostic diag = emitError(loc: rootOp->getLoc())
248 << "Failures have been detected while "
249 "processing an MLIR pass pipeline";
250
251 // If we are generating a global reproducer, we include all of the running
252 // passes in the error message for the only active context.
253 if (!impl->localReproducer) {
254 assert(impl->activeContexts.size() == 1 && "expected one active context");
255
256 // Generate the reproducer.
257 std::string description;
258 impl->activeContexts.front()->generate(description);
259
260 // Emit an error to the user.
261 Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
262 llvm::interleaveComma(c: impl->runningPasses, os&: note,
263 each_fn: [&](const std::pair<Pass *, Operation *> &value) {
264 formatPassOpReproducerMessage(os&: note, passOpPair: value);
265 });
266 note << "]: " << description;
267 impl->runningPasses.clear();
268 impl->activeContexts.clear();
269 return;
270 }
271
272 // If we were generating a local reproducer, we generate a reproducer for the
273 // most recently executing pass using the matching entry from `runningPasses`
274 // to generate a localized diagnostic message.
275 assert(impl->activeContexts.size() == impl->runningPasses.size() &&
276 "expected running passes to match active contexts");
277
278 // Generate the reproducer.
279 RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
280 std::string description;
281 reproducerContext.generate(description);
282
283 // Emit an error to the user.
284 Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
285 formatPassOpReproducerMessage(os&: note, passOpPair: impl->runningPasses.back());
286 note << ": " << description;
287
288 impl->activeContexts.clear();
289 impl->runningPasses.clear();
290}
291
292void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass,
293 Operation *op) {
294 // If not tracking local reproducers, we simply remember that this pass is
295 // running.
296 impl->runningPasses.insert(X: std::make_pair(x&: pass, y&: op));
297 if (!impl->localReproducer)
298 return;
299
300 // Disable the current pass recovery context, if there is one. This may happen
301 // in the case of dynamic pass pipelines.
302 if (!impl->activeContexts.empty())
303 impl->activeContexts.back()->disable();
304
305 // Collect all of the parent scopes of this operation.
306 SmallVector<OperationName> scopes;
307 while (Operation *parentOp = op->getParentOp()) {
308 scopes.push_back(Elt: op->getName());
309 op = parentOp;
310 }
311
312 // Emit a pass pipeline string for the current pass running on the current
313 // operation type.
314 std::string passStr;
315 llvm::raw_string_ostream passOS(passStr);
316 for (OperationName scope : llvm::reverse(C&: scopes))
317 passOS << scope << "(";
318 pass->printAsTextualPipeline(os&: passOS);
319 for (unsigned i = 0, e = scopes.size(); i < e; ++i)
320 passOS << ")";
321
322 impl->activeContexts.push_back(Elt: std::make_unique<RecoveryReproducerContext>(
323 args&: passOS.str(), args&: op, args&: impl->streamFactory, args&: impl->pmFlagVerifyPasses));
324}
325void PassCrashReproducerGenerator::prepareReproducerFor(
326 iterator_range<PassManager::pass_iterator> passes, Operation *op) {
327 std::string passStr;
328 llvm::raw_string_ostream passOS(passStr);
329 llvm::interleaveComma(
330 c: passes, os&: passOS, each_fn: [&](Pass &pass) { pass.printAsTextualPipeline(os&: passOS); });
331
332 impl->activeContexts.push_back(Elt: std::make_unique<RecoveryReproducerContext>(
333 args&: passOS.str(), args&: op, args&: impl->streamFactory, args&: impl->pmFlagVerifyPasses));
334}
335
336void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass,
337 Operation *op) {
338 // We only pop the active context if we are tracking local reproducers.
339 impl->runningPasses.remove(X: std::make_pair(x&: pass, y&: op));
340 if (impl->localReproducer) {
341 impl->activeContexts.pop_back();
342
343 // Re-enable the previous pass recovery context, if there was one. This may
344 // happen in the case of dynamic pass pipelines.
345 if (!impl->activeContexts.empty())
346 impl->activeContexts.back()->enable();
347 }
348}
349
350//===----------------------------------------------------------------------===//
351// CrashReproducerInstrumentation
352//===----------------------------------------------------------------------===//
353
354namespace {
355struct CrashReproducerInstrumentation : public PassInstrumentation {
356 CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
357 : generator(generator) {}
358 ~CrashReproducerInstrumentation() override = default;
359
360 void runBeforePass(Pass *pass, Operation *op) override {
361 if (!isa<OpToOpPassAdaptor>(Val: pass))
362 generator.prepareReproducerFor(pass, op);
363 }
364
365 void runAfterPass(Pass *pass, Operation *op) override {
366 if (!isa<OpToOpPassAdaptor>(Val: pass))
367 generator.removeLastReproducerFor(pass, op);
368 }
369
370 void runAfterPassFailed(Pass *pass, Operation *op) override {
371 // Only generate one reproducer per crash reproducer instrumentation.
372 if (alreadyFailed)
373 return;
374
375 alreadyFailed = true;
376 generator.finalize(rootOp: op, /*executionResult=*/failure());
377 }
378
379private:
380 /// The generator used to create crash reproducers.
381 PassCrashReproducerGenerator &generator;
382 bool alreadyFailed = false;
383};
384} // namespace
385
386//===----------------------------------------------------------------------===//
387// FileReproducerStream
388//===----------------------------------------------------------------------===//
389
390namespace {
391/// This class represents a default instance of mlir::ReproducerStream
392/// that is backed by a file.
393struct FileReproducerStream : public mlir::ReproducerStream {
394 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
395 : outputFile(std::move(outputFile)) {}
396 ~FileReproducerStream() override { outputFile->keep(); }
397
398 /// Returns a description of the reproducer stream.
399 StringRef description() override { return outputFile->getFilename(); }
400
401 /// Returns the stream on which to output the reproducer.
402 raw_ostream &os() override { return outputFile->os(); }
403
404private:
405 /// ToolOutputFile corresponding to opened `filename`.
406 std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
407};
408} // namespace
409
410//===----------------------------------------------------------------------===//
411// PassManager
412//===----------------------------------------------------------------------===//
413
414LogicalResult PassManager::runWithCrashRecovery(Operation *op,
415 AnalysisManager am) {
416 crashReproGenerator->initialize(passes: getPasses(), op, pmFlagVerifyPasses: verifyPasses);
417
418 // Safely invoke the passes within a recovery context.
419 LogicalResult passManagerResult = failure();
420 llvm::CrashRecoveryContext recoveryContext;
421 recoveryContext.RunSafelyOnThread(
422 [&] { passManagerResult = runPasses(op, am); });
423 crashReproGenerator->finalize(rootOp: op, executionResult: passManagerResult);
424 return passManagerResult;
425}
426
427static ReproducerStreamFactory
428makeReproducerStreamFactory(StringRef outputFile) {
429 // Capture the filename by value in case outputFile is out of scope when
430 // invoked.
431 std::string filename = outputFile.str();
432 return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
433 std::unique_ptr<llvm::ToolOutputFile> outputFile =
434 mlir::openOutputFile(outputFilename: filename, errorMessage: &error);
435 if (!outputFile) {
436 error = "Failed to create reproducer stream: " + error;
437 return nullptr;
438 }
439 return std::make_unique<FileReproducerStream>(args: std::move(outputFile));
440 };
441}
442
443void printAsTextualPipeline(
444 raw_ostream &os, StringRef anchorName,
445 const llvm::iterator_range<OpPassManager::pass_iterator> &passes);
446
447std::string mlir::makeReproducer(
448 StringRef anchorName,
449 const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
450 Operation *op, StringRef outputFile, bool disableThreads,
451 bool verifyPasses) {
452
453 std::string description;
454 std::string pipelineStr;
455 llvm::raw_string_ostream passOS(pipelineStr);
456 ::printAsTextualPipeline(os&: passOS, anchorName, passes);
457 appendReproducer(description, op, factory: makeReproducerStreamFactory(outputFile),
458 pipelineElements: pipelineStr, disableThreads, verifyPasses);
459 return description;
460}
461
462void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
463 bool genLocalReproducer) {
464 enableCrashReproducerGeneration(factory: makeReproducerStreamFactory(outputFile),
465 genLocalReproducer);
466}
467
468void PassManager::enableCrashReproducerGeneration(
469 ReproducerStreamFactory factory, bool genLocalReproducer) {
470 assert(!crashReproGenerator &&
471 "crash reproducer has already been initialized");
472 if (genLocalReproducer && getContext()->isMultithreadingEnabled())
473 llvm::report_fatal_error(
474 reason: "Local crash reproduction can't be setup on a "
475 "pass-manager without disabling multi-threading first.");
476
477 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
478 args&: factory, args&: genLocalReproducer);
479 addInstrumentation(
480 pi: std::make_unique<CrashReproducerInstrumentation>(args&: *crashReproGenerator));
481}
482
483//===----------------------------------------------------------------------===//
484// Asm Resource
485//===----------------------------------------------------------------------===//
486
487void PassReproducerOptions::attachResourceParser(ParserConfig &config) {
488 auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult {
489 if (entry.getKey() == "pipeline") {
490 FailureOr<std::string> value = entry.parseAsString();
491 if (succeeded(result: value))
492 this->pipeline = std::move(*value);
493 return value;
494 }
495 if (entry.getKey() == "disable_threading") {
496 FailureOr<bool> value = entry.parseAsBool();
497 if (succeeded(result: value))
498 this->disableThreading = *value;
499 return value;
500 }
501 if (entry.getKey() == "verify_each") {
502 FailureOr<bool> value = entry.parseAsBool();
503 if (succeeded(result: value))
504 this->verifyEach = *value;
505 return value;
506 }
507 return entry.emitError() << "unknown 'mlir_reproducer' resource key '"
508 << entry.getKey() << "'";
509 };
510 config.attachResourceParser(name: "mlir_reproducer", parserFn&: parseFn);
511}
512
513LogicalResult PassReproducerOptions::apply(PassManager &pm) const {
514 if (pipeline.has_value()) {
515 FailureOr<OpPassManager> reproPm = parsePassPipeline(pipeline: *pipeline);
516 if (failed(result: reproPm))
517 return failure();
518 static_cast<OpPassManager &>(pm) = std::move(*reproPm);
519 }
520
521 if (disableThreading.has_value())
522 pm.getContext()->disableMultithreading(disable: *disableThreading);
523
524 if (verifyEach.has_value())
525 pm.enableVerifier(enabled: *verifyEach);
526
527 return success();
528}
529

source code of mlir/lib/Pass/PassCrashRecovery.cpp