1//===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/Diagnostics.h"
11#include "mlir/IR/Dialect.h"
12#include "mlir/IR/SymbolTable.h"
13#include "mlir/IR/Verifier.h"
14#include "mlir/Parser/Parser.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Support/FileUtilities.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/CrashRecoveryContext.h"
22#include "llvm/Support/ManagedStatic.h"
23#include "llvm/Support/Mutex.h"
24#include "llvm/Support/Signals.h"
25#include "llvm/Support/Threading.h"
26#include "llvm/Support/ToolOutputFile.h"
27
28using namespace mlir;
29using namespace mlir::detail;
30
31//===----------------------------------------------------------------------===//
32// RecoveryReproducerContext
33//===----------------------------------------------------------------------===//
34
35namespace mlir {
36namespace detail {
37/// This class contains all of the context for generating a recovery reproducer.
38/// Each recovery context is registered globally to allow for generating
39/// reproducers when a signal is raised, such as a segfault.
40struct RecoveryReproducerContext {
41 RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
42 ReproducerStreamFactory &streamFactory,
43 bool verifyPasses);
44 ~RecoveryReproducerContext();
45
46 /// Generate a reproducer with the current context.
47 void generate(std::string &description);
48
49 /// Disable this reproducer context. This prevents the context from generating
50 /// a reproducer in the result of a crash.
51 void disable();
52
53 /// Enable a previously disabled reproducer context.
54 void enable();
55
56private:
57 /// This function is invoked in the event of a crash.
58 static void crashHandler(void *);
59
60 /// Register a signal handler to run in the event of a crash.
61 static void registerSignalHandler();
62
63 /// The textual description of the currently executing pipeline.
64 std::string pipelineElements;
65
66 /// The MLIR operation representing the IR before the crash.
67 Operation *preCrashOperation;
68
69 /// The factory for the reproducer output stream to use when generating the
70 /// reproducer.
71 ReproducerStreamFactory &streamFactory;
72
73 /// Various pass manager and context flags.
74 bool disableThreads;
75 bool verifyPasses;
76
77 /// The current set of active reproducer contexts. This is used in the event
78 /// of a crash. This is not thread_local as the pass manager may produce any
79 /// number of child threads. This uses a set to allow for multiple MLIR pass
80 /// managers to be running at the same time.
81 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
82 static llvm::ManagedStatic<
83 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
84 reproducerSet;
85};
86} // namespace detail
87} // namespace mlir
88
89llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
90 RecoveryReproducerContext::reproducerMutex;
91llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
92 RecoveryReproducerContext::reproducerSet;
93
94RecoveryReproducerContext::RecoveryReproducerContext(
95 std::string passPipelineStr, Operation *op,
96 ReproducerStreamFactory &streamFactory, bool verifyPasses)
97 : pipelineElements(std::move(passPipelineStr)),
98 preCrashOperation(op->clone()), streamFactory(streamFactory),
99 disableThreads(!op->getContext()->isMultithreadingEnabled()),
100 verifyPasses(verifyPasses) {
101 enable();
102}
103
104RecoveryReproducerContext::~RecoveryReproducerContext() {
105 // Erase the cloned preCrash IR that we cached.
106 preCrashOperation->erase();
107 disable();
108}
109
110static void appendReproducer(std::string &description, Operation *op,
111 const ReproducerStreamFactory &factory,
112 const std::string &pipelineElements,
113 bool disableThreads, bool verifyPasses) {
114 llvm::raw_string_ostream descOS(description);
115
116 // Try to create a new output stream for this crash reproducer.
117 std::string error;
118 std::unique_ptr<ReproducerStream> stream = factory(error);
119 if (!stream) {
120 descOS << "failed to create output stream: " << error;
121 return;
122 }
123 descOS << "reproducer generated at `" << stream->description() << "`";
124
125 std::string pipeline =
126 (op->getName().getStringRef() + "(" + pipelineElements + ")").str();
127 AsmState state(op);
128 state.attachResourcePrinter(
129 name: "mlir_reproducer", printFn: [&](Operation *op, AsmResourceBuilder &builder) {
130 builder.buildString(key: "pipeline", data: pipeline);
131 builder.buildBool(key: "disable_threading", data: disableThreads);
132 builder.buildBool(key: "verify_each", data: verifyPasses);
133 });
134
135 // Output the .mlir module.
136 op->print(os&: stream->os(), state);
137}
138
139void RecoveryReproducerContext::generate(std::string &description) {
140 appendReproducer(description, op: preCrashOperation, factory: streamFactory,
141 pipelineElements, disableThreads, verifyPasses);
142}
143
144void RecoveryReproducerContext::disable() {
145 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
146 reproducerSet->remove(X: this);
147 if (reproducerSet->empty())
148 llvm::CrashRecoveryContext::Disable();
149}
150
151void RecoveryReproducerContext::enable() {
152 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
153 if (reproducerSet->empty())
154 llvm::CrashRecoveryContext::Enable();
155 registerSignalHandler();
156 reproducerSet->insert(X: this);
157}
158
159void RecoveryReproducerContext::crashHandler(void *) {
160 // Walk the current stack of contexts and generate a reproducer for each one.
161 // We can't know for certain which one was the cause, so we need to generate
162 // a reproducer for all of them.
163 for (RecoveryReproducerContext *context : *reproducerSet) {
164 std::string description;
165 context->generate(description);
166
167 // Emit an error using information only available within the context.
168 emitError(loc: context->preCrashOperation->getLoc())
169 << "A signal was caught while processing the MLIR module:"
170 << description << "; marking pass as failed";
171 }
172}
173
174void RecoveryReproducerContext::registerSignalHandler() {
175 // Ensure that the handler is only registered once.
176 static bool registered =
177 (llvm::sys::AddSignalHandler(FnPtr: crashHandler, Cookie: nullptr), false);
178 (void)registered;
179}
180
181//===----------------------------------------------------------------------===//
182// PassCrashReproducerGenerator
183//===----------------------------------------------------------------------===//
184
185struct PassCrashReproducerGenerator::Impl {
186 Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
187 : streamFactory(streamFactory), localReproducer(localReproducer) {}
188
189 /// The factory to use when generating a crash reproducer.
190 ReproducerStreamFactory streamFactory;
191
192 /// Flag indicating if reproducer generation should be localized to the
193 /// failing pass.
194 bool localReproducer = false;
195
196 /// A record of all of the currently active reproducer contexts.
197 SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts;
198
199 /// The set of all currently running passes. Note: This is not populated when
200 /// `localReproducer` is true, as each pass will get its own recovery context.
201 SetVector<std::pair<Pass *, Operation *>> runningPasses;
202
203 /// Various pass manager flags that get emitted when generating a reproducer.
204 bool pmFlagVerifyPasses = false;
205};
206
207PassCrashReproducerGenerator::PassCrashReproducerGenerator(
208 ReproducerStreamFactory &streamFactory, bool localReproducer)
209 : impl(std::make_unique<Impl>(args&: streamFactory, args&: localReproducer)) {}
210PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;
211
212void PassCrashReproducerGenerator::initialize(
213 iterator_range<PassManager::pass_iterator> passes, Operation *op,
214 bool pmFlagVerifyPasses) {
215 assert((!impl->localReproducer ||
216 !op->getContext()->isMultithreadingEnabled()) &&
217 "expected multi-threading to be disabled when generating a local "
218 "reproducer");
219
220 llvm::CrashRecoveryContext::Enable();
221 impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
222
223 // If we aren't generating a local reproducer, prepare a reproducer for the
224 // given top-level operation.
225 if (!impl->localReproducer)
226 prepareReproducerFor(passes, op);
227}
228
229static void
230formatPassOpReproducerMessage(Diagnostic &os,
231 std::pair<Pass *, Operation *> passOpPair) {
232 os << "`" << passOpPair.first->getName() << "` on "
233 << "'" << passOpPair.second->getName() << "' operation";
234 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
235 os << ": @" << symbol.getName();
236}
237
238void PassCrashReproducerGenerator::finalize(Operation *rootOp,
239 LogicalResult executionResult) {
240 // Don't generate a reproducer if we have no active contexts.
241 if (impl->activeContexts.empty())
242 return;
243
244 // If the pass manager execution succeeded, we don't generate any reproducers.
245 if (succeeded(Result: executionResult))
246 return impl->activeContexts.clear();
247
248 InFlightDiagnostic diag = emitError(loc: rootOp->getLoc())
249 << "Failures have been detected while "
250 "processing an MLIR pass pipeline";
251
252 // If we are generating a global reproducer, we include all of the running
253 // passes in the error message for the only active context.
254 if (!impl->localReproducer) {
255 assert(impl->activeContexts.size() == 1 && "expected one active context");
256
257 // Generate the reproducer.
258 std::string description;
259 impl->activeContexts.front()->generate(description);
260
261 // Emit an error to the user.
262 Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
263 llvm::interleaveComma(c: impl->runningPasses, os&: note,
264 each_fn: [&](const std::pair<Pass *, Operation *> &value) {
265 formatPassOpReproducerMessage(os&: note, passOpPair: value);
266 });
267 note << "]: " << description;
268 impl->runningPasses.clear();
269 impl->activeContexts.clear();
270 return;
271 }
272
273 // If we were generating a local reproducer, we generate a reproducer for the
274 // most recently executing pass using the matching entry from `runningPasses`
275 // to generate a localized diagnostic message.
276 assert(impl->activeContexts.size() == impl->runningPasses.size() &&
277 "expected running passes to match active contexts");
278
279 // Generate the reproducer.
280 RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
281 std::string description;
282 reproducerContext.generate(description);
283
284 // Emit an error to the user.
285 Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
286 formatPassOpReproducerMessage(os&: note, passOpPair: impl->runningPasses.back());
287 note << ": " << description;
288
289 impl->activeContexts.clear();
290 impl->runningPasses.clear();
291}
292
293void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass,
294 Operation *op) {
295 // If not tracking local reproducers, we simply remember that this pass is
296 // running.
297 impl->runningPasses.insert(X: std::make_pair(x&: pass, y&: op));
298 if (!impl->localReproducer)
299 return;
300
301 // Disable the current pass recovery context, if there is one. This may happen
302 // in the case of dynamic pass pipelines.
303 if (!impl->activeContexts.empty())
304 impl->activeContexts.back()->disable();
305
306 // Collect all of the parent scopes of this operation.
307 SmallVector<OperationName> scopes;
308 while (Operation *parentOp = op->getParentOp()) {
309 scopes.push_back(Elt: op->getName());
310 op = parentOp;
311 }
312
313 // Emit a pass pipeline string for the current pass running on the current
314 // operation type.
315 std::string passStr;
316 llvm::raw_string_ostream passOS(passStr);
317 for (OperationName scope : llvm::reverse(C&: scopes))
318 passOS << scope << "(";
319 pass->printAsTextualPipeline(os&: passOS);
320 for (unsigned i = 0, e = scopes.size(); i < e; ++i)
321 passOS << ")";
322
323 impl->activeContexts.push_back(Elt: std::make_unique<RecoveryReproducerContext>(
324 args&: passStr, args&: op, args&: impl->streamFactory, args&: impl->pmFlagVerifyPasses));
325}
326void PassCrashReproducerGenerator::prepareReproducerFor(
327 iterator_range<PassManager::pass_iterator> passes, Operation *op) {
328 std::string passStr;
329 llvm::raw_string_ostream passOS(passStr);
330 llvm::interleaveComma(
331 c: passes, os&: passOS, each_fn: [&](Pass &pass) { pass.printAsTextualPipeline(os&: passOS); });
332
333 impl->activeContexts.push_back(Elt: std::make_unique<RecoveryReproducerContext>(
334 args&: passStr, args&: op, args&: impl->streamFactory, args&: impl->pmFlagVerifyPasses));
335}
336
337void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass,
338 Operation *op) {
339 // We only pop the active context if we are tracking local reproducers.
340 impl->runningPasses.remove(X: std::make_pair(x&: pass, y&: op));
341 if (impl->localReproducer) {
342 impl->activeContexts.pop_back();
343
344 // Re-enable the previous pass recovery context, if there was one. This may
345 // happen in the case of dynamic pass pipelines.
346 if (!impl->activeContexts.empty())
347 impl->activeContexts.back()->enable();
348 }
349}
350
351//===----------------------------------------------------------------------===//
352// CrashReproducerInstrumentation
353//===----------------------------------------------------------------------===//
354
355namespace {
356struct CrashReproducerInstrumentation : public PassInstrumentation {
357 CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
358 : generator(generator) {}
359 ~CrashReproducerInstrumentation() override = default;
360
361 void runBeforePass(Pass *pass, Operation *op) override {
362 if (!isa<OpToOpPassAdaptor>(Val: pass))
363 generator.prepareReproducerFor(pass, op);
364 }
365
366 void runAfterPass(Pass *pass, Operation *op) override {
367 if (!isa<OpToOpPassAdaptor>(Val: pass))
368 generator.removeLastReproducerFor(pass, op);
369 }
370
371 void runAfterPassFailed(Pass *pass, Operation *op) override {
372 // Only generate one reproducer per crash reproducer instrumentation.
373 if (alreadyFailed)
374 return;
375
376 alreadyFailed = true;
377 generator.finalize(rootOp: op, /*executionResult=*/failure());
378 }
379
380private:
381 /// The generator used to create crash reproducers.
382 PassCrashReproducerGenerator &generator;
383 bool alreadyFailed = false;
384};
385} // namespace
386
387//===----------------------------------------------------------------------===//
388// FileReproducerStream
389//===----------------------------------------------------------------------===//
390
391namespace {
392/// This class represents a default instance of mlir::ReproducerStream
393/// that is backed by a file.
394struct FileReproducerStream : public mlir::ReproducerStream {
395 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
396 : outputFile(std::move(outputFile)) {}
397 ~FileReproducerStream() override { outputFile->keep(); }
398
399 /// Returns a description of the reproducer stream.
400 StringRef description() override { return outputFile->getFilename(); }
401
402 /// Returns the stream on which to output the reproducer.
403 raw_ostream &os() override { return outputFile->os(); }
404
405private:
406 /// ToolOutputFile corresponding to opened `filename`.
407 std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
408};
409} // namespace
410
411//===----------------------------------------------------------------------===//
412// PassManager
413//===----------------------------------------------------------------------===//
414
415LogicalResult PassManager::runWithCrashRecovery(Operation *op,
416 AnalysisManager am) {
417 crashReproGenerator->initialize(passes: getPasses(), op, pmFlagVerifyPasses: verifyPasses);
418
419 // Safely invoke the passes within a recovery context.
420 LogicalResult passManagerResult = failure();
421 llvm::CrashRecoveryContext recoveryContext;
422 recoveryContext.RunSafelyOnThread(
423 [&] { passManagerResult = runPasses(op, am); });
424 crashReproGenerator->finalize(rootOp: op, executionResult: passManagerResult);
425 return passManagerResult;
426}
427
428static ReproducerStreamFactory
429makeReproducerStreamFactory(StringRef outputFile) {
430 // Capture the filename by value in case outputFile is out of scope when
431 // invoked.
432 std::string filename = outputFile.str();
433 return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
434 std::unique_ptr<llvm::ToolOutputFile> outputFile =
435 mlir::openOutputFile(outputFilename: filename, errorMessage: &error);
436 if (!outputFile) {
437 error = "Failed to create reproducer stream: " + error;
438 return nullptr;
439 }
440 return std::make_unique<FileReproducerStream>(args: std::move(outputFile));
441 };
442}
443
444void printAsTextualPipeline(
445 raw_ostream &os, StringRef anchorName,
446 const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
447 bool pretty = false);
448
449std::string mlir::makeReproducer(
450 StringRef anchorName,
451 const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
452 Operation *op, StringRef outputFile, bool disableThreads,
453 bool verifyPasses) {
454
455 std::string description;
456 std::string pipelineStr;
457 llvm::raw_string_ostream passOS(pipelineStr);
458 ::printAsTextualPipeline(os&: passOS, anchorName, passes);
459 appendReproducer(description, op, factory: makeReproducerStreamFactory(outputFile),
460 pipelineElements: pipelineStr, disableThreads, verifyPasses);
461 return description;
462}
463
464void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
465 bool genLocalReproducer) {
466 enableCrashReproducerGeneration(factory: makeReproducerStreamFactory(outputFile),
467 genLocalReproducer);
468}
469
470void PassManager::enableCrashReproducerGeneration(
471 ReproducerStreamFactory factory, bool genLocalReproducer) {
472 assert(!crashReproGenerator &&
473 "crash reproducer has already been initialized");
474 if (genLocalReproducer && getContext()->isMultithreadingEnabled())
475 llvm::report_fatal_error(
476 reason: "Local crash reproduction can't be setup on a "
477 "pass-manager without disabling multi-threading first.");
478
479 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
480 args&: factory, args&: genLocalReproducer);
481 addInstrumentation(
482 pi: std::make_unique<CrashReproducerInstrumentation>(args&: *crashReproGenerator));
483}
484
485//===----------------------------------------------------------------------===//
486// Asm Resource
487//===----------------------------------------------------------------------===//
488
489void PassReproducerOptions::attachResourceParser(ParserConfig &config) {
490 auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult {
491 if (entry.getKey() == "pipeline") {
492 FailureOr<std::string> value = entry.parseAsString();
493 if (succeeded(Result: value))
494 this->pipeline = std::move(*value);
495 return value;
496 }
497 if (entry.getKey() == "disable_threading") {
498 FailureOr<bool> value = entry.parseAsBool();
499 if (succeeded(Result: value))
500 this->disableThreading = *value;
501 return value;
502 }
503 if (entry.getKey() == "verify_each") {
504 FailureOr<bool> value = entry.parseAsBool();
505 if (succeeded(Result: value))
506 this->verifyEach = *value;
507 return value;
508 }
509 return entry.emitError() << "unknown 'mlir_reproducer' resource key '"
510 << entry.getKey() << "'";
511 };
512 config.attachResourceParser(name: "mlir_reproducer", parserFn&: parseFn);
513}
514
515LogicalResult PassReproducerOptions::apply(PassManager &pm) const {
516 if (pipeline.has_value()) {
517 FailureOr<OpPassManager> reproPm = parsePassPipeline(pipeline: *pipeline);
518 if (failed(Result: reproPm))
519 return failure();
520 static_cast<OpPassManager &>(pm) = std::move(*reproPm);
521 }
522
523 if (disableThreading.has_value())
524 pm.getContext()->disableMultithreading(disable: *disableThreading);
525
526 if (verifyEach.has_value())
527 pm.enableVerifier(enabled: *verifyEach);
528
529 return success();
530}
531

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Pass/PassCrashRecovery.cpp