1 | //===- SliceMatchers.h - Matchers for slicing analysis ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file provides matchers for MLIRQuery that peform slicing analysis |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H |
14 | #define MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H |
15 | |
16 | #include "mlir/Analysis/SliceAnalysis.h" |
17 | #include "mlir/IR/Operation.h" |
18 | |
19 | /// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h. |
20 | /// Additionally, it limits the slice computation to a certain depth level using |
21 | /// a custom filter. |
22 | /// |
23 | /// Example: starting from node 9, assuming the matcher |
24 | /// computes the slice for the first two depth levels: |
25 | /// ============================ |
26 | /// 1 2 3 4 |
27 | /// |_______| |______| |
28 | /// | | | |
29 | /// | 5 6 |
30 | /// |___|_____________| |
31 | /// | | |
32 | /// 7 8 |
33 | /// |_______________| |
34 | /// | |
35 | /// 9 |
36 | /// |
37 | /// Assuming all local orders match the numbering order: |
38 | /// {5, 7, 6, 8, 9} |
39 | namespace mlir::query::matcher { |
40 | |
41 | template <typename Matcher> |
42 | class BackwardSliceMatcher { |
43 | public: |
44 | BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive, |
45 | bool omitBlockArguments, bool omitUsesFromAbove) |
46 | : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth), |
47 | inclusive(inclusive), omitBlockArguments(omitBlockArguments), |
48 | omitUsesFromAbove(omitUsesFromAbove) {} |
49 | |
50 | bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) { |
51 | BackwardSliceOptions options; |
52 | options.inclusive = inclusive; |
53 | options.omitUsesFromAbove = omitUsesFromAbove; |
54 | options.omitBlockArguments = omitBlockArguments; |
55 | return (innerMatcher.match(rootOp) && |
56 | matches(rootOp, backwardSlice, options, maxDepth)); |
57 | } |
58 | |
59 | private: |
60 | bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice, |
61 | BackwardSliceOptions &options, int64_t maxDepth); |
62 | |
63 | private: |
64 | // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher |
65 | // to determine whether we want to traverse the IR or not. For example, we |
66 | // want to explore the IR only if the top-level operation name is |
67 | // `"arith.addf"`. |
68 | Matcher innerMatcher; |
69 | // `maxDepth` specifies the maximum depth that the matcher can traverse the |
70 | // IR. For example, if `maxDepth` is 2, the matcher will explore the defining |
71 | // operations of the top-level op up to 2 levels. |
72 | int64_t maxDepth; |
73 | bool inclusive; |
74 | bool omitBlockArguments; |
75 | bool omitUsesFromAbove; |
76 | }; |
77 | |
78 | template <typename Matcher> |
79 | bool BackwardSliceMatcher<Matcher>::matches( |
80 | Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice, |
81 | BackwardSliceOptions &options, int64_t maxDepth) { |
82 | backwardSlice.clear(); |
83 | llvm::DenseMap<Operation *, int64_t> opDepths; |
84 | // Initializing the root op with a depth of 0 |
85 | opDepths[rootOp] = 0; |
86 | options.filter = [&](Operation *subOp) { |
87 | // If the subOp hasn't been recorded in opDepths, it is deeper than |
88 | // maxDepth. |
89 | if (!opDepths.contains(Val: subOp)) |
90 | return false; |
91 | // Examine subOp's operands to compute depths of their defining operations. |
92 | for (auto operand : subOp->getOperands()) { |
93 | int64_t newDepth = opDepths[subOp] + 1; |
94 | // If the newDepth is greater than maxDepth, further computation can be |
95 | // skipped. |
96 | if (newDepth > maxDepth) |
97 | continue; |
98 | |
99 | if (auto definingOp = operand.getDefiningOp()) { |
100 | // Registers the minimum depth |
101 | if (!opDepths.contains(Val: definingOp) || newDepth < opDepths[definingOp]) |
102 | opDepths[definingOp] = newDepth; |
103 | } else { |
104 | auto blockArgument = cast<BlockArgument>(Val&: operand); |
105 | Operation *parentOp = blockArgument.getOwner()->getParentOp(); |
106 | if (!parentOp) |
107 | continue; |
108 | |
109 | if (!opDepths.contains(Val: parentOp) || newDepth < opDepths[parentOp]) |
110 | opDepths[parentOp] = newDepth; |
111 | } |
112 | } |
113 | return true; |
114 | }; |
115 | LogicalResult result = getBackwardSlice(op: rootOp, backwardSlice: &backwardSlice, options); |
116 | assert(result.succeeded() && "expected backward slice to succeed" ); |
117 | (void)result; |
118 | return options.inclusive ? backwardSlice.size() > 1 |
119 | : backwardSlice.size() >= 1; |
120 | } |
121 | |
122 | /// Matches transitive defs of a top-level operation up to N levels. |
123 | template <typename Matcher> |
124 | inline BackwardSliceMatcher<Matcher> |
125 | m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive, |
126 | bool omitBlockArguments, bool omitUsesFromAbove) { |
127 | assert(maxDepth >= 0 && "maxDepth must be non-negative" ); |
128 | return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, |
129 | inclusive, omitBlockArguments, |
130 | omitUsesFromAbove); |
131 | } |
132 | |
133 | /// Matches all transitive defs of a top-level operation up to N levels |
134 | template <typename Matcher> |
135 | inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher, |
136 | int64_t maxDepth) { |
137 | assert(maxDepth >= 0 && "maxDepth must be non-negative" ); |
138 | return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, true, |
139 | false, false); |
140 | } |
141 | |
142 | } // namespace mlir::query::matcher |
143 | |
144 | #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H |
145 | |