1//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Analysis functions specific to slicing in Function.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Analysis/SliceAnalysis.h"
14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/IR/Block.h"
16#include "mlir/IR/Operation.h"
17#include "mlir/Interfaces/SideEffectInterfaces.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SetVector.h"
21#include "llvm/ADT/SmallPtrSet.h"
22
23///
24/// Implements Analysis functions specific to slicing in Function.
25///
26
27using namespace mlir;
28
29static void
30getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
31 const SliceOptions::TransitiveFilter &filter = nullptr) {
32 if (!op)
33 return;
34
35 // Evaluate whether we should keep this use.
36 // This is useful in particular to implement scoping; i.e. return the
37 // transitive forwardSlice in the current scope.
38 if (filter && !filter(op))
39 return;
40
41 for (Region &region : op->getRegions())
42 for (Block &block : region)
43 for (Operation &blockOp : block)
44 if (forwardSlice->count(key: &blockOp) == 0)
45 getForwardSliceImpl(op: &blockOp, forwardSlice, filter);
46 for (Value result : op->getResults()) {
47 for (Operation *userOp : result.getUsers())
48 if (forwardSlice->count(key: userOp) == 0)
49 getForwardSliceImpl(op: userOp, forwardSlice, filter);
50 }
51
52 forwardSlice->insert(X: op);
53}
54
55void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
56 const ForwardSliceOptions &options) {
57 getForwardSliceImpl(op, forwardSlice, filter: options.filter);
58 if (!options.inclusive) {
59 // Don't insert the top level operation, we just queried on it and don't
60 // want it in the results.
61 forwardSlice->remove(X: op);
62 }
63
64 // Reverse to get back the actual topological order.
65 // std::reverse does not work out of the box on SetVector and I want an
66 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
67 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
68 forwardSlice->insert(Start: v.rbegin(), End: v.rend());
69}
70
71void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
72 const SliceOptions &options) {
73 for (Operation *user : root.getUsers())
74 getForwardSliceImpl(op: user, forwardSlice, filter: options.filter);
75
76 // Reverse to get back the actual topological order.
77 // std::reverse does not work out of the box on SetVector and I want an
78 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
79 SmallVector<Operation *, 0> v(forwardSlice->takeVector());
80 forwardSlice->insert(Start: v.rbegin(), End: v.rend());
81}
82
83static LogicalResult getBackwardSliceImpl(Operation *op,
84 SetVector<Operation *> *backwardSlice,
85 const BackwardSliceOptions &options) {
86 if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
87 return success();
88
89 // Evaluate whether we should keep this def.
90 // This is useful in particular to implement scoping; i.e. return the
91 // transitive backwardSlice in the current scope.
92 if (options.filter && !options.filter(op))
93 return success();
94
95 auto processValue = [&](Value value) {
96 if (auto *definingOp = value.getDefiningOp()) {
97 if (backwardSlice->count(key: definingOp) == 0)
98 return getBackwardSliceImpl(op: definingOp, backwardSlice, options);
99 } else if (auto blockArg = dyn_cast<BlockArgument>(Val&: value)) {
100 if (options.omitBlockArguments)
101 return success();
102
103 Block *block = blockArg.getOwner();
104 Operation *parentOp = block->getParentOp();
105 // TODO: determine whether we want to recurse backward into the other
106 // blocks of parentOp, which are not technically backward unless they flow
107 // into us. For now, just bail.
108 if (parentOp && backwardSlice->count(key: parentOp) == 0) {
109 if (parentOp->getNumRegions() == 1 &&
110 llvm::hasSingleElement(C&: parentOp->getRegion(index: 0).getBlocks())) {
111 return getBackwardSliceImpl(op: parentOp, backwardSlice, options);
112 }
113 }
114 } else {
115 return failure();
116 }
117 return success();
118 };
119
120 bool succeeded = true;
121
122 if (!options.omitUsesFromAbove) {
123 llvm::for_each(Range: op->getRegions(), F: [&](Region &region) {
124 // Walk this region recursively to collect the regions that descend from
125 // this op's nested regions (inclusive).
126 SmallPtrSet<Region *, 4> descendents;
127 region.walk(
128 callback: [&](Region *childRegion) { descendents.insert(Ptr: childRegion); });
129 region.walk(callback: [&](Operation *op) {
130 for (OpOperand &operand : op->getOpOperands()) {
131 if (!descendents.contains(Ptr: operand.get().getParentRegion()))
132 if (!processValue(operand.get()).succeeded()) {
133 return WalkResult::interrupt();
134 }
135 }
136 return WalkResult::advance();
137 });
138 });
139 }
140 llvm::for_each(Range: op->getOperands(), F: processValue);
141
142 backwardSlice->insert(X: op);
143 return success(IsSuccess: succeeded);
144}
145
146LogicalResult mlir::getBackwardSlice(Operation *op,
147 SetVector<Operation *> *backwardSlice,
148 const BackwardSliceOptions &options) {
149 LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);
150
151 if (!options.inclusive) {
152 // Don't insert the top level operation, we just queried on it and don't
153 // want it in the results.
154 backwardSlice->remove(X: op);
155 }
156 return result;
157}
158
159LogicalResult mlir::getBackwardSlice(Value root,
160 SetVector<Operation *> *backwardSlice,
161 const BackwardSliceOptions &options) {
162 if (Operation *definingOp = root.getDefiningOp()) {
163 return getBackwardSlice(op: definingOp, backwardSlice, options);
164 }
165 Operation *bbAargOwner = cast<BlockArgument>(Val&: root).getOwner()->getParentOp();
166 return getBackwardSlice(op: bbAargOwner, backwardSlice, options);
167}
168
169SetVector<Operation *>
170mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
171 const ForwardSliceOptions &forwardSliceOptions) {
172 SetVector<Operation *> slice;
173 slice.insert(X: op);
174
175 unsigned currentIndex = 0;
176 SetVector<Operation *> backwardSlice;
177 SetVector<Operation *> forwardSlice;
178 while (currentIndex != slice.size()) {
179 auto *currentOp = (slice)[currentIndex];
180 // Compute and insert the backwardSlice starting from currentOp.
181 backwardSlice.clear();
182 LogicalResult result =
183 getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions);
184 assert(result.succeeded());
185 (void)result;
186 slice.insert_range(R&: backwardSlice);
187
188 // Compute and insert the forwardSlice starting from currentOp.
189 forwardSlice.clear();
190 getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions);
191 slice.insert_range(R&: forwardSlice);
192 ++currentIndex;
193 }
194 return topologicalSort(toSort: slice);
195}
196
197/// Returns true if `value` (transitively) depends on iteration-carried values
198/// of the given `ancestorOp`.
199static bool dependsOnCarriedVals(Value value,
200 ArrayRef<BlockArgument> iterCarriedArgs,
201 Operation *ancestorOp) {
202 // Compute the backward slice of the value.
203 SetVector<Operation *> slice;
204 BackwardSliceOptions sliceOptions;
205 sliceOptions.filter = [&](Operation *op) {
206 return !ancestorOp->isAncestor(other: op);
207 };
208 LogicalResult result = getBackwardSlice(root: value, backwardSlice: &slice, options: sliceOptions);
209 assert(result.succeeded());
210 (void)result;
211
212 // Check that none of the operands of the operations in the backward slice are
213 // loop iteration arguments, and neither is the value itself.
214 SmallPtrSet<Value, 8> iterCarriedValSet(llvm::from_range, iterCarriedArgs);
215 if (iterCarriedValSet.contains(Ptr: value))
216 return true;
217
218 for (Operation *op : slice)
219 for (Value operand : op->getOperands())
220 if (iterCarriedValSet.contains(Ptr: operand))
221 return true;
222
223 return false;
224}
225
226/// Utility to match a generic reduction given a list of iteration-carried
227/// arguments, `iterCarriedArgs` and the position of the potential reduction
228/// argument within the list, `redPos`. If a reduction is matched, returns the
229/// reduced value and the topologically-sorted list of combiner operations
230/// involved in the reduction. Otherwise, returns a null value.
231///
232/// The matching algorithm relies on the following invariants, which are subject
233/// to change:
234/// 1. The first combiner operation must be a binary operation with the
235/// iteration-carried value and the reduced value as operands.
236/// 2. The iteration-carried value and combiner operations must be side
237/// effect-free, have single result and a single use.
238/// 3. Combiner operations must be immediately nested in the region op
239/// performing the reduction.
240/// 4. Reduction def-use chain must end in a terminator op that yields the
241/// next iteration/output values in the same order as the iteration-carried
242/// values in `iterCarriedArgs`.
243/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
244/// of the region op performing the reduction.
245///
246/// This utility is generic enough to detect reductions involving multiple
247/// combiner operations (disabled for now) across multiple dialects, including
248/// Linalg, Affine and SCF. For the sake of genericity, it does not return
249/// specific enum values for the combiner operations since its goal is also
250/// matching reductions without pre-defined semantics in core MLIR. It's up to
251/// each client to make sense out of the list of combiner operations. It's also
252/// up to each client to check for additional invariants on the expected
253/// reductions not covered by this generic matching.
254Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
255 unsigned redPos,
256 SmallVectorImpl<Operation *> &combinerOps) {
257 assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
258
259 BlockArgument redCarriedVal = iterCarriedArgs[redPos];
260 if (!redCarriedVal.hasOneUse())
261 return nullptr;
262
263 // For now, the first combiner op must be a binary op.
264 Operation *combinerOp = *redCarriedVal.getUsers().begin();
265 if (combinerOp->getNumOperands() != 2)
266 return nullptr;
267 Value reducedVal = combinerOp->getOperand(idx: 0) == redCarriedVal
268 ? combinerOp->getOperand(idx: 1)
269 : combinerOp->getOperand(idx: 0);
270
271 Operation *redRegionOp =
272 iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
273 if (dependsOnCarriedVals(value: reducedVal, iterCarriedArgs, ancestorOp: redRegionOp))
274 return nullptr;
275
276 // Traverse the def-use chain starting from the first combiner op until a
277 // terminator is found. Gather all the combiner ops along the way in
278 // topological order.
279 while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
280 if (!isMemoryEffectFree(op: combinerOp) || combinerOp->getNumResults() != 1 ||
281 !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
282 return nullptr;
283
284 combinerOps.push_back(Elt: combinerOp);
285 combinerOp = *combinerOp->getUsers().begin();
286 }
287
288 // Limit matching to single combiner op until we can properly test reductions
289 // involving multiple combiners.
290 if (combinerOps.size() != 1)
291 return nullptr;
292
293 // Check that the yielded value is in the same position as in
294 // `iterCarriedArgs`.
295 Operation *terminatorOp = combinerOp;
296 if (terminatorOp->getOperand(idx: redPos) != combinerOps.back()->getResults()[0])
297 return nullptr;
298
299 return reducedVal;
300}
301

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Analysis/SliceAnalysis.cpp