1 | //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements Analysis functions specific to slicing in Function. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Analysis/SliceAnalysis.h" |
14 | #include "mlir/Analysis/TopologicalSortUtils.h" |
15 | #include "mlir/IR/Block.h" |
16 | #include "mlir/IR/Operation.h" |
17 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | #include "llvm/ADT/STLExtras.h" |
20 | #include "llvm/ADT/SetVector.h" |
21 | #include "llvm/ADT/SmallPtrSet.h" |
22 | |
23 | /// |
24 | /// Implements Analysis functions specific to slicing in Function. |
25 | /// |
26 | |
27 | using namespace mlir; |
28 | |
29 | static void |
30 | getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice, |
31 | const SliceOptions::TransitiveFilter &filter = nullptr) { |
32 | if (!op) |
33 | return; |
34 | |
35 | // Evaluate whether we should keep this use. |
36 | // This is useful in particular to implement scoping; i.e. return the |
37 | // transitive forwardSlice in the current scope. |
38 | if (filter && !filter(op)) |
39 | return; |
40 | |
41 | for (Region ®ion : op->getRegions()) |
42 | for (Block &block : region) |
43 | for (Operation &blockOp : block) |
44 | if (forwardSlice->count(key: &blockOp) == 0) |
45 | getForwardSliceImpl(op: &blockOp, forwardSlice, filter); |
46 | for (Value result : op->getResults()) { |
47 | for (Operation *userOp : result.getUsers()) |
48 | if (forwardSlice->count(key: userOp) == 0) |
49 | getForwardSliceImpl(op: userOp, forwardSlice, filter); |
50 | } |
51 | |
52 | forwardSlice->insert(X: op); |
53 | } |
54 | |
55 | void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, |
56 | const ForwardSliceOptions &options) { |
57 | getForwardSliceImpl(op, forwardSlice, filter: options.filter); |
58 | if (!options.inclusive) { |
59 | // Don't insert the top level operation, we just queried on it and don't |
60 | // want it in the results. |
61 | forwardSlice->remove(X: op); |
62 | } |
63 | |
64 | // Reverse to get back the actual topological order. |
65 | // std::reverse does not work out of the box on SetVector and I want an |
66 | // in-place swap based thing (the real std::reverse, not the LLVM adapter). |
67 | SmallVector<Operation *, 0> v(forwardSlice->takeVector()); |
68 | forwardSlice->insert(Start: v.rbegin(), End: v.rend()); |
69 | } |
70 | |
71 | void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, |
72 | const SliceOptions &options) { |
73 | for (Operation *user : root.getUsers()) |
74 | getForwardSliceImpl(op: user, forwardSlice, filter: options.filter); |
75 | |
76 | // Reverse to get back the actual topological order. |
77 | // std::reverse does not work out of the box on SetVector and I want an |
78 | // in-place swap based thing (the real std::reverse, not the LLVM adapter). |
79 | SmallVector<Operation *, 0> v(forwardSlice->takeVector()); |
80 | forwardSlice->insert(Start: v.rbegin(), End: v.rend()); |
81 | } |
82 | |
83 | static LogicalResult getBackwardSliceImpl(Operation *op, |
84 | SetVector<Operation *> *backwardSlice, |
85 | const BackwardSliceOptions &options) { |
86 | if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
87 | return success(); |
88 | |
89 | // Evaluate whether we should keep this def. |
90 | // This is useful in particular to implement scoping; i.e. return the |
91 | // transitive backwardSlice in the current scope. |
92 | if (options.filter && !options.filter(op)) |
93 | return success(); |
94 | |
95 | auto processValue = [&](Value value) { |
96 | if (auto *definingOp = value.getDefiningOp()) { |
97 | if (backwardSlice->count(key: definingOp) == 0) |
98 | return getBackwardSliceImpl(op: definingOp, backwardSlice, options); |
99 | } else if (auto blockArg = dyn_cast<BlockArgument>(Val&: value)) { |
100 | if (options.omitBlockArguments) |
101 | return success(); |
102 | |
103 | Block *block = blockArg.getOwner(); |
104 | Operation *parentOp = block->getParentOp(); |
105 | // TODO: determine whether we want to recurse backward into the other |
106 | // blocks of parentOp, which are not technically backward unless they flow |
107 | // into us. For now, just bail. |
108 | if (parentOp && backwardSlice->count(key: parentOp) == 0) { |
109 | if (parentOp->getNumRegions() == 1 && |
110 | llvm::hasSingleElement(C&: parentOp->getRegion(index: 0).getBlocks())) { |
111 | return getBackwardSliceImpl(op: parentOp, backwardSlice, options); |
112 | } |
113 | } |
114 | } else { |
115 | return failure(); |
116 | } |
117 | return success(); |
118 | }; |
119 | |
120 | bool succeeded = true; |
121 | |
122 | if (!options.omitUsesFromAbove) { |
123 | llvm::for_each(Range: op->getRegions(), F: [&](Region ®ion) { |
124 | // Walk this region recursively to collect the regions that descend from |
125 | // this op's nested regions (inclusive). |
126 | SmallPtrSet<Region *, 4> descendents; |
127 | region.walk( |
128 | callback: [&](Region *childRegion) { descendents.insert(Ptr: childRegion); }); |
129 | region.walk(callback: [&](Operation *op) { |
130 | for (OpOperand &operand : op->getOpOperands()) { |
131 | if (!descendents.contains(Ptr: operand.get().getParentRegion())) |
132 | if (!processValue(operand.get()).succeeded()) { |
133 | return WalkResult::interrupt(); |
134 | } |
135 | } |
136 | return WalkResult::advance(); |
137 | }); |
138 | }); |
139 | } |
140 | llvm::for_each(Range: op->getOperands(), F: processValue); |
141 | |
142 | backwardSlice->insert(X: op); |
143 | return success(IsSuccess: succeeded); |
144 | } |
145 | |
146 | LogicalResult mlir::getBackwardSlice(Operation *op, |
147 | SetVector<Operation *> *backwardSlice, |
148 | const BackwardSliceOptions &options) { |
149 | LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options); |
150 | |
151 | if (!options.inclusive) { |
152 | // Don't insert the top level operation, we just queried on it and don't |
153 | // want it in the results. |
154 | backwardSlice->remove(X: op); |
155 | } |
156 | return result; |
157 | } |
158 | |
159 | LogicalResult mlir::getBackwardSlice(Value root, |
160 | SetVector<Operation *> *backwardSlice, |
161 | const BackwardSliceOptions &options) { |
162 | if (Operation *definingOp = root.getDefiningOp()) { |
163 | return getBackwardSlice(op: definingOp, backwardSlice, options); |
164 | } |
165 | Operation *bbAargOwner = cast<BlockArgument>(Val&: root).getOwner()->getParentOp(); |
166 | return getBackwardSlice(op: bbAargOwner, backwardSlice, options); |
167 | } |
168 | |
169 | SetVector<Operation *> |
170 | mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, |
171 | const ForwardSliceOptions &forwardSliceOptions) { |
172 | SetVector<Operation *> slice; |
173 | slice.insert(X: op); |
174 | |
175 | unsigned currentIndex = 0; |
176 | SetVector<Operation *> backwardSlice; |
177 | SetVector<Operation *> forwardSlice; |
178 | while (currentIndex != slice.size()) { |
179 | auto *currentOp = (slice)[currentIndex]; |
180 | // Compute and insert the backwardSlice starting from currentOp. |
181 | backwardSlice.clear(); |
182 | LogicalResult result = |
183 | getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions); |
184 | assert(result.succeeded()); |
185 | (void)result; |
186 | slice.insert_range(R&: backwardSlice); |
187 | |
188 | // Compute and insert the forwardSlice starting from currentOp. |
189 | forwardSlice.clear(); |
190 | getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions); |
191 | slice.insert_range(R&: forwardSlice); |
192 | ++currentIndex; |
193 | } |
194 | return topologicalSort(toSort: slice); |
195 | } |
196 | |
197 | /// Returns true if `value` (transitively) depends on iteration-carried values |
198 | /// of the given `ancestorOp`. |
199 | static bool dependsOnCarriedVals(Value value, |
200 | ArrayRef<BlockArgument> iterCarriedArgs, |
201 | Operation *ancestorOp) { |
202 | // Compute the backward slice of the value. |
203 | SetVector<Operation *> slice; |
204 | BackwardSliceOptions sliceOptions; |
205 | sliceOptions.filter = [&](Operation *op) { |
206 | return !ancestorOp->isAncestor(other: op); |
207 | }; |
208 | LogicalResult result = getBackwardSlice(root: value, backwardSlice: &slice, options: sliceOptions); |
209 | assert(result.succeeded()); |
210 | (void)result; |
211 | |
212 | // Check that none of the operands of the operations in the backward slice are |
213 | // loop iteration arguments, and neither is the value itself. |
214 | SmallPtrSet<Value, 8> iterCarriedValSet(llvm::from_range, iterCarriedArgs); |
215 | if (iterCarriedValSet.contains(Ptr: value)) |
216 | return true; |
217 | |
218 | for (Operation *op : slice) |
219 | for (Value operand : op->getOperands()) |
220 | if (iterCarriedValSet.contains(Ptr: operand)) |
221 | return true; |
222 | |
223 | return false; |
224 | } |
225 | |
226 | /// Utility to match a generic reduction given a list of iteration-carried |
227 | /// arguments, `iterCarriedArgs` and the position of the potential reduction |
228 | /// argument within the list, `redPos`. If a reduction is matched, returns the |
229 | /// reduced value and the topologically-sorted list of combiner operations |
230 | /// involved in the reduction. Otherwise, returns a null value. |
231 | /// |
232 | /// The matching algorithm relies on the following invariants, which are subject |
233 | /// to change: |
234 | /// 1. The first combiner operation must be a binary operation with the |
235 | /// iteration-carried value and the reduced value as operands. |
236 | /// 2. The iteration-carried value and combiner operations must be side |
237 | /// effect-free, have single result and a single use. |
238 | /// 3. Combiner operations must be immediately nested in the region op |
239 | /// performing the reduction. |
240 | /// 4. Reduction def-use chain must end in a terminator op that yields the |
241 | /// next iteration/output values in the same order as the iteration-carried |
242 | /// values in `iterCarriedArgs`. |
243 | /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values |
244 | /// of the region op performing the reduction. |
245 | /// |
246 | /// This utility is generic enough to detect reductions involving multiple |
247 | /// combiner operations (disabled for now) across multiple dialects, including |
248 | /// Linalg, Affine and SCF. For the sake of genericity, it does not return |
249 | /// specific enum values for the combiner operations since its goal is also |
250 | /// matching reductions without pre-defined semantics in core MLIR. It's up to |
251 | /// each client to make sense out of the list of combiner operations. It's also |
252 | /// up to each client to check for additional invariants on the expected |
253 | /// reductions not covered by this generic matching. |
254 | Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs, |
255 | unsigned redPos, |
256 | SmallVectorImpl<Operation *> &combinerOps) { |
257 | assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds" ); |
258 | |
259 | BlockArgument redCarriedVal = iterCarriedArgs[redPos]; |
260 | if (!redCarriedVal.hasOneUse()) |
261 | return nullptr; |
262 | |
263 | // For now, the first combiner op must be a binary op. |
264 | Operation *combinerOp = *redCarriedVal.getUsers().begin(); |
265 | if (combinerOp->getNumOperands() != 2) |
266 | return nullptr; |
267 | Value reducedVal = combinerOp->getOperand(idx: 0) == redCarriedVal |
268 | ? combinerOp->getOperand(idx: 1) |
269 | : combinerOp->getOperand(idx: 0); |
270 | |
271 | Operation *redRegionOp = |
272 | iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); |
273 | if (dependsOnCarriedVals(value: reducedVal, iterCarriedArgs, ancestorOp: redRegionOp)) |
274 | return nullptr; |
275 | |
276 | // Traverse the def-use chain starting from the first combiner op until a |
277 | // terminator is found. Gather all the combiner ops along the way in |
278 | // topological order. |
279 | while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) { |
280 | if (!isMemoryEffectFree(op: combinerOp) || combinerOp->getNumResults() != 1 || |
281 | !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp) |
282 | return nullptr; |
283 | |
284 | combinerOps.push_back(Elt: combinerOp); |
285 | combinerOp = *combinerOp->getUsers().begin(); |
286 | } |
287 | |
288 | // Limit matching to single combiner op until we can properly test reductions |
289 | // involving multiple combiners. |
290 | if (combinerOps.size() != 1) |
291 | return nullptr; |
292 | |
293 | // Check that the yielded value is in the same position as in |
294 | // `iterCarriedArgs`. |
295 | Operation *terminatorOp = combinerOp; |
296 | if (terminatorOp->getOperand(idx: redPos) != combinerOps.back()->getResults()[0]) |
297 | return nullptr; |
298 | |
299 | return reducedVal; |
300 | } |
301 | |