1 | //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements Analysis functions specific to slicing in Function. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Analysis/SliceAnalysis.h" |
14 | #include "mlir/IR/BuiltinOps.h" |
15 | #include "mlir/IR/Operation.h" |
16 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
17 | #include "mlir/Support/LLVM.h" |
18 | #include "llvm/ADT/SetVector.h" |
19 | #include "llvm/ADT/SmallPtrSet.h" |
20 | |
21 | /// |
22 | /// Implements Analysis functions specific to slicing in Function. |
23 | /// |
24 | |
25 | using namespace mlir; |
26 | |
27 | static void |
28 | getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice, |
29 | const SliceOptions::TransitiveFilter &filter = nullptr) { |
30 | if (!op) |
31 | return; |
32 | |
33 | // Evaluate whether we should keep this use. |
34 | // This is useful in particular to implement scoping; i.e. return the |
35 | // transitive forwardSlice in the current scope. |
36 | if (filter && !filter(op)) |
37 | return; |
38 | |
39 | for (Region ®ion : op->getRegions()) |
40 | for (Block &block : region) |
41 | for (Operation &blockOp : block) |
42 | if (forwardSlice->count(key: &blockOp) == 0) |
43 | getForwardSliceImpl(op: &blockOp, forwardSlice, filter); |
44 | for (Value result : op->getResults()) { |
45 | for (Operation *userOp : result.getUsers()) |
46 | if (forwardSlice->count(key: userOp) == 0) |
47 | getForwardSliceImpl(op: userOp, forwardSlice, filter); |
48 | } |
49 | |
50 | forwardSlice->insert(X: op); |
51 | } |
52 | |
53 | void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, |
54 | const ForwardSliceOptions &options) { |
55 | getForwardSliceImpl(op, forwardSlice, filter: options.filter); |
56 | if (!options.inclusive) { |
57 | // Don't insert the top level operation, we just queried on it and don't |
58 | // want it in the results. |
59 | forwardSlice->remove(X: op); |
60 | } |
61 | |
62 | // Reverse to get back the actual topological order. |
63 | // std::reverse does not work out of the box on SetVector and I want an |
64 | // in-place swap based thing (the real std::reverse, not the LLVM adapter). |
65 | SmallVector<Operation *, 0> v(forwardSlice->takeVector()); |
66 | forwardSlice->insert(Start: v.rbegin(), End: v.rend()); |
67 | } |
68 | |
69 | void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, |
70 | const SliceOptions &options) { |
71 | for (Operation *user : root.getUsers()) |
72 | getForwardSliceImpl(op: user, forwardSlice, filter: options.filter); |
73 | |
74 | // Reverse to get back the actual topological order. |
75 | // std::reverse does not work out of the box on SetVector and I want an |
76 | // in-place swap based thing (the real std::reverse, not the LLVM adapter). |
77 | SmallVector<Operation *, 0> v(forwardSlice->takeVector()); |
78 | forwardSlice->insert(Start: v.rbegin(), End: v.rend()); |
79 | } |
80 | |
81 | static void getBackwardSliceImpl(Operation *op, |
82 | SetVector<Operation *> *backwardSlice, |
83 | const BackwardSliceOptions &options) { |
84 | if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
85 | return; |
86 | |
87 | // Evaluate whether we should keep this def. |
88 | // This is useful in particular to implement scoping; i.e. return the |
89 | // transitive backwardSlice in the current scope. |
90 | if (options.filter && !options.filter(op)) |
91 | return; |
92 | |
93 | for (const auto &en : llvm::enumerate(First: op->getOperands())) { |
94 | auto operand = en.value(); |
95 | if (auto *definingOp = operand.getDefiningOp()) { |
96 | if (backwardSlice->count(key: definingOp) == 0) |
97 | getBackwardSliceImpl(op: definingOp, backwardSlice, options); |
98 | } else if (auto blockArg = dyn_cast<BlockArgument>(Val&: operand)) { |
99 | if (options.omitBlockArguments) |
100 | continue; |
101 | |
102 | Block *block = blockArg.getOwner(); |
103 | Operation *parentOp = block->getParentOp(); |
104 | // TODO: determine whether we want to recurse backward into the other |
105 | // blocks of parentOp, which are not technically backward unless they flow |
106 | // into us. For now, just bail. |
107 | if (parentOp && backwardSlice->count(key: parentOp) == 0) { |
108 | assert(parentOp->getNumRegions() == 1 && |
109 | parentOp->getRegion(0).getBlocks().size() == 1); |
110 | getBackwardSliceImpl(op: parentOp, backwardSlice, options); |
111 | } |
112 | } else { |
113 | llvm_unreachable("No definingOp and not a block argument." ); |
114 | } |
115 | } |
116 | |
117 | backwardSlice->insert(X: op); |
118 | } |
119 | |
120 | void mlir::getBackwardSlice(Operation *op, |
121 | SetVector<Operation *> *backwardSlice, |
122 | const BackwardSliceOptions &options) { |
123 | getBackwardSliceImpl(op, backwardSlice, options); |
124 | |
125 | if (!options.inclusive) { |
126 | // Don't insert the top level operation, we just queried on it and don't |
127 | // want it in the results. |
128 | backwardSlice->remove(X: op); |
129 | } |
130 | } |
131 | |
132 | void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, |
133 | const BackwardSliceOptions &options) { |
134 | if (Operation *definingOp = root.getDefiningOp()) { |
135 | getBackwardSlice(op: definingOp, backwardSlice, options); |
136 | return; |
137 | } |
138 | Operation *bbAargOwner = cast<BlockArgument>(Val&: root).getOwner()->getParentOp(); |
139 | getBackwardSlice(op: bbAargOwner, backwardSlice, options); |
140 | } |
141 | |
142 | SetVector<Operation *> |
143 | mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, |
144 | const ForwardSliceOptions &forwardSliceOptions) { |
145 | SetVector<Operation *> slice; |
146 | slice.insert(X: op); |
147 | |
148 | unsigned currentIndex = 0; |
149 | SetVector<Operation *> backwardSlice; |
150 | SetVector<Operation *> forwardSlice; |
151 | while (currentIndex != slice.size()) { |
152 | auto *currentOp = (slice)[currentIndex]; |
153 | // Compute and insert the backwardSlice starting from currentOp. |
154 | backwardSlice.clear(); |
155 | getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions); |
156 | slice.insert(Start: backwardSlice.begin(), End: backwardSlice.end()); |
157 | |
158 | // Compute and insert the forwardSlice starting from currentOp. |
159 | forwardSlice.clear(); |
160 | getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions); |
161 | slice.insert(Start: forwardSlice.begin(), End: forwardSlice.end()); |
162 | ++currentIndex; |
163 | } |
164 | return topologicalSort(toSort: slice); |
165 | } |
166 | |
167 | namespace { |
168 | /// DFS post-order implementation that maintains a global count to work across |
169 | /// multiple invocations, to help implement topological sort on multi-root DAGs. |
170 | /// We traverse all operations but only record the ones that appear in |
171 | /// `toSort` for the final result. |
172 | struct DFSState { |
173 | DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {} |
174 | const SetVector<Operation *> &toSort; |
175 | SmallVector<Operation *, 16> topologicalCounts; |
176 | DenseSet<Operation *> seen; |
177 | }; |
178 | } // namespace |
179 | |
180 | static void dfsPostorder(Operation *root, DFSState *state) { |
181 | SmallVector<Operation *> queue(1, root); |
182 | std::vector<Operation *> ops; |
183 | while (!queue.empty()) { |
184 | Operation *current = queue.pop_back_val(); |
185 | ops.push_back(x: current); |
186 | for (Operation *op : current->getUsers()) |
187 | queue.push_back(Elt: op); |
188 | for (Region ®ion : current->getRegions()) { |
189 | for (Operation &op : region.getOps()) |
190 | queue.push_back(Elt: &op); |
191 | } |
192 | } |
193 | |
194 | for (Operation *op : llvm::reverse(C&: ops)) { |
195 | if (state->seen.insert(V: op).second && state->toSort.count(key: op) > 0) |
196 | state->topologicalCounts.push_back(Elt: op); |
197 | } |
198 | } |
199 | |
200 | SetVector<Operation *> |
201 | mlir::topologicalSort(const SetVector<Operation *> &toSort) { |
202 | if (toSort.empty()) { |
203 | return toSort; |
204 | } |
205 | |
206 | // Run from each root with global count and `seen` set. |
207 | DFSState state(toSort); |
208 | for (auto *s : toSort) { |
209 | assert(toSort.count(s) == 1 && "NYI: multi-sets not supported" ); |
210 | dfsPostorder(root: s, state: &state); |
211 | } |
212 | |
213 | // Reorder and return. |
214 | SetVector<Operation *> res; |
215 | for (auto it = state.topologicalCounts.rbegin(), |
216 | eit = state.topologicalCounts.rend(); |
217 | it != eit; ++it) { |
218 | res.insert(X: *it); |
219 | } |
220 | return res; |
221 | } |
222 | |
223 | /// Returns true if `value` (transitively) depends on iteration-carried values |
224 | /// of the given `ancestorOp`. |
225 | static bool dependsOnCarriedVals(Value value, |
226 | ArrayRef<BlockArgument> iterCarriedArgs, |
227 | Operation *ancestorOp) { |
228 | // Compute the backward slice of the value. |
229 | SetVector<Operation *> slice; |
230 | BackwardSliceOptions sliceOptions; |
231 | sliceOptions.filter = [&](Operation *op) { |
232 | return !ancestorOp->isAncestor(other: op); |
233 | }; |
234 | getBackwardSlice(root: value, backwardSlice: &slice, options: sliceOptions); |
235 | |
236 | // Check that none of the operands of the operations in the backward slice are |
237 | // loop iteration arguments, and neither is the value itself. |
238 | SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(), |
239 | iterCarriedArgs.end()); |
240 | if (iterCarriedValSet.contains(Ptr: value)) |
241 | return true; |
242 | |
243 | for (Operation *op : slice) |
244 | for (Value operand : op->getOperands()) |
245 | if (iterCarriedValSet.contains(Ptr: operand)) |
246 | return true; |
247 | |
248 | return false; |
249 | } |
250 | |
251 | /// Utility to match a generic reduction given a list of iteration-carried |
252 | /// arguments, `iterCarriedArgs` and the position of the potential reduction |
253 | /// argument within the list, `redPos`. If a reduction is matched, returns the |
254 | /// reduced value and the topologically-sorted list of combiner operations |
255 | /// involved in the reduction. Otherwise, returns a null value. |
256 | /// |
257 | /// The matching algorithm relies on the following invariants, which are subject |
258 | /// to change: |
259 | /// 1. The first combiner operation must be a binary operation with the |
260 | /// iteration-carried value and the reduced value as operands. |
261 | /// 2. The iteration-carried value and combiner operations must be side |
262 | /// effect-free, have single result and a single use. |
263 | /// 3. Combiner operations must be immediately nested in the region op |
264 | /// performing the reduction. |
265 | /// 4. Reduction def-use chain must end in a terminator op that yields the |
266 | /// next iteration/output values in the same order as the iteration-carried |
267 | /// values in `iterCarriedArgs`. |
268 | /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values |
269 | /// of the region op performing the reduction. |
270 | /// |
271 | /// This utility is generic enough to detect reductions involving multiple |
272 | /// combiner operations (disabled for now) across multiple dialects, including |
273 | /// Linalg, Affine and SCF. For the sake of genericity, it does not return |
274 | /// specific enum values for the combiner operations since its goal is also |
275 | /// matching reductions without pre-defined semantics in core MLIR. It's up to |
276 | /// each client to make sense out of the list of combiner operations. It's also |
277 | /// up to each client to check for additional invariants on the expected |
278 | /// reductions not covered by this generic matching. |
279 | Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs, |
280 | unsigned redPos, |
281 | SmallVectorImpl<Operation *> &combinerOps) { |
282 | assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds" ); |
283 | |
284 | BlockArgument redCarriedVal = iterCarriedArgs[redPos]; |
285 | if (!redCarriedVal.hasOneUse()) |
286 | return nullptr; |
287 | |
288 | // For now, the first combiner op must be a binary op. |
289 | Operation *combinerOp = *redCarriedVal.getUsers().begin(); |
290 | if (combinerOp->getNumOperands() != 2) |
291 | return nullptr; |
292 | Value reducedVal = combinerOp->getOperand(idx: 0) == redCarriedVal |
293 | ? combinerOp->getOperand(idx: 1) |
294 | : combinerOp->getOperand(idx: 0); |
295 | |
296 | Operation *redRegionOp = |
297 | iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); |
298 | if (dependsOnCarriedVals(value: reducedVal, iterCarriedArgs, ancestorOp: redRegionOp)) |
299 | return nullptr; |
300 | |
301 | // Traverse the def-use chain starting from the first combiner op until a |
302 | // terminator is found. Gather all the combiner ops along the way in |
303 | // topological order. |
304 | while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) { |
305 | if (!isMemoryEffectFree(op: combinerOp) || combinerOp->getNumResults() != 1 || |
306 | !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp) |
307 | return nullptr; |
308 | |
309 | combinerOps.push_back(Elt: combinerOp); |
310 | combinerOp = *combinerOp->getUsers().begin(); |
311 | } |
312 | |
313 | // Limit matching to single combiner op until we can properly test reductions |
314 | // involving multiple combiners. |
315 | if (combinerOps.size() != 1) |
316 | return nullptr; |
317 | |
318 | // Check that the yielded value is in the same position as in |
319 | // `iterCarriedArgs`. |
320 | Operation *terminatorOp = combinerOp; |
321 | if (terminatorOp->getOperand(idx: redPos) != combinerOps.back()->getResults()[0]) |
322 | return nullptr; |
323 | |
324 | return reducedVal; |
325 | } |
326 | |