1//===- DebuggerExecutionContextHook.cpp - Debugger Support ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Debug/DebuggerExecutionContextHook.h"
10
11#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
12#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
13
14using namespace mlir;
15using namespace mlir::tracing;
16
17namespace {
18/// This structure tracks the state of the interactive debugger.
19struct DebuggerState {
20 /// This variable keeps track of the current control option. This is set by
21 /// the debugger when control is handed over to it.
22 ExecutionContext::Control debuggerControl = ExecutionContext::Apply;
23
24 /// The breakpoint manager that allows the debugger to set breakpoints on
25 /// action tags.
26 TagBreakpointManager tagBreakpointManager;
27
28 /// The breakpoint manager that allows the debugger to set breakpoints on
29 /// FileLineColLoc locations.
30 FileLineColLocBreakpointManager fileLineColLocBreakpointManager;
31
32 /// Map of breakpoint IDs to breakpoint objects.
33 DenseMap<unsigned, Breakpoint *> breakpointIdsMap;
34
35 /// The current stack of actiive actions.
36 const tracing::ActionActiveStack *actionActiveStack;
37
38 /// This is a "cursor" in the IR, it is used for the debugger to navigate the
39 /// IR associated to the actions.
40 IRUnit cursor;
41};
42} // namespace
43
44static DebuggerState &getGlobalDebuggerState() {
45 static LLVM_THREAD_LOCAL DebuggerState debuggerState;
46 return debuggerState;
47}
48
49extern "C" {
50void mlirDebuggerSetControl(int controlOption) {
51 getGlobalDebuggerState().debuggerControl =
52 static_cast<ExecutionContext::Control>(controlOption);
53}
54
55void mlirDebuggerPrintContext() {
56 DebuggerState &state = getGlobalDebuggerState();
57 if (!state.actionActiveStack) {
58 llvm::outs() << "No active action.\n";
59 return;
60 }
61 const ArrayRef<IRUnit> &units =
62 state.actionActiveStack->getAction().getContextIRUnits();
63 llvm::outs() << units.size() << " available IRUnits:\n";
64 for (const IRUnit &unit : units) {
65 llvm::outs() << " - ";
66 unit.print(
67 os&: llvm::outs(),
68 flags: OpPrintingFlags().useLocalScope().skipRegions().enableDebugInfo());
69 llvm::outs() << "\n";
70 }
71}
72
73void mlirDebuggerPrintActionBacktrace(bool withContext) {
74 DebuggerState &state = getGlobalDebuggerState();
75 if (!state.actionActiveStack) {
76 llvm::outs() << "No active action.\n";
77 return;
78 }
79 state.actionActiveStack->print(os&: llvm::outs(), withContext);
80}
81
82//===----------------------------------------------------------------------===//
83// Cursor Management
84//===----------------------------------------------------------------------===//
85
86void mlirDebuggerCursorPrint(bool withRegion) {
87 auto &state = getGlobalDebuggerState();
88 if (!state.cursor) {
89 llvm::outs() << "No active MLIR cursor, select from the context first\n";
90 return;
91 }
92 state.cursor.print(os&: llvm::outs(), flags: OpPrintingFlags()
93 .skipRegions(skip: !withRegion)
94 .useLocalScope()
95 .enableDebugInfo());
96 llvm::outs() << "\n";
97}
98
99void mlirDebuggerCursorSelectIRUnitFromContext(int index) {
100 auto &state = getGlobalDebuggerState();
101 if (!state.actionActiveStack) {
102 llvm::outs() << "No active MLIR Action stack\n";
103 return;
104 }
105 ArrayRef<IRUnit> units =
106 state.actionActiveStack->getAction().getContextIRUnits();
107 if (index < 0 || index >= static_cast<int>(units.size())) {
108 llvm::outs() << "Index invalid, bounds: [0, " << units.size()
109 << "] but got " << index << "\n";
110 return;
111 }
112 state.cursor = units[index];
113 state.cursor.print(os&: llvm::outs());
114 llvm::outs() << "\n";
115}
116
117void mlirDebuggerCursorSelectParentIRUnit() {
118 auto &state = getGlobalDebuggerState();
119 if (!state.cursor) {
120 llvm::outs() << "No active MLIR cursor, select from the context first\n";
121 return;
122 }
123 IRUnit *unit = &state.cursor;
124 if (auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: *unit)) {
125 state.cursor = op->getBlock();
126 } else if (auto *region = llvm::dyn_cast_if_present<Region *>(Val&: *unit)) {
127 state.cursor = region->getParentOp();
128 } else if (auto *block = llvm::dyn_cast_if_present<Block *>(Val&: *unit)) {
129 state.cursor = block->getParent();
130 } else {
131 llvm::outs() << "Current cursor is not a valid IRUnit";
132 return;
133 }
134 state.cursor.print(os&: llvm::outs());
135 llvm::outs() << "\n";
136}
137
138void mlirDebuggerCursorSelectChildIRUnit(int index) {
139 auto &state = getGlobalDebuggerState();
140 if (!state.cursor) {
141 llvm::outs() << "No active MLIR cursor, select from the context first\n";
142 return;
143 }
144 IRUnit *unit = &state.cursor;
145 if (auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: *unit)) {
146 if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
147 llvm::outs() << "Index invalid, op has " << op->getNumRegions()
148 << " but got " << index << "\n";
149 return;
150 }
151 state.cursor = &op->getRegion(index);
152 } else if (auto *region = llvm::dyn_cast_if_present<Region *>(Val&: *unit)) {
153 auto block = region->begin();
154 int count = 0;
155 while (block != region->end() && count != index) {
156 ++block;
157 ++count;
158 }
159
160 if (block == region->end()) {
161 llvm::outs() << "Index invalid, region has " << count << " block but got "
162 << index << "\n";
163 return;
164 }
165 state.cursor = &*block;
166 } else if (auto *block = llvm::dyn_cast_if_present<Block *>(Val&: *unit)) {
167 auto op = block->begin();
168 int count = 0;
169 while (op != block->end() && count != index) {
170 ++op;
171 ++count;
172 }
173
174 if (op == block->end()) {
175 llvm::outs() << "Index invalid, block has " << count
176 << "operations but got " << index << "\n";
177 return;
178 }
179 state.cursor = &*op;
180 } else {
181 llvm::outs() << "Current cursor is not a valid IRUnit";
182 return;
183 }
184 state.cursor.print(os&: llvm::outs());
185 llvm::outs() << "\n";
186}
187
188void mlirDebuggerCursorSelectPreviousIRUnit() {
189 auto &state = getGlobalDebuggerState();
190 if (!state.cursor) {
191 llvm::outs() << "No active MLIR cursor, select from the context first\n";
192 return;
193 }
194 IRUnit *unit = &state.cursor;
195 if (auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: *unit)) {
196 Operation *previous = op->getPrevNode();
197 if (!previous) {
198 llvm::outs() << "No previous operation in the current block\n";
199 return;
200 }
201 state.cursor = previous;
202 } else if (auto *region = llvm::dyn_cast_if_present<Region *>(Val&: *unit)) {
203 llvm::outs() << "Has region\n";
204 Operation *parent = region->getParentOp();
205 if (!parent) {
206 llvm::outs() << "No parent operation for the current region\n";
207 return;
208 }
209 if (region->getRegionNumber() == 0) {
210 llvm::outs() << "No previous region in the current operation\n";
211 return;
212 }
213 state.cursor =
214 &region->getParentOp()->getRegion(index: region->getRegionNumber() - 1);
215 } else if (auto *block = llvm::dyn_cast_if_present<Block *>(Val&: *unit)) {
216 Block *previous = block->getPrevNode();
217 if (!previous) {
218 llvm::outs() << "No previous block in the current region\n";
219 return;
220 }
221 state.cursor = previous;
222 } else {
223 llvm::outs() << "Current cursor is not a valid IRUnit";
224 return;
225 }
226 state.cursor.print(os&: llvm::outs());
227 llvm::outs() << "\n";
228}
229
230void mlirDebuggerCursorSelectNextIRUnit() {
231 auto &state = getGlobalDebuggerState();
232 if (!state.cursor) {
233 llvm::outs() << "No active MLIR cursor, select from the context first\n";
234 return;
235 }
236 IRUnit *unit = &state.cursor;
237 if (auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: *unit)) {
238 Operation *next = op->getNextNode();
239 if (!next) {
240 llvm::outs() << "No next operation in the current block\n";
241 return;
242 }
243 state.cursor = next;
244 } else if (auto *region = llvm::dyn_cast_if_present<Region *>(Val&: *unit)) {
245 Operation *parent = region->getParentOp();
246 if (!parent) {
247 llvm::outs() << "No parent operation for the current region\n";
248 return;
249 }
250 if (region->getRegionNumber() == parent->getNumRegions() - 1) {
251 llvm::outs() << "No next region in the current operation\n";
252 return;
253 }
254 state.cursor =
255 &region->getParentOp()->getRegion(index: region->getRegionNumber() + 1);
256 } else if (auto *block = llvm::dyn_cast_if_present<Block *>(Val&: *unit)) {
257 Block *next = block->getNextNode();
258 if (!next) {
259 llvm::outs() << "No next block in the current region\n";
260 return;
261 }
262 state.cursor = next;
263 } else {
264 llvm::outs() << "Current cursor is not a valid IRUnit";
265 return;
266 }
267 state.cursor.print(os&: llvm::outs());
268 llvm::outs() << "\n";
269}
270
271//===----------------------------------------------------------------------===//
272// Breakpoint Management
273//===----------------------------------------------------------------------===//
274
275void mlirDebuggerEnableBreakpoint(BreakpointHandle breakpoint) {
276 reinterpret_cast<Breakpoint *>(breakpoint)->enable();
277}
278
279void mlirDebuggerDisableBreakpoint(BreakpointHandle breakpoint) {
280 reinterpret_cast<Breakpoint *>(breakpoint)->disable();
281}
282
283BreakpointHandle mlirDebuggerAddTagBreakpoint(const char *tag) {
284 DebuggerState &state = getGlobalDebuggerState();
285 Breakpoint *breakpoint =
286 state.tagBreakpointManager.addBreakpoint(tag: StringRef(tag, strlen(s: tag)));
287 int breakpointId = state.breakpointIdsMap.size() + 1;
288 state.breakpointIdsMap[breakpointId] = breakpoint;
289 return reinterpret_cast<BreakpointHandle>(breakpoint);
290}
291
292void mlirDebuggerAddRewritePatternBreakpoint(const char *patternNameInfo) {}
293
294void mlirDebuggerAddFileLineColLocBreakpoint(const char *file, int line,
295 int col) {
296 getGlobalDebuggerState().fileLineColLocBreakpointManager.addBreakpoint(
297 file: StringRef(file, strlen(s: file)), line, col);
298}
299
300} // extern "C"
301
302LLVM_ATTRIBUTE_NOINLINE void mlirDebuggerBreakpointHook() {
303 static LLVM_THREAD_LOCAL void *volatile sink;
304 sink = (void *)&sink;
305}
306
307static void preventLinkerDeadCodeElim() {
308 static void *volatile sink;
309 static bool initialized = [&]() {
310 sink = (void *)mlirDebuggerSetControl;
311 sink = (void *)mlirDebuggerEnableBreakpoint;
312 sink = (void *)mlirDebuggerDisableBreakpoint;
313 sink = (void *)mlirDebuggerPrintContext;
314 sink = (void *)mlirDebuggerPrintActionBacktrace;
315 sink = (void *)mlirDebuggerCursorPrint;
316 sink = (void *)mlirDebuggerCursorSelectIRUnitFromContext;
317 sink = (void *)mlirDebuggerCursorSelectParentIRUnit;
318 sink = (void *)mlirDebuggerCursorSelectChildIRUnit;
319 sink = (void *)mlirDebuggerCursorSelectPreviousIRUnit;
320 sink = (void *)mlirDebuggerCursorSelectNextIRUnit;
321 sink = (void *)mlirDebuggerAddTagBreakpoint;
322 sink = (void *)mlirDebuggerAddRewritePatternBreakpoint;
323 sink = (void *)mlirDebuggerAddFileLineColLocBreakpoint;
324 sink = (void *)&sink;
325 return true;
326 }();
327 (void)initialized;
328}
329
330static tracing::ExecutionContext::Control
331debuggerCallBackFunction(const tracing::ActionActiveStack *actionStack) {
332 preventLinkerDeadCodeElim();
333 // Invoke the breakpoint hook, the debugger is supposed to trap this.
334 // The debugger controls the execution from there by invoking
335 // `mlirDebuggerSetControl()`.
336 auto &state = getGlobalDebuggerState();
337 state.actionActiveStack = actionStack;
338 getGlobalDebuggerState().debuggerControl = ExecutionContext::Apply;
339 actionStack->getAction().print(os&: llvm::outs());
340 llvm::outs() << "\n";
341 mlirDebuggerBreakpointHook();
342 return getGlobalDebuggerState().debuggerControl;
343}
344
345namespace {
346/// Manage the stack of actions that are currently active.
347class DebuggerObserver : public ExecutionContext::Observer {
348 void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint,
349 bool willExecute) override {
350 auto &state = getGlobalDebuggerState();
351 state.actionActiveStack = action;
352 }
353 void afterExecute(const ActionActiveStack *action) override {
354 auto &state = getGlobalDebuggerState();
355 state.actionActiveStack = action->getParent();
356 state.cursor = nullptr;
357 }
358};
359} // namespace
360
361void mlir::setupDebuggerExecutionContextHook(
362 tracing::ExecutionContext &executionContext) {
363 executionContext.setCallback(debuggerCallBackFunction);
364 DebuggerState &state = getGlobalDebuggerState();
365 static DebuggerObserver observer;
366 executionContext.registerObserver(observer: &observer);
367 executionContext.addBreakpointManager(manager: &state.fileLineColLocBreakpointManager);
368 executionContext.addBreakpointManager(manager: &state.tagBreakpointManager);
369}
370

source code of mlir/lib/Debug/DebuggerExecutionContextHook.cpp