1 | //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
10 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
11 | #include "mlir/Analysis/DataFlowFramework.h" |
12 | #include "mlir/IR/Attributes.h" |
13 | #include "mlir/IR/Operation.h" |
14 | #include "mlir/IR/Region.h" |
15 | #include "mlir/IR/SymbolTable.h" |
16 | #include "mlir/IR/Value.h" |
17 | #include "mlir/IR/ValueRange.h" |
18 | #include "mlir/Interfaces/CallInterfaces.h" |
19 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
20 | #include "mlir/Support/LLVM.h" |
21 | #include "mlir/Support/LogicalResult.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/Support/Casting.h" |
24 | #include <cassert> |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::dataflow; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // AbstractSparseLattice |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { |
35 | AnalysisState::onUpdate(solver); |
36 | |
37 | // Push all users of the value to the queue. |
38 | for (Operation *user : point.get<Value>().getUsers()) |
39 | for (DataFlowAnalysis *analysis : useDefSubscribers) |
40 | solver->enqueue(item: {user, analysis}); |
41 | } |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // AbstractSparseForwardDataFlowAnalysis |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis( |
48 | DataFlowSolver &solver) |
49 | : DataFlowAnalysis(solver) { |
50 | registerPointKind<CFGEdge>(); |
51 | } |
52 | |
53 | LogicalResult |
54 | AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) { |
55 | // Mark the entry block arguments as having reached their pessimistic |
56 | // fixpoints. |
57 | for (Region ®ion : top->getRegions()) { |
58 | if (region.empty()) |
59 | continue; |
60 | for (Value argument : region.front().getArguments()) |
61 | setToEntryState(getLatticeElement(value: argument)); |
62 | } |
63 | |
64 | return initializeRecursively(op: top); |
65 | } |
66 | |
67 | LogicalResult |
68 | AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { |
69 | // Initialize the analysis by visiting every owner of an SSA value (all |
70 | // operations and blocks). |
71 | visitOperation(op); |
72 | for (Region ®ion : op->getRegions()) { |
73 | for (Block &block : region) { |
74 | getOrCreate<Executable>(point: &block)->blockContentSubscribe(analysis: this); |
75 | visitBlock(block: &block); |
76 | for (Operation &op : block) |
77 | if (failed(result: initializeRecursively(op: &op))) |
78 | return failure(); |
79 | } |
80 | } |
81 | |
82 | return success(); |
83 | } |
84 | |
85 | LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) { |
86 | if (Operation *op = llvm::dyn_cast_if_present<Operation *>(Val&: point)) |
87 | visitOperation(op); |
88 | else if (Block *block = llvm::dyn_cast_if_present<Block *>(Val&: point)) |
89 | visitBlock(block); |
90 | else |
91 | return failure(); |
92 | return success(); |
93 | } |
94 | |
95 | void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { |
96 | // Exit early on operations with no results. |
97 | if (op->getNumResults() == 0) |
98 | return; |
99 | |
100 | // If the containing block is not executable, bail out. |
101 | if (!getOrCreate<Executable>(point: op->getBlock())->isLive()) |
102 | return; |
103 | |
104 | // Get the result lattices. |
105 | SmallVector<AbstractSparseLattice *> resultLattices; |
106 | resultLattices.reserve(N: op->getNumResults()); |
107 | for (Value result : op->getResults()) { |
108 | AbstractSparseLattice *resultLattice = getLatticeElement(value: result); |
109 | resultLattices.push_back(Elt: resultLattice); |
110 | } |
111 | |
112 | // The results of a region branch operation are determined by control-flow. |
113 | if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { |
114 | return visitRegionSuccessors(point: {branch}, branch: branch, |
115 | /*successor=*/RegionBranchPoint::parent(), |
116 | lattices: resultLattices); |
117 | } |
118 | |
119 | // Grab the lattice elements of the operands. |
120 | SmallVector<const AbstractSparseLattice *> operandLattices; |
121 | operandLattices.reserve(N: op->getNumOperands()); |
122 | for (Value operand : op->getOperands()) { |
123 | AbstractSparseLattice *operandLattice = getLatticeElement(value: operand); |
124 | operandLattice->useDefSubscribe(analysis: this); |
125 | operandLattices.push_back(Elt: operandLattice); |
126 | } |
127 | |
128 | if (auto call = dyn_cast<CallOpInterface>(op)) { |
129 | // If the call operation is to an external function, attempt to infer the |
130 | // results from the call arguments. |
131 | auto callable = |
132 | dyn_cast_if_present<CallableOpInterface>(call.resolveCallable()); |
133 | if (!getSolverConfig().isInterprocedural() || |
134 | (callable && !callable.getCallableRegion())) { |
135 | return visitExternalCallImpl(call, operandLattices, resultLattices); |
136 | } |
137 | |
138 | // Otherwise, the results of a call operation are determined by the |
139 | // callgraph. |
140 | const auto *predecessors = getOrCreateFor<PredecessorState>(op, call); |
141 | // If not all return sites are known, then conservatively assume we can't |
142 | // reason about the data-flow. |
143 | if (!predecessors->allPredecessorsKnown()) |
144 | return setAllToEntryStates(resultLattices); |
145 | for (Operation *predecessor : predecessors->getKnownPredecessors()) |
146 | for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) |
147 | join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); |
148 | return; |
149 | } |
150 | |
151 | // Invoke the operation transfer function. |
152 | visitOperationImpl(op, operandLattices, resultLattices); |
153 | } |
154 | |
155 | void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { |
156 | // Exit early on blocks with no arguments. |
157 | if (block->getNumArguments() == 0) |
158 | return; |
159 | |
160 | // If the block is not executable, bail out. |
161 | if (!getOrCreate<Executable>(point: block)->isLive()) |
162 | return; |
163 | |
164 | // Get the argument lattices. |
165 | SmallVector<AbstractSparseLattice *> argLattices; |
166 | argLattices.reserve(N: block->getNumArguments()); |
167 | for (BlockArgument argument : block->getArguments()) { |
168 | AbstractSparseLattice *argLattice = getLatticeElement(value: argument); |
169 | argLattices.push_back(Elt: argLattice); |
170 | } |
171 | |
172 | // The argument lattices of entry blocks are set by region control-flow or the |
173 | // callgraph. |
174 | if (block->isEntryBlock()) { |
175 | // Check if this block is the entry block of a callable region. |
176 | auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); |
177 | if (callable && callable.getCallableRegion() == block->getParent()) { |
178 | const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); |
179 | // If not all callsites are known, conservatively mark all lattices as |
180 | // having reached their pessimistic fixpoints. |
181 | if (!callsites->allPredecessorsKnown() || |
182 | !getSolverConfig().isInterprocedural()) { |
183 | return setAllToEntryStates(argLattices); |
184 | } |
185 | for (Operation *callsite : callsites->getKnownPredecessors()) { |
186 | auto call = cast<CallOpInterface>(callsite); |
187 | for (auto it : llvm::zip(call.getArgOperands(), argLattices)) |
188 | join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); |
189 | } |
190 | return; |
191 | } |
192 | |
193 | // Check if the lattices can be determined from region control flow. |
194 | if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { |
195 | return visitRegionSuccessors(point: block, branch: branch, successor: block->getParent(), |
196 | lattices: argLattices); |
197 | } |
198 | |
199 | // Otherwise, we can't reason about the data-flow. |
200 | return visitNonControlFlowArgumentsImpl(op: block->getParentOp(), |
201 | successor: RegionSuccessor(block->getParent()), |
202 | argLattices, /*firstIndex=*/0); |
203 | } |
204 | |
205 | // Iterate over the predecessors of the non-entry block. |
206 | for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); |
207 | it != e; ++it) { |
208 | Block *predecessor = *it; |
209 | |
210 | // If the edge from the predecessor block to the current block is not live, |
211 | // bail out. |
212 | auto *edgeExecutable = |
213 | getOrCreate<Executable>(point: getProgramPoint<CFGEdge>(args&: predecessor, args&: block)); |
214 | edgeExecutable->blockContentSubscribe(analysis: this); |
215 | if (!edgeExecutable->isLive()) |
216 | continue; |
217 | |
218 | // Check if we can reason about the data-flow from the predecessor. |
219 | if (auto branch = |
220 | dyn_cast<BranchOpInterface>(predecessor->getTerminator())) { |
221 | SuccessorOperands operands = |
222 | branch.getSuccessorOperands(it.getSuccessorIndex()); |
223 | for (auto [idx, lattice] : llvm::enumerate(First&: argLattices)) { |
224 | if (Value operand = operands[idx]) { |
225 | join(lhs: lattice, rhs: *getLatticeElementFor(point: block, value: operand)); |
226 | } else { |
227 | // Conservatively consider internally produced arguments as entry |
228 | // points. |
229 | setAllToEntryStates(lattice); |
230 | } |
231 | } |
232 | } else { |
233 | return setAllToEntryStates(argLattices); |
234 | } |
235 | } |
236 | } |
237 | |
238 | void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( |
239 | ProgramPoint point, RegionBranchOpInterface branch, |
240 | RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) { |
241 | const auto *predecessors = getOrCreateFor<PredecessorState>(dependent: point, point); |
242 | assert(predecessors->allPredecessorsKnown() && |
243 | "unexpected unresolved region successors" ); |
244 | |
245 | for (Operation *op : predecessors->getKnownPredecessors()) { |
246 | // Get the incoming successor operands. |
247 | std::optional<OperandRange> operands; |
248 | |
249 | // Check if the predecessor is the parent op. |
250 | if (op == branch) { |
251 | operands = branch.getEntrySuccessorOperands(successor); |
252 | // Otherwise, try to deduce the operands from a region return-like op. |
253 | } else if (auto regionTerminator = |
254 | dyn_cast<RegionBranchTerminatorOpInterface>(op)) { |
255 | operands = regionTerminator.getSuccessorOperands(successor); |
256 | } |
257 | |
258 | if (!operands) { |
259 | // We can't reason about the data-flow. |
260 | return setAllToEntryStates(lattices); |
261 | } |
262 | |
263 | ValueRange inputs = predecessors->getSuccessorInputs(predecessor: op); |
264 | assert(inputs.size() == operands->size() && |
265 | "expected the same number of successor inputs as operands" ); |
266 | |
267 | unsigned firstIndex = 0; |
268 | if (inputs.size() != lattices.size()) { |
269 | if (llvm::dyn_cast_if_present<Operation *>(Val&: point)) { |
270 | if (!inputs.empty()) |
271 | firstIndex = cast<OpResult>(Val: inputs.front()).getResultNumber(); |
272 | visitNonControlFlowArgumentsImpl( |
273 | op: branch, |
274 | successor: RegionSuccessor( |
275 | branch->getResults().slice(firstIndex, inputs.size())), |
276 | argLattices: lattices, firstIndex); |
277 | } else { |
278 | if (!inputs.empty()) |
279 | firstIndex = cast<BlockArgument>(Val: inputs.front()).getArgNumber(); |
280 | Region *region = point.get<Block *>()->getParent(); |
281 | visitNonControlFlowArgumentsImpl( |
282 | op: branch, |
283 | successor: RegionSuccessor(region, region->getArguments().slice( |
284 | N: firstIndex, M: inputs.size())), |
285 | argLattices: lattices, firstIndex); |
286 | } |
287 | } |
288 | |
289 | for (auto it : llvm::zip(t&: *operands, u: lattices.drop_front(N: firstIndex))) |
290 | join(lhs: std::get<1>(t&: it), rhs: *getLatticeElementFor(point, value: std::get<0>(t&: it))); |
291 | } |
292 | } |
293 | |
294 | const AbstractSparseLattice * |
295 | AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, |
296 | Value value) { |
297 | AbstractSparseLattice *state = getLatticeElement(value); |
298 | addDependency(state, point); |
299 | return state; |
300 | } |
301 | |
302 | void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates( |
303 | ArrayRef<AbstractSparseLattice *> lattices) { |
304 | for (AbstractSparseLattice *lattice : lattices) |
305 | setToEntryState(lattice); |
306 | } |
307 | |
308 | void AbstractSparseForwardDataFlowAnalysis::join( |
309 | AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) { |
310 | propagateIfChanged(state: lhs, changed: lhs->join(rhs)); |
311 | } |
312 | |
313 | //===----------------------------------------------------------------------===// |
314 | // AbstractSparseBackwardDataFlowAnalysis |
315 | //===----------------------------------------------------------------------===// |
316 | |
317 | AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis( |
318 | DataFlowSolver &solver, SymbolTableCollection &symbolTable) |
319 | : DataFlowAnalysis(solver), symbolTable(symbolTable) { |
320 | registerPointKind<CFGEdge>(); |
321 | } |
322 | |
323 | LogicalResult |
324 | AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) { |
325 | return initializeRecursively(op: top); |
326 | } |
327 | |
328 | LogicalResult |
329 | AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { |
330 | visitOperation(op); |
331 | for (Region ®ion : op->getRegions()) { |
332 | for (Block &block : region) { |
333 | getOrCreate<Executable>(point: &block)->blockContentSubscribe(analysis: this); |
334 | // Initialize ops in reverse order, so we can do as much initial |
335 | // propagation as possible without having to go through the |
336 | // solver queue. |
337 | for (auto it = block.rbegin(); it != block.rend(); it++) |
338 | if (failed(result: initializeRecursively(op: &*it))) |
339 | return failure(); |
340 | } |
341 | } |
342 | return success(); |
343 | } |
344 | |
345 | LogicalResult |
346 | AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) { |
347 | if (Operation *op = llvm::dyn_cast_if_present<Operation *>(Val&: point)) |
348 | visitOperation(op); |
349 | else if (llvm::dyn_cast_if_present<Block *>(Val&: point)) |
350 | // For backward dataflow, we don't have to do any work for the blocks |
351 | // themselves. CFG edges between blocks are processed by the BranchOp |
352 | // logic in `visitOperation`, and entry blocks for functions are tied |
353 | // to the CallOp arguments by visitOperation. |
354 | return success(); |
355 | else |
356 | return failure(); |
357 | return success(); |
358 | } |
359 | |
360 | SmallVector<AbstractSparseLattice *> |
361 | AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) { |
362 | SmallVector<AbstractSparseLattice *> resultLattices; |
363 | resultLattices.reserve(N: values.size()); |
364 | for (Value result : values) { |
365 | AbstractSparseLattice *resultLattice = getLatticeElement(value: result); |
366 | resultLattices.push_back(Elt: resultLattice); |
367 | } |
368 | return resultLattices; |
369 | } |
370 | |
371 | SmallVector<const AbstractSparseLattice *> |
372 | AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor( |
373 | ProgramPoint point, ValueRange values) { |
374 | SmallVector<const AbstractSparseLattice *> resultLattices; |
375 | resultLattices.reserve(N: values.size()); |
376 | for (Value result : values) { |
377 | const AbstractSparseLattice *resultLattice = |
378 | getLatticeElementFor(point, value: result); |
379 | resultLattices.push_back(Elt: resultLattice); |
380 | } |
381 | return resultLattices; |
382 | } |
383 | |
384 | static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) { |
385 | return MutableArrayRef<OpOperand>(operands.getBase(), operands.size()); |
386 | } |
387 | |
388 | void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { |
389 | // If we're in a dead block, bail out. |
390 | if (!getOrCreate<Executable>(point: op->getBlock())->isLive()) |
391 | return; |
392 | |
393 | SmallVector<AbstractSparseLattice *> operandLattices = |
394 | getLatticeElements(values: op->getOperands()); |
395 | SmallVector<const AbstractSparseLattice *> resultLattices = |
396 | getLatticeElementsFor(point: op, values: op->getResults()); |
397 | |
398 | // Block arguments of region branch operations flow back into the operands |
399 | // of the parent op |
400 | if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { |
401 | visitRegionSuccessors(branch: branch, operands: operandLattices); |
402 | return; |
403 | } |
404 | |
405 | if (auto branch = dyn_cast<BranchOpInterface>(op)) { |
406 | // Block arguments of successor blocks flow back into our operands. |
407 | |
408 | // We remember all operands not forwarded to any block in a BitVector. |
409 | // We can't just cut out a range here, since the non-forwarded ops might |
410 | // be non-contiguous (if there's more than one successor). |
411 | BitVector unaccounted(op->getNumOperands(), true); |
412 | |
413 | for (auto [index, block] : llvm::enumerate(First: op->getSuccessors())) { |
414 | SuccessorOperands successorOperands = branch.getSuccessorOperands(index); |
415 | OperandRange forwarded = successorOperands.getForwardedOperands(); |
416 | if (!forwarded.empty()) { |
417 | MutableArrayRef<OpOperand> operands = op->getOpOperands().slice( |
418 | N: forwarded.getBeginOperandIndex(), M: forwarded.size()); |
419 | for (OpOperand &operand : operands) { |
420 | unaccounted.reset(operand.getOperandNumber()); |
421 | if (std::optional<BlockArgument> blockArg = |
422 | detail::getBranchSuccessorArgument( |
423 | successorOperands, operand.getOperandNumber(), block)) { |
424 | meet(getLatticeElement(operand.get()), |
425 | *getLatticeElementFor(op, *blockArg)); |
426 | } |
427 | } |
428 | } |
429 | } |
430 | // Operands not forwarded to successor blocks are typically parameters |
431 | // of the branch operation itself (for example the boolean for if/else). |
432 | for (int index : unaccounted.set_bits()) { |
433 | OpOperand &operand = op->getOpOperand(idx: index); |
434 | visitBranchOperand(operand); |
435 | } |
436 | return; |
437 | } |
438 | |
439 | // For function calls, connect the arguments of the entry blocks to the |
440 | // operands of the call op that are forwarded to these arguments. |
441 | if (auto call = dyn_cast<CallOpInterface>(op)) { |
442 | Operation *callableOp = call.resolveCallable(&symbolTable); |
443 | if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) { |
444 | // Not all operands of a call op forward to arguments. Such operands are |
445 | // stored in `unaccounted`. |
446 | BitVector unaccounted(op->getNumOperands(), true); |
447 | |
448 | // If the call invokes an external function (or a function treated as |
449 | // external due to config), defer to the corresponding extension hook. |
450 | // By default, it just does `visitCallOperand` for all operands. |
451 | OperandRange argOperands = call.getArgOperands(); |
452 | MutableArrayRef<OpOperand> argOpOperands = |
453 | operandsToOpOperands(operands&: argOperands); |
454 | Region *region = callable.getCallableRegion(); |
455 | if (!region || region->empty() || !getSolverConfig().isInterprocedural()) |
456 | return visitExternalCallImpl(call, operandLattices, resultLattices); |
457 | |
458 | // Otherwise, propagate information from the entry point of the function |
459 | // back to operands whenever possible. |
460 | Block &block = region->front(); |
461 | for (auto [blockArg, argOpOperand] : |
462 | llvm::zip(block.getArguments(), argOpOperands)) { |
463 | meet(getLatticeElement(argOpOperand.get()), |
464 | *getLatticeElementFor(op, blockArg)); |
465 | unaccounted.reset(argOpOperand.getOperandNumber()); |
466 | } |
467 | |
468 | // Handle the operands of the call op that aren't forwarded to any |
469 | // arguments. |
470 | for (int index : unaccounted.set_bits()) { |
471 | OpOperand &opOperand = op->getOpOperand(idx: index); |
472 | visitCallOperand(operand&: opOperand); |
473 | } |
474 | return; |
475 | } |
476 | } |
477 | |
478 | // When the region of an op implementing `RegionBranchOpInterface` has a |
479 | // terminator implementing `RegionBranchTerminatorOpInterface` or a |
480 | // return-like terminator, the region's successors' arguments flow back into |
481 | // the "successor operands" of this terminator. |
482 | // |
483 | // A successor operand with respect to an op implementing |
484 | // `RegionBranchOpInterface` is an operand that is forwarded to a region |
485 | // successor's input. There are two types of successor operands: the operands |
486 | // of this op itself and the operands of the terminators of the regions of |
487 | // this op. |
488 | if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) { |
489 | if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { |
490 | visitRegionSuccessorsFromTerminator(terminator, branch); |
491 | return; |
492 | } |
493 | } |
494 | |
495 | if (op->hasTrait<OpTrait::ReturnLike>()) { |
496 | // Going backwards, the operands of the return are derived from the |
497 | // results of all CallOps calling this CallableOp. |
498 | if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { |
499 | const PredecessorState *callsites = |
500 | getOrCreateFor<PredecessorState>(op, callable); |
501 | if (callsites->allPredecessorsKnown()) { |
502 | for (Operation *call : callsites->getKnownPredecessors()) { |
503 | SmallVector<const AbstractSparseLattice *> callResultLattices = |
504 | getLatticeElementsFor(op, call->getResults()); |
505 | for (auto [op, result] : |
506 | llvm::zip(operandLattices, callResultLattices)) |
507 | meet(op, *result); |
508 | } |
509 | } else { |
510 | // If we don't know all the callers, we can't know where the |
511 | // returned values go. Note that, in particular, this will trigger |
512 | // for the return ops of any public functions. |
513 | setAllToExitStates(operandLattices); |
514 | } |
515 | return; |
516 | } |
517 | } |
518 | |
519 | visitOperationImpl(op, operandLattices, resultLattices); |
520 | } |
521 | |
522 | void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( |
523 | RegionBranchOpInterface branch, |
524 | ArrayRef<AbstractSparseLattice *> operandLattices) { |
525 | Operation *op = branch.getOperation(); |
526 | SmallVector<RegionSuccessor> successors; |
527 | SmallVector<Attribute> operands(op->getNumOperands(), nullptr); |
528 | branch.getEntrySuccessorRegions(operands, successors); |
529 | |
530 | // All operands not forwarded to any successor. This set can be non-contiguous |
531 | // in the presence of multiple successors. |
532 | BitVector unaccounted(op->getNumOperands(), true); |
533 | |
534 | for (RegionSuccessor &successor : successors) { |
535 | OperandRange operands = branch.getEntrySuccessorOperands(successor); |
536 | MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands); |
537 | ValueRange inputs = successor.getSuccessorInputs(); |
538 | for (auto [operand, input] : llvm::zip(opoperands, inputs)) { |
539 | meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input)); |
540 | unaccounted.reset(operand.getOperandNumber()); |
541 | } |
542 | } |
543 | // All operands not forwarded to regions are typically parameters of the |
544 | // branch operation itself (for example the boolean for if/else). |
545 | for (int index : unaccounted.set_bits()) { |
546 | visitBranchOperand(op->getOpOperand(index)); |
547 | } |
548 | } |
549 | |
550 | void AbstractSparseBackwardDataFlowAnalysis:: |
551 | visitRegionSuccessorsFromTerminator( |
552 | RegionBranchTerminatorOpInterface terminator, |
553 | RegionBranchOpInterface branch) { |
554 | assert(isa<RegionBranchTerminatorOpInterface>(terminator) && |
555 | "expected a `RegionBranchTerminatorOpInterface` op" ); |
556 | assert(terminator->getParentOp() == branch.getOperation() && |
557 | "expected `branch` to be the parent op of `terminator`" ); |
558 | |
559 | SmallVector<Attribute> operandAttributes(terminator->getNumOperands(), |
560 | nullptr); |
561 | SmallVector<RegionSuccessor> successors; |
562 | terminator.getSuccessorRegions(operandAttributes, successors); |
563 | // All operands not forwarded to any successor. This set can be |
564 | // non-contiguous in the presence of multiple successors. |
565 | BitVector unaccounted(terminator->getNumOperands(), true); |
566 | |
567 | for (const RegionSuccessor &successor : successors) { |
568 | ValueRange inputs = successor.getSuccessorInputs(); |
569 | OperandRange operands = terminator.getSuccessorOperands(successor); |
570 | MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands); |
571 | for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) { |
572 | meet(getLatticeElement(opOperand.get()), |
573 | *getLatticeElementFor(terminator, input)); |
574 | unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber()); |
575 | } |
576 | } |
577 | // Visit operands of the branch op not forwarded to the next region. |
578 | // (Like e.g. the boolean of `scf.conditional`) |
579 | for (int index : unaccounted.set_bits()) { |
580 | visitBranchOperand(terminator->getOpOperand(index)); |
581 | } |
582 | } |
583 | |
584 | const AbstractSparseLattice * |
585 | AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, |
586 | Value value) { |
587 | AbstractSparseLattice *state = getLatticeElement(value); |
588 | addDependency(state, point); |
589 | return state; |
590 | } |
591 | |
592 | void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates( |
593 | ArrayRef<AbstractSparseLattice *> lattices) { |
594 | for (AbstractSparseLattice *lattice : lattices) |
595 | setToExitState(lattice); |
596 | } |
597 | |
598 | void AbstractSparseBackwardDataFlowAnalysis::meet( |
599 | AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) { |
600 | propagateIfChanged(state: lhs, changed: lhs->meet(rhs)); |
601 | } |
602 | |