1 | //===- TagBreakpointManager.h - Simple breakpoint Support -------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H |
10 | #define MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H |
11 | |
12 | #include "mlir/Debug/BreakpointManager.h" |
13 | #include "mlir/Debug/ExecutionContext.h" |
14 | #include "mlir/IR/Action.h" |
15 | #include "llvm/ADT/MapVector.h" |
16 | |
17 | namespace mlir { |
18 | namespace tracing { |
19 | |
20 | /// Simple breakpoint matching an action "tag". |
21 | class TagBreakpoint : public BreakpointBase<TagBreakpoint> { |
22 | public: |
23 | TagBreakpoint(StringRef tag) : tag(tag) {} |
24 | |
25 | void print(raw_ostream &os) const override { os << "Tag: `"<< tag << '`'; } |
26 | |
27 | private: |
28 | /// A tag to associate the TagBreakpoint with. |
29 | std::string tag; |
30 | |
31 | /// Allow access to `tag`. |
32 | friend class TagBreakpointManager; |
33 | }; |
34 | |
35 | /// This is a manager to store a collection of breakpoints that trigger |
36 | /// on tags. |
37 | class TagBreakpointManager |
38 | : public BreakpointManagerBase<TagBreakpointManager> { |
39 | public: |
40 | Breakpoint *match(const Action &action) const override { |
41 | auto it = breakpoints.find(Key: action.getTag()); |
42 | if (it != breakpoints.end() && it->second->isEnabled()) |
43 | return it->second.get(); |
44 | return {}; |
45 | } |
46 | |
47 | /// Add a breakpoint to the manager for the given tag and return it. |
48 | /// If a breakpoint already exists for the given tag, return the existing |
49 | /// instance. |
50 | TagBreakpoint *addBreakpoint(StringRef tag) { |
51 | auto result = breakpoints.insert(KV: {tag, nullptr}); |
52 | auto &it = result.first; |
53 | if (result.second) |
54 | it->second = std::make_unique<TagBreakpoint>(args: tag.str()); |
55 | return it->second.get(); |
56 | } |
57 | |
58 | private: |
59 | llvm::StringMap<std::unique_ptr<TagBreakpoint>> breakpoints; |
60 | }; |
61 | |
62 | } // namespace tracing |
63 | } // namespace mlir |
64 | |
65 | #endif // MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H |
66 |