1 | //===- ExecutionContext.h - Execution Context Support *- C++ -*-=============// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_TRACING_EXECUTIONCONTEXT_H |
10 | #define MLIR_TRACING_EXECUTIONCONTEXT_H |
11 | |
12 | #include "mlir/Debug/BreakpointManager.h" |
13 | #include "mlir/IR/Action.h" |
14 | #include "llvm/ADT/SmallVector.h" |
15 | |
16 | namespace mlir { |
17 | namespace tracing { |
18 | |
19 | /// This class is used to keep track of the active actions in the stack. |
20 | /// It provides the current action but also access to the parent entry in the |
21 | /// stack. This allows to keep track of the nested nature in which actions may |
22 | /// be executed. |
23 | struct ActionActiveStack { |
24 | public: |
25 | ActionActiveStack(const ActionActiveStack *parent, const Action &action, |
26 | int depth) |
27 | : parent(parent), action(action), depth(depth) {} |
28 | const ActionActiveStack *getParent() const { return parent; } |
29 | const Action &getAction() const { return action; } |
30 | int getDepth() const { return depth; } |
31 | void print(raw_ostream &os, bool withContext) const; |
32 | void dump() const { |
33 | print(os&: llvm::errs(), /*withContext=*/withContext: true); |
34 | llvm::errs() << "\n" ; |
35 | } |
36 | Breakpoint *getBreakpoint() const { return breakpoint; } |
37 | void setBreakpoint(Breakpoint *breakpoint) { this->breakpoint = breakpoint; } |
38 | |
39 | private: |
40 | Breakpoint *breakpoint = nullptr; |
41 | const ActionActiveStack *parent; |
42 | const Action &action; |
43 | int depth; |
44 | }; |
45 | |
46 | /// The ExecutionContext is the main orchestration of the infrastructure, it |
47 | /// acts as a handler in the MLIRContext for executing an Action. When an action |
48 | /// is dispatched, it'll query its set of Breakpoints managers for a breakpoint |
49 | /// matching this action. If a breakpoint is hit, it passes the action and the |
50 | /// breakpoint information to a callback. The callback is responsible for |
51 | /// controlling the execution of the action through an enum value it returns. |
52 | /// Optionally, observers can be registered to be notified before and after the |
53 | /// callback is executed. |
54 | class ExecutionContext { |
55 | public: |
56 | /// Enum that allows the client of the context to control the execution of the |
57 | /// action. |
58 | /// - Apply: The action is executed. |
59 | /// - Skip: The action is skipped. |
60 | /// - Step: The action is executed and the execution is paused before the next |
61 | /// action, including for nested actions encountered before the |
62 | /// current action finishes. |
63 | /// - Next: The action is executed and the execution is paused after the |
64 | /// current action finishes before the next action. |
65 | /// - Finish: The action is executed and the execution is paused only when we |
66 | /// reach the parent/enclosing operation. If there are no enclosing |
67 | /// operation, the execution continues without stopping. |
68 | enum Control { Apply = 1, Skip = 2, Step = 3, Next = 4, Finish = 5 }; |
69 | |
70 | /// The type of the callback that is used to control the execution. |
71 | /// The callback is passed the current action. |
72 | using CallbackTy = function_ref<Control(const ActionActiveStack *)>; |
73 | |
74 | /// Create an ExecutionContext with a callback that is used to control the |
75 | /// execution. |
76 | ExecutionContext(CallbackTy callback) { setCallback(callback); } |
77 | ExecutionContext() = default; |
78 | |
79 | /// Set the callback that is used to control the execution. |
80 | void setCallback(CallbackTy callback) { |
81 | onBreakpointControlExecutionCallback = callback; |
82 | } |
83 | |
84 | /// This abstract class defines the interface used to observe an Action |
85 | /// execution. It allows to be notified before and after the callback is |
86 | /// processed, but can't affect the execution. |
87 | struct Observer { |
88 | virtual ~Observer() = default; |
89 | /// This method is called before the Action is executed |
90 | /// If a breakpoint was hit, it is passed as an argument to the callback. |
91 | /// The `willExecute` argument indicates whether the action will be executed |
92 | /// or not. |
93 | /// Note that this method will be called from multiple threads concurrently |
94 | /// when MLIR multi-threading is enabled. |
95 | virtual void beforeExecute(const ActionActiveStack *action, |
96 | Breakpoint *breakpoint, bool willExecute) {} |
97 | |
98 | /// This method is called after the Action is executed, if it was executed. |
99 | /// It is not called if the action is skipped. |
100 | /// Note that this method will be called from multiple threads concurrently |
101 | /// when MLIR multi-threading is enabled. |
102 | virtual void afterExecute(const ActionActiveStack *action) {} |
103 | }; |
104 | |
105 | /// Register a new `Observer` on this context. It'll be notified before and |
106 | /// after executing an action. Note that this method is not thread-safe: it |
107 | /// isn't supported to add a new observer while actions may be executed. |
108 | void registerObserver(Observer *observer); |
109 | |
110 | /// Register a new `BreakpointManager` on this context. It'll have a chance to |
111 | /// match an action before it gets executed. Note that this method is not |
112 | /// thread-safe: it isn't supported to add a new manager while actions may be |
113 | /// executed. |
114 | void addBreakpointManager(BreakpointManager *manager) { |
115 | breakpoints.push_back(Elt: manager); |
116 | } |
117 | |
118 | /// Process the given action. This is the operator called by MLIRContext on |
119 | /// `executeAction()`. |
120 | void operator()(function_ref<void()> transform, const Action &action); |
121 | |
122 | private: |
123 | /// Callback that is executed when a breakpoint is hit and allows the client |
124 | /// to control the execution. |
125 | CallbackTy onBreakpointControlExecutionCallback; |
126 | |
127 | /// Next point to stop execution as describe by `Control` enum. |
128 | /// This is handle by indicating at which levels of depth the next |
129 | /// break should happen. |
130 | std::optional<int> depthToBreak; |
131 | |
132 | /// Observers that are notified before and after the callback is executed. |
133 | SmallVector<Observer *> observers; |
134 | |
135 | /// The list of managers that are queried for breakpoints. |
136 | SmallVector<BreakpointManager *> breakpoints; |
137 | }; |
138 | |
139 | } // namespace tracing |
140 | } // namespace mlir |
141 | |
142 | #endif // MLIR_TRACING_EXECUTIONCONTEXT_H |
143 | |