1//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Analysis functions specific to slicing in Function.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Analysis/SliceAnalysis.h"
14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/IR/Block.h"
16#include "mlir/IR/Operation.h"
17#include "mlir/Interfaces/SideEffectInterfaces.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SetVector.h"
21
22///
23/// Implements Analysis functions specific to slicing in Function.
24///
25
26using namespace mlir;
27
28static void
29getForwardSliceImpl(Operation *op, DenseSet<Operation *> &visited,
30 SetVector<Operation *> *forwardSlice,
31 const SliceOptions::TransitiveFilter &filter = nullptr) {
32 if (!op)
33 return;
34
35 // Evaluate whether we should keep this use.
36 // This is useful in particular to implement scoping; i.e. return the
37 // transitive forwardSlice in the current scope.
38 if (filter && !filter(op))
39 return;
40
41 for (Region &region : op->getRegions())
42 for (Block &block : region)
43 for (Operation &blockOp : block)
44 if (forwardSlice->count(key: &blockOp) == 0) {
45 // We don't have to check if the 'blockOp' is already visited because
46 // there cannot be a traversal path from this nested op to the parent
47 // and thus a cycle cannot be closed here. We still have to mark it
48 // as visited to stop before visiting this operation again if it is
49 // part of a cycle.
50 visited.insert(V: &blockOp);
51 getForwardSliceImpl(op: &blockOp, visited, forwardSlice, filter);
52 visited.erase(V: &blockOp);
53 }
54
55 for (Value result : op->getResults())
56 for (Operation *userOp : result.getUsers()) {
57 // A cycle can only occur within a basic block (not across regions or
58 // basic blocks) because the parent region must be a graph region, graph
59 // regions are restricted to always have 0 or 1 blocks, and there cannot
60 // be a def-use edge from a nested operation to an operation in an
61 // ancestor region. Therefore, we don't have to but may use the same
62 // 'visited' set across regions/blocks as long as we remove operations
63 // from the set again when the DFS traverses back from the leaf to the
64 // root.
65 if (forwardSlice->count(key: userOp) == 0 && visited.insert(V: userOp).second)
66 getForwardSliceImpl(op: userOp, visited, forwardSlice, filter);
67
68 visited.erase(V: userOp);
69 }
70
71 forwardSlice->insert(X: op);
72}
73
74void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
75 const ForwardSliceOptions &options) {
76 DenseSet<Operation *> visited;
77 visited.insert(V: op);
78 getForwardSliceImpl(op, visited, forwardSlice, filter: options.filter);
79 if (!options.inclusive) {
80 // Don't insert the top level operation, we just queried on it and don't
81 // want it in the results.
82 forwardSlice->remove(X: op);
83 }
84
85 // Reverse to get back the actual topological order.
86 // std::reverse does not work out of the box on SetVector and I want an
87 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
88 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
89 forwardSlice->insert(Start: v.rbegin(), End: v.rend());
90}
91
92void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
93 const SliceOptions &options) {
94 DenseSet<Operation *> visited;
95 for (Operation *user : root.getUsers()) {
96 visited.insert(V: user);
97 getForwardSliceImpl(op: user, visited, forwardSlice, filter: options.filter);
98 visited.erase(V: user);
99 }
100
101 // Reverse to get back the actual topological order.
102 // std::reverse does not work out of the box on SetVector and I want an
103 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
104 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
105 forwardSlice->insert(Start: v.rbegin(), End: v.rend());
106}
107
108static LogicalResult getBackwardSliceImpl(Operation *op,
109 DenseSet<Operation *> &visited,
110 SetVector<Operation *> *backwardSlice,
111 const BackwardSliceOptions &options) {
112 if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
113 return success();
114
115 // Evaluate whether we should keep this def.
116 // This is useful in particular to implement scoping; i.e. return the
117 // transitive backwardSlice in the current scope.
118 if (options.filter && !options.filter(op))
119 return success();
120
121 auto processValue = [&](Value value) {
122 if (auto *definingOp = value.getDefiningOp()) {
123 if (backwardSlice->count(key: definingOp) == 0 &&
124 visited.insert(V: definingOp).second)
125 return getBackwardSliceImpl(op: definingOp, visited, backwardSlice,
126 options);
127
128 visited.erase(V: definingOp);
129 } else if (auto blockArg = dyn_cast<BlockArgument>(Val&: value)) {
130 if (options.omitBlockArguments)
131 return success();
132
133 Block *block = blockArg.getOwner();
134 Operation *parentOp = block->getParentOp();
135 // TODO: determine whether we want to recurse backward into the other
136 // blocks of parentOp, which are not technically backward unless they flow
137 // into us. For now, just bail.
138 if (parentOp && backwardSlice->count(key: parentOp) == 0) {
139 if (parentOp->getNumRegions() == 1 &&
140 llvm::hasSingleElement(C&: parentOp->getRegion(index: 0).getBlocks())) {
141 return getBackwardSliceImpl(op: parentOp, visited, backwardSlice,
142 options);
143 }
144 }
145 } else {
146 return failure();
147 }
148 return success();
149 };
150
151 bool succeeded = true;
152
153 if (!options.omitUsesFromAbove) {
154 llvm::for_each(Range: op->getRegions(), F: [&](Region &region) {
155 // Walk this region recursively to collect the regions that descend from
156 // this op's nested regions (inclusive).
157 SmallPtrSet<Region *, 4> descendents;
158 region.walk(
159 callback: [&](Region *childRegion) { descendents.insert(Ptr: childRegion); });
160 region.walk(callback: [&](Operation *op) {
161 for (OpOperand &operand : op->getOpOperands()) {
162 if (!descendents.contains(Ptr: operand.get().getParentRegion()))
163 if (!processValue(operand.get()).succeeded()) {
164 return WalkResult::interrupt();
165 }
166 }
167 return WalkResult::advance();
168 });
169 });
170 }
171 llvm::for_each(Range: op->getOperands(), F: processValue);
172
173 backwardSlice->insert(X: op);
174 return success(IsSuccess: succeeded);
175}
176
177LogicalResult mlir::getBackwardSlice(Operation *op,
178 SetVector<Operation *> *backwardSlice,
179 const BackwardSliceOptions &options) {
180 DenseSet<Operation *> visited;
181 visited.insert(V: op);
182 LogicalResult result =
183 getBackwardSliceImpl(op, visited, backwardSlice, options);
184
185 if (!options.inclusive) {
186 // Don't insert the top level operation, we just queried on it and don't
187 // want it in the results.
188 backwardSlice->remove(X: op);
189 }
190 return result;
191}
192
193LogicalResult mlir::getBackwardSlice(Value root,
194 SetVector<Operation *> *backwardSlice,
195 const BackwardSliceOptions &options) {
196 if (Operation *definingOp = root.getDefiningOp()) {
197 return getBackwardSlice(op: definingOp, backwardSlice, options);
198 }
199 Operation *bbAargOwner = cast<BlockArgument>(Val&: root).getOwner()->getParentOp();
200 return getBackwardSlice(op: bbAargOwner, backwardSlice, options);
201}
202
203SetVector<Operation *>
204mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
205 const ForwardSliceOptions &forwardSliceOptions) {
206 SetVector<Operation *> slice;
207 slice.insert(X: op);
208
209 unsigned currentIndex = 0;
210 SetVector<Operation *> backwardSlice;
211 SetVector<Operation *> forwardSlice;
212 while (currentIndex != slice.size()) {
213 auto *currentOp = (slice)[currentIndex];
214 // Compute and insert the backwardSlice starting from currentOp.
215 backwardSlice.clear();
216 LogicalResult result =
217 getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions);
218 assert(result.succeeded());
219 (void)result;
220 slice.insert_range(R&: backwardSlice);
221
222 // Compute and insert the forwardSlice starting from currentOp.
223 forwardSlice.clear();
224 getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions);
225 slice.insert_range(R&: forwardSlice);
226 ++currentIndex;
227 }
228 return topologicalSort(toSort: slice);
229}
230
231/// Returns true if `value` (transitively) depends on iteration-carried values
232/// of the given `ancestorOp`.
233static bool dependsOnCarriedVals(Value value,
234 ArrayRef<BlockArgument> iterCarriedArgs,
235 Operation *ancestorOp) {
236 // Compute the backward slice of the value.
237 SetVector<Operation *> slice;
238 BackwardSliceOptions sliceOptions;
239 sliceOptions.filter = [&](Operation *op) {
240 return !ancestorOp->isAncestor(other: op);
241 };
242 LogicalResult result = getBackwardSlice(root: value, backwardSlice: &slice, options: sliceOptions);
243 assert(result.succeeded());
244 (void)result;
245
246 // Check that none of the operands of the operations in the backward slice are
247 // loop iteration arguments, and neither is the value itself.
248 SmallPtrSet<Value, 8> iterCarriedValSet(llvm::from_range, iterCarriedArgs);
249 if (iterCarriedValSet.contains(Ptr: value))
250 return true;
251
252 for (Operation *op : slice)
253 for (Value operand : op->getOperands())
254 if (iterCarriedValSet.contains(Ptr: operand))
255 return true;
256
257 return false;
258}
259
260/// Utility to match a generic reduction given a list of iteration-carried
261/// arguments, `iterCarriedArgs` and the position of the potential reduction
262/// argument within the list, `redPos`. If a reduction is matched, returns the
263/// reduced value and the topologically-sorted list of combiner operations
264/// involved in the reduction. Otherwise, returns a null value.
265///
266/// The matching algorithm relies on the following invariants, which are subject
267/// to change:
268/// 1. The first combiner operation must be a binary operation with the
269/// iteration-carried value and the reduced value as operands.
270/// 2. The iteration-carried value and combiner operations must be side
271/// effect-free, have single result and a single use.
272/// 3. Combiner operations must be immediately nested in the region op
273/// performing the reduction.
274/// 4. Reduction def-use chain must end in a terminator op that yields the
275/// next iteration/output values in the same order as the iteration-carried
276/// values in `iterCarriedArgs`.
277/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
278/// of the region op performing the reduction.
279///
280/// This utility is generic enough to detect reductions involving multiple
281/// combiner operations (disabled for now) across multiple dialects, including
282/// Linalg, Affine and SCF. For the sake of genericity, it does not return
283/// specific enum values for the combiner operations since its goal is also
284/// matching reductions without pre-defined semantics in core MLIR. It's up to
285/// each client to make sense out of the list of combiner operations. It's also
286/// up to each client to check for additional invariants on the expected
287/// reductions not covered by this generic matching.
288Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
289 unsigned redPos,
290 SmallVectorImpl<Operation *> &combinerOps) {
291 assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
292
293 BlockArgument redCarriedVal = iterCarriedArgs[redPos];
294 if (!redCarriedVal.hasOneUse())
295 return nullptr;
296
297 // For now, the first combiner op must be a binary op.
298 Operation *combinerOp = *redCarriedVal.getUsers().begin();
299 if (combinerOp->getNumOperands() != 2)
300 return nullptr;
301 Value reducedVal = combinerOp->getOperand(idx: 0) == redCarriedVal
302 ? combinerOp->getOperand(idx: 1)
303 : combinerOp->getOperand(idx: 0);
304
305 Operation *redRegionOp =
306 iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
307 if (dependsOnCarriedVals(value: reducedVal, iterCarriedArgs, ancestorOp: redRegionOp))
308 return nullptr;
309
310 // Traverse the def-use chain starting from the first combiner op until a
311 // terminator is found. Gather all the combiner ops along the way in
312 // topological order.
313 while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
314 if (!isMemoryEffectFree(op: combinerOp) || combinerOp->getNumResults() != 1 ||
315 !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
316 return nullptr;
317
318 combinerOps.push_back(Elt: combinerOp);
319 combinerOp = *combinerOp->getUsers().begin();
320 }
321
322 // Limit matching to single combiner op until we can properly test reductions
323 // involving multiple combiners.
324 if (combinerOps.size() != 1)
325 return nullptr;
326
327 // Check that the yielded value is in the same position as in
328 // `iterCarriedArgs`.
329 Operation *terminatorOp = combinerOp;
330 if (terminatorOp->getOperand(idx: redPos) != combinerOps.back()->getResults()[0])
331 return nullptr;
332
333 return reducedVal;
334}
335

source code of mlir/lib/Analysis/SliceAnalysis.cpp