1//===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11#include "mlir/Analysis/DataFlowFramework.h"
12#include "mlir/IR/Attributes.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/IR/Region.h"
15#include "mlir/IR/SymbolTable.h"
16#include "mlir/IR/Value.h"
17#include "mlir/IR/ValueRange.h"
18#include "mlir/Interfaces/CallInterfaces.h"
19#include "mlir/Interfaces/ControlFlowInterfaces.h"
20#include "mlir/Support/LLVM.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/Debug.h"
23#include <cassert>
24#include <optional>
25
26using namespace mlir;
27using namespace mlir::dataflow;
28
29#define DEBUG_TYPE "dataflow"
30
31//===----------------------------------------------------------------------===//
32// AbstractSparseLattice
33//===----------------------------------------------------------------------===//
34
35void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
36 AnalysisState::onUpdate(solver);
37
38 // Push all users of the value to the queue.
39 for (Operation *user : cast<Value>(Val: anchor).getUsers())
40 for (DataFlowAnalysis *analysis : useDefSubscribers)
41 solver->enqueue(item: {solver->getProgramPointAfter(op: user), analysis});
42}
43
44//===----------------------------------------------------------------------===//
45// AbstractSparseForwardDataFlowAnalysis
46//===----------------------------------------------------------------------===//
47
48AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
49 DataFlowSolver &solver)
50 : DataFlowAnalysis(solver) {
51 registerAnchorKind<CFGEdge>();
52}
53
54LogicalResult
55AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
56 // Mark the entry block arguments as having reached their pessimistic
57 // fixpoints.
58 for (Region &region : top->getRegions()) {
59 if (region.empty())
60 continue;
61 for (Value argument : region.front().getArguments())
62 setToEntryState(getLatticeElement(value: argument));
63 }
64
65 return initializeRecursively(op: top);
66}
67
68LogicalResult
69AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
70 LLVM_DEBUG({
71 llvm::dbgs() << "Initializing recursively for operation: " << op->getName()
72 << "\n";
73 });
74
75 // Initialize the analysis by visiting every owner of an SSA value (all
76 // operations and blocks).
77 if (failed(Result: visitOperation(op))) {
78 LLVM_DEBUG({
79 llvm::dbgs() << "Failed to visit operation: " << op->getName() << "\n";
80 });
81 return failure();
82 }
83
84 for (Region &region : op->getRegions()) {
85 LLVM_DEBUG({
86 llvm::dbgs() << "Processing region with " << region.getBlocks().size()
87 << " blocks"
88 << "\n";
89 });
90 for (Block &block : region) {
91 LLVM_DEBUG({
92 llvm::dbgs() << "Processing block with " << block.getNumArguments()
93 << " arguments"
94 << "\n";
95 });
96 getOrCreate<Executable>(anchor: getProgramPointBefore(block: &block))
97 ->blockContentSubscribe(analysis: this);
98 visitBlock(block: &block);
99 for (Operation &op : block) {
100 LLVM_DEBUG({
101 llvm::dbgs() << "Recursively initializing nested operation: "
102 << op.getName() << "\n";
103 });
104 if (failed(Result: initializeRecursively(op: &op))) {
105 LLVM_DEBUG({
106 llvm::dbgs() << "Failed to initialize nested operation: "
107 << op.getName() << "\n";
108 });
109 return failure();
110 }
111 }
112 }
113 }
114
115 LLVM_DEBUG({
116 llvm::dbgs()
117 << "Successfully completed recursive initialization for operation: "
118 << op->getName() << "\n";
119 });
120 return success();
121}
122
123LogicalResult
124AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
125 if (!point->isBlockStart())
126 return visitOperation(op: point->getPrevOp());
127 visitBlock(block: point->getBlock());
128 return success();
129}
130
131LogicalResult
132AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
133 // Exit early on operations with no results.
134 if (op->getNumResults() == 0)
135 return success();
136
137 // If the containing block is not executable, bail out.
138 if (op->getBlock() != nullptr &&
139 !getOrCreate<Executable>(anchor: getProgramPointBefore(block: op->getBlock()))->isLive())
140 return success();
141
142 // Get the result lattices.
143 SmallVector<AbstractSparseLattice *> resultLattices;
144 resultLattices.reserve(N: op->getNumResults());
145 for (Value result : op->getResults()) {
146 AbstractSparseLattice *resultLattice = getLatticeElement(value: result);
147 resultLattices.push_back(Elt: resultLattice);
148 }
149
150 // The results of a region branch operation are determined by control-flow.
151 if (auto branch = dyn_cast<RegionBranchOpInterface>(Val: op)) {
152 visitRegionSuccessors(point: getProgramPointAfter(op: branch), branch,
153 /*successor=*/RegionBranchPoint::parent(),
154 lattices: resultLattices);
155 return success();
156 }
157
158 // Grab the lattice elements of the operands.
159 SmallVector<const AbstractSparseLattice *> operandLattices;
160 operandLattices.reserve(N: op->getNumOperands());
161 for (Value operand : op->getOperands()) {
162 AbstractSparseLattice *operandLattice = getLatticeElement(value: operand);
163 operandLattice->useDefSubscribe(analysis: this);
164 operandLattices.push_back(Elt: operandLattice);
165 }
166
167 if (auto call = dyn_cast<CallOpInterface>(Val: op))
168 return visitCallOperation(call, operandLattices, resultLattices);
169
170 // Invoke the operation transfer function.
171 return visitOperationImpl(op, operandLattices, resultLattices);
172}
173
174void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
175 // Exit early on blocks with no arguments.
176 if (block->getNumArguments() == 0)
177 return;
178
179 // If the block is not executable, bail out.
180 if (!getOrCreate<Executable>(anchor: getProgramPointBefore(block))->isLive())
181 return;
182
183 // Get the argument lattices.
184 SmallVector<AbstractSparseLattice *> argLattices;
185 argLattices.reserve(N: block->getNumArguments());
186 for (BlockArgument argument : block->getArguments()) {
187 AbstractSparseLattice *argLattice = getLatticeElement(value: argument);
188 argLattices.push_back(Elt: argLattice);
189 }
190
191 // The argument lattices of entry blocks are set by region control-flow or the
192 // callgraph.
193 if (block->isEntryBlock()) {
194 // Check if this block is the entry block of a callable region.
195 auto callable = dyn_cast<CallableOpInterface>(Val: block->getParentOp());
196 if (callable && callable.getCallableRegion() == block->getParent())
197 return visitCallableOperation(callable, argLattices);
198
199 // Check if the lattices can be determined from region control flow.
200 if (auto branch = dyn_cast<RegionBranchOpInterface>(Val: block->getParentOp())) {
201 return visitRegionSuccessors(point: getProgramPointBefore(block), branch,
202 successor: block->getParent(), lattices: argLattices);
203 }
204
205 // Otherwise, we can't reason about the data-flow.
206 return visitNonControlFlowArgumentsImpl(op: block->getParentOp(),
207 successor: RegionSuccessor(block->getParent()),
208 argLattices, /*firstIndex=*/0);
209 }
210
211 // Iterate over the predecessors of the non-entry block.
212 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
213 it != e; ++it) {
214 Block *predecessor = *it;
215
216 // If the edge from the predecessor block to the current block is not live,
217 // bail out.
218 auto *edgeExecutable =
219 getOrCreate<Executable>(anchor: getLatticeAnchor<CFGEdge>(args&: predecessor, args&: block));
220 edgeExecutable->blockContentSubscribe(analysis: this);
221 if (!edgeExecutable->isLive())
222 continue;
223
224 // Check if we can reason about the data-flow from the predecessor.
225 if (auto branch =
226 dyn_cast<BranchOpInterface>(Val: predecessor->getTerminator())) {
227 SuccessorOperands operands =
228 branch.getSuccessorOperands(index: it.getSuccessorIndex());
229 for (auto [idx, lattice] : llvm::enumerate(First&: argLattices)) {
230 if (Value operand = operands[idx]) {
231 join(lhs: lattice,
232 rhs: *getLatticeElementFor(point: getProgramPointBefore(block), value: operand));
233 } else {
234 // Conservatively consider internally produced arguments as entry
235 // points.
236 setAllToEntryStates(lattice);
237 }
238 }
239 } else {
240 return setAllToEntryStates(argLattices);
241 }
242 }
243}
244
245LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
246 CallOpInterface call,
247 ArrayRef<const AbstractSparseLattice *> operandLattices,
248 ArrayRef<AbstractSparseLattice *> resultLattices) {
249 // If the call operation is to an external function, attempt to infer the
250 // results from the call arguments.
251 auto callable =
252 dyn_cast_if_present<CallableOpInterface>(Val: call.resolveCallable());
253 if (!getSolverConfig().isInterprocedural() ||
254 (callable && !callable.getCallableRegion())) {
255 visitExternalCallImpl(call, argumentLattices: operandLattices, resultLattices);
256 return success();
257 }
258
259 // Otherwise, the results of a call operation are determined by the
260 // callgraph.
261 const auto *predecessors = getOrCreateFor<PredecessorState>(
262 dependent: getProgramPointAfter(op: call), anchor: getProgramPointAfter(op: call));
263 // If not all return sites are known, then conservatively assume we can't
264 // reason about the data-flow.
265 if (!predecessors->allPredecessorsKnown()) {
266 setAllToEntryStates(resultLattices);
267 return success();
268 }
269 for (Operation *predecessor : predecessors->getKnownPredecessors())
270 for (auto &&[operand, resLattice] :
271 llvm::zip(t: predecessor->getOperands(), u&: resultLattices))
272 join(lhs: resLattice,
273 rhs: *getLatticeElementFor(point: getProgramPointAfter(op: call), value: operand));
274 return success();
275}
276
277void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
278 CallableOpInterface callable,
279 ArrayRef<AbstractSparseLattice *> argLattices) {
280 Block *entryBlock = &callable.getCallableRegion()->front();
281 const auto *callsites = getOrCreateFor<PredecessorState>(
282 dependent: getProgramPointBefore(block: entryBlock), anchor: getProgramPointAfter(op: callable));
283 // If not all callsites are known, conservatively mark all lattices as
284 // having reached their pessimistic fixpoints.
285 if (!callsites->allPredecessorsKnown() ||
286 !getSolverConfig().isInterprocedural()) {
287 return setAllToEntryStates(argLattices);
288 }
289 for (Operation *callsite : callsites->getKnownPredecessors()) {
290 auto call = cast<CallOpInterface>(Val: callsite);
291 for (auto it : llvm::zip(t: call.getArgOperands(), u&: argLattices))
292 join(lhs: std::get<1>(t&: it),
293 rhs: *getLatticeElementFor(point: getProgramPointBefore(block: entryBlock),
294 value: std::get<0>(t&: it)));
295 }
296}
297
298void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
299 ProgramPoint *point, RegionBranchOpInterface branch,
300 RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
301 const auto *predecessors = getOrCreateFor<PredecessorState>(dependent: point, anchor: point);
302 assert(predecessors->allPredecessorsKnown() &&
303 "unexpected unresolved region successors");
304
305 for (Operation *op : predecessors->getKnownPredecessors()) {
306 // Get the incoming successor operands.
307 std::optional<OperandRange> operands;
308
309 // Check if the predecessor is the parent op.
310 if (op == branch) {
311 operands = branch.getEntrySuccessorOperands(point: successor);
312 // Otherwise, try to deduce the operands from a region return-like op.
313 } else if (auto regionTerminator =
314 dyn_cast<RegionBranchTerminatorOpInterface>(Val: op)) {
315 operands = regionTerminator.getSuccessorOperands(point: successor);
316 }
317
318 if (!operands) {
319 // We can't reason about the data-flow.
320 return setAllToEntryStates(lattices);
321 }
322
323 ValueRange inputs = predecessors->getSuccessorInputs(predecessor: op);
324 assert(inputs.size() == operands->size() &&
325 "expected the same number of successor inputs as operands");
326
327 unsigned firstIndex = 0;
328 if (inputs.size() != lattices.size()) {
329 if (!point->isBlockStart()) {
330 if (!inputs.empty())
331 firstIndex = cast<OpResult>(Val: inputs.front()).getResultNumber();
332 visitNonControlFlowArgumentsImpl(
333 op: branch,
334 successor: RegionSuccessor(
335 branch->getResults().slice(n: firstIndex, m: inputs.size())),
336 argLattices: lattices, firstIndex);
337 } else {
338 if (!inputs.empty())
339 firstIndex = cast<BlockArgument>(Val: inputs.front()).getArgNumber();
340 Region *region = point->getBlock()->getParent();
341 visitNonControlFlowArgumentsImpl(
342 op: branch,
343 successor: RegionSuccessor(region, region->getArguments().slice(
344 N: firstIndex, M: inputs.size())),
345 argLattices: lattices, firstIndex);
346 }
347 }
348
349 for (auto it : llvm::zip(t&: *operands, u: lattices.drop_front(N: firstIndex)))
350 join(lhs: std::get<1>(t&: it), rhs: *getLatticeElementFor(point, value: std::get<0>(t&: it)));
351 }
352}
353
354const AbstractSparseLattice *
355AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point,
356 Value value) {
357 AbstractSparseLattice *state = getLatticeElement(value);
358 addDependency(state, point);
359 return state;
360}
361
362void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
363 ArrayRef<AbstractSparseLattice *> lattices) {
364 for (AbstractSparseLattice *lattice : lattices)
365 setToEntryState(lattice);
366}
367
368void AbstractSparseForwardDataFlowAnalysis::join(
369 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
370 propagateIfChanged(state: lhs, changed: lhs->join(rhs));
371}
372
373//===----------------------------------------------------------------------===//
374// AbstractSparseBackwardDataFlowAnalysis
375//===----------------------------------------------------------------------===//
376
377AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
378 DataFlowSolver &solver, SymbolTableCollection &symbolTable)
379 : DataFlowAnalysis(solver), symbolTable(symbolTable) {
380 registerAnchorKind<CFGEdge>();
381}
382
383LogicalResult
384AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
385 return initializeRecursively(op: top);
386}
387
388LogicalResult
389AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
390 if (failed(Result: visitOperation(op)))
391 return failure();
392
393 for (Region &region : op->getRegions()) {
394 for (Block &block : region) {
395 getOrCreate<Executable>(anchor: getProgramPointBefore(block: &block))
396 ->blockContentSubscribe(analysis: this);
397 // Initialize ops in reverse order, so we can do as much initial
398 // propagation as possible without having to go through the
399 // solver queue.
400 for (auto it = block.rbegin(); it != block.rend(); it++)
401 if (failed(Result: initializeRecursively(op: &*it)))
402 return failure();
403 }
404 }
405 return success();
406}
407
408LogicalResult
409AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
410 // For backward dataflow, we don't have to do any work for the blocks
411 // themselves. CFG edges between blocks are processed by the BranchOp
412 // logic in `visitOperation`, and entry blocks for functions are tied
413 // to the CallOp arguments by visitOperation.
414 if (point->isBlockStart())
415 return success();
416 return visitOperation(op: point->getPrevOp());
417}
418
419SmallVector<AbstractSparseLattice *>
420AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
421 SmallVector<AbstractSparseLattice *> resultLattices;
422 resultLattices.reserve(N: values.size());
423 for (Value result : values) {
424 AbstractSparseLattice *resultLattice = getLatticeElement(value: result);
425 resultLattices.push_back(Elt: resultLattice);
426 }
427 return resultLattices;
428}
429
430SmallVector<const AbstractSparseLattice *>
431AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
432 ProgramPoint *point, ValueRange values) {
433 SmallVector<const AbstractSparseLattice *> resultLattices;
434 resultLattices.reserve(N: values.size());
435 for (Value result : values) {
436 const AbstractSparseLattice *resultLattice =
437 getLatticeElementFor(point, value: result);
438 resultLattices.push_back(Elt: resultLattice);
439 }
440 return resultLattices;
441}
442
443static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
444 return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
445}
446
447LogicalResult
448AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
449 LLVM_DEBUG({
450 llvm::dbgs() << "Visiting operation: " << op->getName() << " with "
451 << op->getNumOperands() << " operands and "
452 << op->getNumResults() << " results"
453 << "\n";
454 });
455
456 // If we're in a dead block, bail out.
457 if (op->getBlock() != nullptr &&
458 !getOrCreate<Executable>(anchor: getProgramPointBefore(block: op->getBlock()))
459 ->isLive()) {
460 LLVM_DEBUG({
461 llvm::dbgs() << "Operation is in dead block, bailing out"
462 << "\n";
463 });
464 return success();
465 }
466
467 LLVM_DEBUG({
468 llvm::dbgs() << "Creating lattice elements for " << op->getNumOperands()
469 << " operands and " << op->getNumResults() << " results"
470 << "\n";
471 });
472 SmallVector<AbstractSparseLattice *> operandLattices =
473 getLatticeElements(values: op->getOperands());
474 SmallVector<const AbstractSparseLattice *> resultLattices =
475 getLatticeElementsFor(point: getProgramPointAfter(op), values: op->getResults());
476
477 // Block arguments of region branch operations flow back into the operands
478 // of the parent op
479 if (auto branch = dyn_cast<RegionBranchOpInterface>(Val: op)) {
480 LLVM_DEBUG({
481 llvm::dbgs() << "Processing RegionBranchOpInterface operation"
482 << "\n";
483 });
484 visitRegionSuccessors(branch, operands: operandLattices);
485 return success();
486 }
487
488 if (auto branch = dyn_cast<BranchOpInterface>(Val: op)) {
489 LLVM_DEBUG({
490 llvm::dbgs() << "Processing BranchOpInterface operation with "
491 << op->getNumSuccessors() << " successors"
492 << "\n";
493 });
494
495 // Block arguments of successor blocks flow back into our operands.
496
497 // We remember all operands not forwarded to any block in a BitVector.
498 // We can't just cut out a range here, since the non-forwarded ops might
499 // be non-contiguous (if there's more than one successor).
500 BitVector unaccounted(op->getNumOperands(), true);
501
502 for (auto [index, block] : llvm::enumerate(First: op->getSuccessors())) {
503 SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
504 OperandRange forwarded = successorOperands.getForwardedOperands();
505 if (!forwarded.empty()) {
506 MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
507 N: forwarded.getBeginOperandIndex(), M: forwarded.size());
508 for (OpOperand &operand : operands) {
509 unaccounted.reset(Idx: operand.getOperandNumber());
510 if (std::optional<BlockArgument> blockArg =
511 detail::getBranchSuccessorArgument(
512 operands: successorOperands, operandIndex: operand.getOperandNumber(), successor: block)) {
513 meet(lhs: getLatticeElement(value: operand.get()),
514 rhs: *getLatticeElementFor(point: getProgramPointAfter(op), value: *blockArg));
515 }
516 }
517 }
518 }
519 // Operands not forwarded to successor blocks are typically parameters
520 // of the branch operation itself (for example the boolean for if/else).
521 for (int index : unaccounted.set_bits()) {
522 OpOperand &operand = op->getOpOperand(idx: index);
523 visitBranchOperand(operand);
524 }
525 return success();
526 }
527
528 // For function calls, connect the arguments of the entry blocks to the
529 // operands of the call op that are forwarded to these arguments.
530 if (auto call = dyn_cast<CallOpInterface>(Val: op)) {
531 LLVM_DEBUG({
532 llvm::dbgs() << "Processing CallOpInterface operation"
533 << "\n";
534 });
535 Operation *callableOp = call.resolveCallableInTable(symbolTable: &symbolTable);
536 if (auto callable = dyn_cast_or_null<CallableOpInterface>(Val: callableOp)) {
537 // Not all operands of a call op forward to arguments. Such operands are
538 // stored in `unaccounted`.
539 BitVector unaccounted(op->getNumOperands(), true);
540
541 // If the call invokes an external function (or a function treated as
542 // external due to config), defer to the corresponding extension hook.
543 // By default, it just does `visitCallOperand` for all operands.
544 OperandRange argOperands = call.getArgOperands();
545 MutableArrayRef<OpOperand> argOpOperands =
546 operandsToOpOperands(operands&: argOperands);
547 Region *region = callable.getCallableRegion();
548 if (!region || region->empty() ||
549 !getSolverConfig().isInterprocedural()) {
550 visitExternalCallImpl(call, operandLattices, resultLattices);
551 return success();
552 }
553
554 // Otherwise, propagate information from the entry point of the function
555 // back to operands whenever possible.
556 Block &block = region->front();
557 for (auto [blockArg, argOpOperand] :
558 llvm::zip(t: block.getArguments(), u&: argOpOperands)) {
559 meet(lhs: getLatticeElement(value: argOpOperand.get()),
560 rhs: *getLatticeElementFor(point: getProgramPointAfter(op), value: blockArg));
561 unaccounted.reset(Idx: argOpOperand.getOperandNumber());
562 }
563
564 // Handle the operands of the call op that aren't forwarded to any
565 // arguments.
566 for (int index : unaccounted.set_bits()) {
567 OpOperand &opOperand = op->getOpOperand(idx: index);
568 visitCallOperand(operand&: opOperand);
569 }
570 return success();
571 }
572 }
573
574 // When the region of an op implementing `RegionBranchOpInterface` has a
575 // terminator implementing `RegionBranchTerminatorOpInterface` or a
576 // return-like terminator, the region's successors' arguments flow back into
577 // the "successor operands" of this terminator.
578 //
579 // A successor operand with respect to an op implementing
580 // `RegionBranchOpInterface` is an operand that is forwarded to a region
581 // successor's input. There are two types of successor operands: the operands
582 // of this op itself and the operands of the terminators of the regions of
583 // this op.
584 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(Val: op)) {
585 LLVM_DEBUG({
586 llvm::dbgs() << "Processing RegionBranchTerminatorOpInterface operation"
587 << "\n";
588 });
589 if (auto branch = dyn_cast<RegionBranchOpInterface>(Val: op->getParentOp())) {
590 visitRegionSuccessorsFromTerminator(terminator, branch);
591 return success();
592 }
593 }
594
595 if (op->hasTrait<OpTrait::ReturnLike>()) {
596 LLVM_DEBUG({
597 llvm::dbgs() << "Processing ReturnLike operation"
598 << "\n";
599 });
600 // Going backwards, the operands of the return are derived from the
601 // results of all CallOps calling this CallableOp.
602 if (auto callable = dyn_cast<CallableOpInterface>(Val: op->getParentOp())) {
603 LLVM_DEBUG({
604 llvm::dbgs() << "Callable parent found, visiting callable operation"
605 << "\n";
606 });
607 return visitCallableOperation(op, callable, operandLattices);
608 }
609 }
610
611 LLVM_DEBUG({
612 llvm::dbgs() << "Using default visitOperationImpl for operation: "
613 << op->getName() << "\n";
614 });
615 return visitOperationImpl(op, operandLattices, resultLattices);
616}
617
618LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
619 Operation *op, CallableOpInterface callable,
620 ArrayRef<AbstractSparseLattice *> operandLattices) {
621 const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
622 dependent: getProgramPointAfter(op), anchor: getProgramPointAfter(op: callable));
623 if (callsites->allPredecessorsKnown()) {
624 for (Operation *call : callsites->getKnownPredecessors()) {
625 SmallVector<const AbstractSparseLattice *> callResultLattices =
626 getLatticeElementsFor(point: getProgramPointAfter(op), values: call->getResults());
627 for (auto [op, result] : llvm::zip(t&: operandLattices, u&: callResultLattices))
628 meet(lhs: op, rhs: *result);
629 }
630 } else {
631 // If we don't know all the callers, we can't know where the
632 // returned values go. Note that, in particular, this will trigger
633 // for the return ops of any public functions.
634 setAllToExitStates(operandLattices);
635 }
636 return success();
637}
638
639void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
640 RegionBranchOpInterface branch,
641 ArrayRef<AbstractSparseLattice *> operandLattices) {
642 Operation *op = branch.getOperation();
643 SmallVector<RegionSuccessor> successors;
644 SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
645 branch.getEntrySuccessorRegions(operands, regions&: successors);
646
647 // All operands not forwarded to any successor. This set can be non-contiguous
648 // in the presence of multiple successors.
649 BitVector unaccounted(op->getNumOperands(), true);
650
651 for (RegionSuccessor &successor : successors) {
652 OperandRange operands = branch.getEntrySuccessorOperands(point: successor);
653 MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
654 ValueRange inputs = successor.getSuccessorInputs();
655 for (auto [operand, input] : llvm::zip(t&: opoperands, u&: inputs)) {
656 meet(lhs: getLatticeElement(value: operand.get()),
657 rhs: *getLatticeElementFor(point: getProgramPointAfter(op), value: input));
658 unaccounted.reset(Idx: operand.getOperandNumber());
659 }
660 }
661 // All operands not forwarded to regions are typically parameters of the
662 // branch operation itself (for example the boolean for if/else).
663 for (int index : unaccounted.set_bits()) {
664 visitBranchOperand(operand&: op->getOpOperand(idx: index));
665 }
666}
667
668void AbstractSparseBackwardDataFlowAnalysis::
669 visitRegionSuccessorsFromTerminator(
670 RegionBranchTerminatorOpInterface terminator,
671 RegionBranchOpInterface branch) {
672 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
673 "expected a `RegionBranchTerminatorOpInterface` op");
674 assert(terminator->getParentOp() == branch.getOperation() &&
675 "expected `branch` to be the parent op of `terminator`");
676
677 SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
678 nullptr);
679 SmallVector<RegionSuccessor> successors;
680 terminator.getSuccessorRegions(operands: operandAttributes, regions&: successors);
681 // All operands not forwarded to any successor. This set can be
682 // non-contiguous in the presence of multiple successors.
683 BitVector unaccounted(terminator->getNumOperands(), true);
684
685 for (const RegionSuccessor &successor : successors) {
686 ValueRange inputs = successor.getSuccessorInputs();
687 OperandRange operands = terminator.getSuccessorOperands(point: successor);
688 MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
689 for (auto [opOperand, input] : llvm::zip(t&: opOperands, u&: inputs)) {
690 meet(lhs: getLatticeElement(value: opOperand.get()),
691 rhs: *getLatticeElementFor(point: getProgramPointAfter(op: terminator), value: input));
692 unaccounted.reset(Idx: const_cast<OpOperand &>(opOperand).getOperandNumber());
693 }
694 }
695 // Visit operands of the branch op not forwarded to the next region.
696 // (Like e.g. the boolean of `scf.conditional`)
697 for (int index : unaccounted.set_bits()) {
698 visitBranchOperand(operand&: terminator->getOpOperand(idx: index));
699 }
700}
701
702const AbstractSparseLattice *
703AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
704 ProgramPoint *point, Value value) {
705 AbstractSparseLattice *state = getLatticeElement(value);
706 addDependency(state, point);
707 return state;
708}
709
710void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
711 ArrayRef<AbstractSparseLattice *> lattices) {
712 for (AbstractSparseLattice *lattice : lattices)
713 setToExitState(lattice);
714}
715
716void AbstractSparseBackwardDataFlowAnalysis::meet(
717 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
718 propagateIfChanged(state: lhs, changed: lhs->meet(rhs));
719}
720

source code of mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp