1//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Analysis functions specific to slicing in Function.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Analysis/SliceAnalysis.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/Interfaces/SideEffectInterfaces.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/SetVector.h"
19#include "llvm/ADT/SmallPtrSet.h"
20
21///
22/// Implements Analysis functions specific to slicing in Function.
23///
24
25using namespace mlir;
26
27static void
28getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
29 const SliceOptions::TransitiveFilter &filter = nullptr) {
30 if (!op)
31 return;
32
33 // Evaluate whether we should keep this use.
34 // This is useful in particular to implement scoping; i.e. return the
35 // transitive forwardSlice in the current scope.
36 if (filter && !filter(op))
37 return;
38
39 for (Region &region : op->getRegions())
40 for (Block &block : region)
41 for (Operation &blockOp : block)
42 if (forwardSlice->count(key: &blockOp) == 0)
43 getForwardSliceImpl(op: &blockOp, forwardSlice, filter);
44 for (Value result : op->getResults()) {
45 for (Operation *userOp : result.getUsers())
46 if (forwardSlice->count(key: userOp) == 0)
47 getForwardSliceImpl(op: userOp, forwardSlice, filter);
48 }
49
50 forwardSlice->insert(X: op);
51}
52
53void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
54 const ForwardSliceOptions &options) {
55 getForwardSliceImpl(op, forwardSlice, filter: options.filter);
56 if (!options.inclusive) {
57 // Don't insert the top level operation, we just queried on it and don't
58 // want it in the results.
59 forwardSlice->remove(X: op);
60 }
61
62 // Reverse to get back the actual topological order.
63 // std::reverse does not work out of the box on SetVector and I want an
64 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
65 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
66 forwardSlice->insert(Start: v.rbegin(), End: v.rend());
67}
68
69void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
70 const SliceOptions &options) {
71 for (Operation *user : root.getUsers())
72 getForwardSliceImpl(op: user, forwardSlice, filter: options.filter);
73
74 // Reverse to get back the actual topological order.
75 // std::reverse does not work out of the box on SetVector and I want an
76 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
77 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
78 forwardSlice->insert(Start: v.rbegin(), End: v.rend());
79}
80
81static void getBackwardSliceImpl(Operation *op,
82 SetVector<Operation *> *backwardSlice,
83 const BackwardSliceOptions &options) {
84 if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
85 return;
86
87 // Evaluate whether we should keep this def.
88 // This is useful in particular to implement scoping; i.e. return the
89 // transitive backwardSlice in the current scope.
90 if (options.filter && !options.filter(op))
91 return;
92
93 for (const auto &en : llvm::enumerate(First: op->getOperands())) {
94 auto operand = en.value();
95 if (auto *definingOp = operand.getDefiningOp()) {
96 if (backwardSlice->count(key: definingOp) == 0)
97 getBackwardSliceImpl(op: definingOp, backwardSlice, options);
98 } else if (auto blockArg = dyn_cast<BlockArgument>(Val&: operand)) {
99 if (options.omitBlockArguments)
100 continue;
101
102 Block *block = blockArg.getOwner();
103 Operation *parentOp = block->getParentOp();
104 // TODO: determine whether we want to recurse backward into the other
105 // blocks of parentOp, which are not technically backward unless they flow
106 // into us. For now, just bail.
107 if (parentOp && backwardSlice->count(key: parentOp) == 0) {
108 assert(parentOp->getNumRegions() == 1 &&
109 parentOp->getRegion(0).getBlocks().size() == 1);
110 getBackwardSliceImpl(op: parentOp, backwardSlice, options);
111 }
112 } else {
113 llvm_unreachable("No definingOp and not a block argument.");
114 }
115 }
116
117 backwardSlice->insert(X: op);
118}
119
120void mlir::getBackwardSlice(Operation *op,
121 SetVector<Operation *> *backwardSlice,
122 const BackwardSliceOptions &options) {
123 getBackwardSliceImpl(op, backwardSlice, options);
124
125 if (!options.inclusive) {
126 // Don't insert the top level operation, we just queried on it and don't
127 // want it in the results.
128 backwardSlice->remove(X: op);
129 }
130}
131
132void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
133 const BackwardSliceOptions &options) {
134 if (Operation *definingOp = root.getDefiningOp()) {
135 getBackwardSlice(op: definingOp, backwardSlice, options);
136 return;
137 }
138 Operation *bbAargOwner = cast<BlockArgument>(Val&: root).getOwner()->getParentOp();
139 getBackwardSlice(op: bbAargOwner, backwardSlice, options);
140}
141
142SetVector<Operation *>
143mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
144 const ForwardSliceOptions &forwardSliceOptions) {
145 SetVector<Operation *> slice;
146 slice.insert(X: op);
147
148 unsigned currentIndex = 0;
149 SetVector<Operation *> backwardSlice;
150 SetVector<Operation *> forwardSlice;
151 while (currentIndex != slice.size()) {
152 auto *currentOp = (slice)[currentIndex];
153 // Compute and insert the backwardSlice starting from currentOp.
154 backwardSlice.clear();
155 getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions);
156 slice.insert(Start: backwardSlice.begin(), End: backwardSlice.end());
157
158 // Compute and insert the forwardSlice starting from currentOp.
159 forwardSlice.clear();
160 getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions);
161 slice.insert(Start: forwardSlice.begin(), End: forwardSlice.end());
162 ++currentIndex;
163 }
164 return topologicalSort(toSort: slice);
165}
166
167namespace {
168/// DFS post-order implementation that maintains a global count to work across
169/// multiple invocations, to help implement topological sort on multi-root DAGs.
170/// We traverse all operations but only record the ones that appear in
171/// `toSort` for the final result.
172struct DFSState {
173 DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
174 const SetVector<Operation *> &toSort;
175 SmallVector<Operation *, 16> topologicalCounts;
176 DenseSet<Operation *> seen;
177};
178} // namespace
179
180static void dfsPostorder(Operation *root, DFSState *state) {
181 SmallVector<Operation *> queue(1, root);
182 std::vector<Operation *> ops;
183 while (!queue.empty()) {
184 Operation *current = queue.pop_back_val();
185 ops.push_back(x: current);
186 for (Operation *op : current->getUsers())
187 queue.push_back(Elt: op);
188 for (Region &region : current->getRegions()) {
189 for (Operation &op : region.getOps())
190 queue.push_back(Elt: &op);
191 }
192 }
193
194 for (Operation *op : llvm::reverse(C&: ops)) {
195 if (state->seen.insert(V: op).second && state->toSort.count(key: op) > 0)
196 state->topologicalCounts.push_back(Elt: op);
197 }
198}
199
200SetVector<Operation *>
201mlir::topologicalSort(const SetVector<Operation *> &toSort) {
202 if (toSort.empty()) {
203 return toSort;
204 }
205
206 // Run from each root with global count and `seen` set.
207 DFSState state(toSort);
208 for (auto *s : toSort) {
209 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
210 dfsPostorder(root: s, state: &state);
211 }
212
213 // Reorder and return.
214 SetVector<Operation *> res;
215 for (auto it = state.topologicalCounts.rbegin(),
216 eit = state.topologicalCounts.rend();
217 it != eit; ++it) {
218 res.insert(X: *it);
219 }
220 return res;
221}
222
223/// Returns true if `value` (transitively) depends on iteration-carried values
224/// of the given `ancestorOp`.
225static bool dependsOnCarriedVals(Value value,
226 ArrayRef<BlockArgument> iterCarriedArgs,
227 Operation *ancestorOp) {
228 // Compute the backward slice of the value.
229 SetVector<Operation *> slice;
230 BackwardSliceOptions sliceOptions;
231 sliceOptions.filter = [&](Operation *op) {
232 return !ancestorOp->isAncestor(other: op);
233 };
234 getBackwardSlice(root: value, backwardSlice: &slice, options: sliceOptions);
235
236 // Check that none of the operands of the operations in the backward slice are
237 // loop iteration arguments, and neither is the value itself.
238 SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
239 iterCarriedArgs.end());
240 if (iterCarriedValSet.contains(Ptr: value))
241 return true;
242
243 for (Operation *op : slice)
244 for (Value operand : op->getOperands())
245 if (iterCarriedValSet.contains(Ptr: operand))
246 return true;
247
248 return false;
249}
250
251/// Utility to match a generic reduction given a list of iteration-carried
252/// arguments, `iterCarriedArgs` and the position of the potential reduction
253/// argument within the list, `redPos`. If a reduction is matched, returns the
254/// reduced value and the topologically-sorted list of combiner operations
255/// involved in the reduction. Otherwise, returns a null value.
256///
257/// The matching algorithm relies on the following invariants, which are subject
258/// to change:
259/// 1. The first combiner operation must be a binary operation with the
260/// iteration-carried value and the reduced value as operands.
261/// 2. The iteration-carried value and combiner operations must be side
262/// effect-free, have single result and a single use.
263/// 3. Combiner operations must be immediately nested in the region op
264/// performing the reduction.
265/// 4. Reduction def-use chain must end in a terminator op that yields the
266/// next iteration/output values in the same order as the iteration-carried
267/// values in `iterCarriedArgs`.
268/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
269/// of the region op performing the reduction.
270///
271/// This utility is generic enough to detect reductions involving multiple
272/// combiner operations (disabled for now) across multiple dialects, including
273/// Linalg, Affine and SCF. For the sake of genericity, it does not return
274/// specific enum values for the combiner operations since its goal is also
275/// matching reductions without pre-defined semantics in core MLIR. It's up to
276/// each client to make sense out of the list of combiner operations. It's also
277/// up to each client to check for additional invariants on the expected
278/// reductions not covered by this generic matching.
279Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
280 unsigned redPos,
281 SmallVectorImpl<Operation *> &combinerOps) {
282 assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
283
284 BlockArgument redCarriedVal = iterCarriedArgs[redPos];
285 if (!redCarriedVal.hasOneUse())
286 return nullptr;
287
288 // For now, the first combiner op must be a binary op.
289 Operation *combinerOp = *redCarriedVal.getUsers().begin();
290 if (combinerOp->getNumOperands() != 2)
291 return nullptr;
292 Value reducedVal = combinerOp->getOperand(idx: 0) == redCarriedVal
293 ? combinerOp->getOperand(idx: 1)
294 : combinerOp->getOperand(idx: 0);
295
296 Operation *redRegionOp =
297 iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
298 if (dependsOnCarriedVals(value: reducedVal, iterCarriedArgs, ancestorOp: redRegionOp))
299 return nullptr;
300
301 // Traverse the def-use chain starting from the first combiner op until a
302 // terminator is found. Gather all the combiner ops along the way in
303 // topological order.
304 while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
305 if (!isMemoryEffectFree(op: combinerOp) || combinerOp->getNumResults() != 1 ||
306 !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
307 return nullptr;
308
309 combinerOps.push_back(Elt: combinerOp);
310 combinerOp = *combinerOp->getUsers().begin();
311 }
312
313 // Limit matching to single combiner op until we can properly test reductions
314 // involving multiple combiners.
315 if (combinerOps.size() != 1)
316 return nullptr;
317
318 // Check that the yielded value is in the same position as in
319 // `iterCarriedArgs`.
320 Operation *terminatorOp = combinerOp;
321 if (terminatorOp->getOperand(idx: redPos) != combinerOps.back()->getResults()[0])
322 return nullptr;
323
324 return reducedVal;
325}
326

source code of mlir/lib/Analysis/SliceAnalysis.cpp