1//===- ExecutionContext.h - Execution Context Support *- C++ -*-=============//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_TRACING_EXECUTIONCONTEXT_H
10#define MLIR_TRACING_EXECUTIONCONTEXT_H
11
12#include "mlir/Debug/BreakpointManager.h"
13#include "mlir/IR/Action.h"
14#include "llvm/ADT/SmallVector.h"
15
16namespace mlir {
17namespace tracing {
18
19/// This class is used to keep track of the active actions in the stack.
20/// It provides the current action but also access to the parent entry in the
21/// stack. This allows to keep track of the nested nature in which actions may
22/// be executed.
23struct ActionActiveStack {
24public:
25 ActionActiveStack(const ActionActiveStack *parent, const Action &action,
26 int depth)
27 : parent(parent), action(action), depth(depth) {}
28 const ActionActiveStack *getParent() const { return parent; }
29 const Action &getAction() const { return action; }
30 int getDepth() const { return depth; }
31 void print(raw_ostream &os, bool withContext) const;
32 void dump() const {
33 print(os&: llvm::errs(), /*withContext=*/withContext: true);
34 llvm::errs() << "\n";
35 }
36 Breakpoint *getBreakpoint() const { return breakpoint; }
37 void setBreakpoint(Breakpoint *breakpoint) { this->breakpoint = breakpoint; }
38
39private:
40 Breakpoint *breakpoint = nullptr;
41 const ActionActiveStack *parent;
42 const Action &action;
43 int depth;
44};
45
46/// The ExecutionContext is the main orchestration of the infrastructure, it
47/// acts as a handler in the MLIRContext for executing an Action. When an action
48/// is dispatched, it'll query its set of Breakpoints managers for a breakpoint
49/// matching this action. If a breakpoint is hit, it passes the action and the
50/// breakpoint information to a callback. The callback is responsible for
51/// controlling the execution of the action through an enum value it returns.
52/// Optionally, observers can be registered to be notified before and after the
53/// callback is executed.
54class ExecutionContext {
55public:
56 /// Enum that allows the client of the context to control the execution of the
57 /// action.
58 /// - Apply: The action is executed.
59 /// - Skip: The action is skipped.
60 /// - Step: The action is executed and the execution is paused before the next
61 /// action, including for nested actions encountered before the
62 /// current action finishes.
63 /// - Next: The action is executed and the execution is paused after the
64 /// current action finishes before the next action.
65 /// - Finish: The action is executed and the execution is paused only when we
66 /// reach the parent/enclosing operation. If there are no enclosing
67 /// operation, the execution continues without stopping.
68 enum Control { Apply = 1, Skip = 2, Step = 3, Next = 4, Finish = 5 };
69
70 /// The type of the callback that is used to control the execution.
71 /// The callback is passed the current action.
72 using CallbackTy = function_ref<Control(const ActionActiveStack *)>;
73
74 /// Create an ExecutionContext with a callback that is used to control the
75 /// execution.
76 ExecutionContext(CallbackTy callback) { setCallback(callback); }
77 ExecutionContext() = default;
78
79 /// Set the callback that is used to control the execution.
80 void setCallback(CallbackTy callback) {
81 onBreakpointControlExecutionCallback = callback;
82 }
83
84 /// This abstract class defines the interface used to observe an Action
85 /// execution. It allows to be notified before and after the callback is
86 /// processed, but can't affect the execution.
87 struct Observer {
88 virtual ~Observer() = default;
89 /// This method is called before the Action is executed
90 /// If a breakpoint was hit, it is passed as an argument to the callback.
91 /// The `willExecute` argument indicates whether the action will be executed
92 /// or not.
93 /// Note that this method will be called from multiple threads concurrently
94 /// when MLIR multi-threading is enabled.
95 virtual void beforeExecute(const ActionActiveStack *action,
96 Breakpoint *breakpoint, bool willExecute) {}
97
98 /// This method is called after the Action is executed, if it was executed.
99 /// It is not called if the action is skipped.
100 /// Note that this method will be called from multiple threads concurrently
101 /// when MLIR multi-threading is enabled.
102 virtual void afterExecute(const ActionActiveStack *action) {}
103 };
104
105 /// Register a new `Observer` on this context. It'll be notified before and
106 /// after executing an action. Note that this method is not thread-safe: it
107 /// isn't supported to add a new observer while actions may be executed.
108 void registerObserver(Observer *observer);
109
110 /// Register a new `BreakpointManager` on this context. It'll have a chance to
111 /// match an action before it gets executed. Note that this method is not
112 /// thread-safe: it isn't supported to add a new manager while actions may be
113 /// executed.
114 void addBreakpointManager(BreakpointManager *manager) {
115 breakpoints.push_back(Elt: manager);
116 }
117
118 /// Process the given action. This is the operator called by MLIRContext on
119 /// `executeAction()`.
120 void operator()(function_ref<void()> transform, const Action &action);
121
122private:
123 /// Callback that is executed when a breakpoint is hit and allows the client
124 /// to control the execution.
125 CallbackTy onBreakpointControlExecutionCallback;
126
127 /// Next point to stop execution as describe by `Control` enum.
128 /// This is handle by indicating at which levels of depth the next
129 /// break should happen.
130 std::optional<int> depthToBreak;
131
132 /// Observers that are notified before and after the callback is executed.
133 SmallVector<Observer *> observers;
134
135 /// The list of managers that are queried for breakpoints.
136 SmallVector<BreakpointManager *> breakpoints;
137};
138
139} // namespace tracing
140} // namespace mlir
141
142#endif // MLIR_TRACING_EXECUTIONCONTEXT_H
143

source code of mlir/include/mlir/Debug/ExecutionContext.h