1//===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11#include "mlir/Analysis/DataFlowFramework.h"
12#include "mlir/IR/Attributes.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/IR/Region.h"
15#include "mlir/IR/SymbolTable.h"
16#include "mlir/IR/Value.h"
17#include "mlir/IR/ValueRange.h"
18#include "mlir/Interfaces/CallInterfaces.h"
19#include "mlir/Interfaces/ControlFlowInterfaces.h"
20#include "mlir/Support/LLVM.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/Casting.h"
23#include <cassert>
24#include <optional>
25
26using namespace mlir;
27using namespace mlir::dataflow;
28
29//===----------------------------------------------------------------------===//
30// AbstractSparseLattice
31//===----------------------------------------------------------------------===//
32
33void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
34 AnalysisState::onUpdate(solver);
35
36 // Push all users of the value to the queue.
37 for (Operation *user : cast<Value>(Val: anchor).getUsers())
38 for (DataFlowAnalysis *analysis : useDefSubscribers)
39 solver->enqueue(item: {solver->getProgramPointAfter(op: user), analysis});
40}
41
42//===----------------------------------------------------------------------===//
43// AbstractSparseForwardDataFlowAnalysis
44//===----------------------------------------------------------------------===//
45
46AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
47 DataFlowSolver &solver)
48 : DataFlowAnalysis(solver) {
49 registerAnchorKind<CFGEdge>();
50}
51
52LogicalResult
53AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
54 // Mark the entry block arguments as having reached their pessimistic
55 // fixpoints.
56 for (Region &region : top->getRegions()) {
57 if (region.empty())
58 continue;
59 for (Value argument : region.front().getArguments())
60 setToEntryState(getLatticeElement(value: argument));
61 }
62
63 return initializeRecursively(op: top);
64}
65
66LogicalResult
67AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
68 // Initialize the analysis by visiting every owner of an SSA value (all
69 // operations and blocks).
70 if (failed(Result: visitOperation(op)))
71 return failure();
72
73 for (Region &region : op->getRegions()) {
74 for (Block &block : region) {
75 getOrCreate<Executable>(anchor: getProgramPointBefore(block: &block))
76 ->blockContentSubscribe(analysis: this);
77 visitBlock(block: &block);
78 for (Operation &op : block)
79 if (failed(Result: initializeRecursively(op: &op)))
80 return failure();
81 }
82 }
83
84 return success();
85}
86
87LogicalResult
88AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
89 if (!point->isBlockStart())
90 return visitOperation(op: point->getPrevOp());
91 visitBlock(block: point->getBlock());
92 return success();
93}
94
95LogicalResult
96AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
97 // Exit early on operations with no results.
98 if (op->getNumResults() == 0)
99 return success();
100
101 // If the containing block is not executable, bail out.
102 if (op->getBlock() != nullptr &&
103 !getOrCreate<Executable>(anchor: getProgramPointBefore(block: op->getBlock()))->isLive())
104 return success();
105
106 // Get the result lattices.
107 SmallVector<AbstractSparseLattice *> resultLattices;
108 resultLattices.reserve(N: op->getNumResults());
109 for (Value result : op->getResults()) {
110 AbstractSparseLattice *resultLattice = getLatticeElement(value: result);
111 resultLattices.push_back(Elt: resultLattice);
112 }
113
114 // The results of a region branch operation are determined by control-flow.
115 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
116 visitRegionSuccessors(point: getProgramPointAfter(branch), branch: branch,
117 /*successor=*/RegionBranchPoint::parent(),
118 lattices: resultLattices);
119 return success();
120 }
121
122 // Grab the lattice elements of the operands.
123 SmallVector<const AbstractSparseLattice *> operandLattices;
124 operandLattices.reserve(N: op->getNumOperands());
125 for (Value operand : op->getOperands()) {
126 AbstractSparseLattice *operandLattice = getLatticeElement(value: operand);
127 operandLattice->useDefSubscribe(analysis: this);
128 operandLattices.push_back(Elt: operandLattice);
129 }
130
131 if (auto call = dyn_cast<CallOpInterface>(op))
132 return visitCallOperation(call: call, operandLattices, resultLattices);
133
134 // Invoke the operation transfer function.
135 return visitOperationImpl(op, operandLattices, resultLattices);
136}
137
138void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
139 // Exit early on blocks with no arguments.
140 if (block->getNumArguments() == 0)
141 return;
142
143 // If the block is not executable, bail out.
144 if (!getOrCreate<Executable>(anchor: getProgramPointBefore(block))->isLive())
145 return;
146
147 // Get the argument lattices.
148 SmallVector<AbstractSparseLattice *> argLattices;
149 argLattices.reserve(N: block->getNumArguments());
150 for (BlockArgument argument : block->getArguments()) {
151 AbstractSparseLattice *argLattice = getLatticeElement(value: argument);
152 argLattices.push_back(Elt: argLattice);
153 }
154
155 // The argument lattices of entry blocks are set by region control-flow or the
156 // callgraph.
157 if (block->isEntryBlock()) {
158 // Check if this block is the entry block of a callable region.
159 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
160 if (callable && callable.getCallableRegion() == block->getParent())
161 return visitCallableOperation(callable, argLattices);
162
163 // Check if the lattices can be determined from region control flow.
164 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
165 return visitRegionSuccessors(point: getProgramPointBefore(block), branch: branch,
166 successor: block->getParent(), lattices: argLattices);
167 }
168
169 // Otherwise, we can't reason about the data-flow.
170 return visitNonControlFlowArgumentsImpl(op: block->getParentOp(),
171 successor: RegionSuccessor(block->getParent()),
172 argLattices, /*firstIndex=*/0);
173 }
174
175 // Iterate over the predecessors of the non-entry block.
176 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
177 it != e; ++it) {
178 Block *predecessor = *it;
179
180 // If the edge from the predecessor block to the current block is not live,
181 // bail out.
182 auto *edgeExecutable =
183 getOrCreate<Executable>(anchor: getLatticeAnchor<CFGEdge>(args&: predecessor, args&: block));
184 edgeExecutable->blockContentSubscribe(analysis: this);
185 if (!edgeExecutable->isLive())
186 continue;
187
188 // Check if we can reason about the data-flow from the predecessor.
189 if (auto branch =
190 dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
191 SuccessorOperands operands =
192 branch.getSuccessorOperands(it.getSuccessorIndex());
193 for (auto [idx, lattice] : llvm::enumerate(First&: argLattices)) {
194 if (Value operand = operands[idx]) {
195 join(lhs: lattice,
196 rhs: *getLatticeElementFor(point: getProgramPointBefore(block), value: operand));
197 } else {
198 // Conservatively consider internally produced arguments as entry
199 // points.
200 setAllToEntryStates(lattice);
201 }
202 }
203 } else {
204 return setAllToEntryStates(argLattices);
205 }
206 }
207}
208
209LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
210 CallOpInterface call,
211 ArrayRef<const AbstractSparseLattice *> operandLattices,
212 ArrayRef<AbstractSparseLattice *> resultLattices) {
213 // If the call operation is to an external function, attempt to infer the
214 // results from the call arguments.
215 auto callable =
216 dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
217 if (!getSolverConfig().isInterprocedural() ||
218 (callable && !callable.getCallableRegion())) {
219 visitExternalCallImpl(call: call, argumentLattices: operandLattices, resultLattices);
220 return success();
221 }
222
223 // Otherwise, the results of a call operation are determined by the
224 // callgraph.
225 const auto *predecessors = getOrCreateFor<PredecessorState>(
226 getProgramPointAfter(call), getProgramPointAfter(call));
227 // If not all return sites are known, then conservatively assume we can't
228 // reason about the data-flow.
229 if (!predecessors->allPredecessorsKnown()) {
230 setAllToEntryStates(resultLattices);
231 return success();
232 }
233 for (Operation *predecessor : predecessors->getKnownPredecessors())
234 for (auto &&[operand, resLattice] :
235 llvm::zip(predecessor->getOperands(), resultLattices))
236 join(resLattice,
237 *getLatticeElementFor(getProgramPointAfter(call), operand));
238 return success();
239}
240
241void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
242 CallableOpInterface callable,
243 ArrayRef<AbstractSparseLattice *> argLattices) {
244 Block *entryBlock = &callable.getCallableRegion()->front();
245 const auto *callsites = getOrCreateFor<PredecessorState>(
246 getProgramPointBefore(block: entryBlock), getProgramPointAfter(callable));
247 // If not all callsites are known, conservatively mark all lattices as
248 // having reached their pessimistic fixpoints.
249 if (!callsites->allPredecessorsKnown() ||
250 !getSolverConfig().isInterprocedural()) {
251 return setAllToEntryStates(argLattices);
252 }
253 for (Operation *callsite : callsites->getKnownPredecessors()) {
254 auto call = cast<CallOpInterface>(callsite);
255 for (auto it : llvm::zip(call.getArgOperands(), argLattices))
256 join(std::get<1>(it),
257 *getLatticeElementFor(getProgramPointBefore(entryBlock),
258 std::get<0>(it)));
259 }
260}
261
262void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
263 ProgramPoint *point, RegionBranchOpInterface branch,
264 RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
265 const auto *predecessors = getOrCreateFor<PredecessorState>(dependent: point, anchor: point);
266 assert(predecessors->allPredecessorsKnown() &&
267 "unexpected unresolved region successors");
268
269 for (Operation *op : predecessors->getKnownPredecessors()) {
270 // Get the incoming successor operands.
271 std::optional<OperandRange> operands;
272
273 // Check if the predecessor is the parent op.
274 if (op == branch) {
275 operands = branch.getEntrySuccessorOperands(successor);
276 // Otherwise, try to deduce the operands from a region return-like op.
277 } else if (auto regionTerminator =
278 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
279 operands = regionTerminator.getSuccessorOperands(successor);
280 }
281
282 if (!operands) {
283 // We can't reason about the data-flow.
284 return setAllToEntryStates(lattices);
285 }
286
287 ValueRange inputs = predecessors->getSuccessorInputs(predecessor: op);
288 assert(inputs.size() == operands->size() &&
289 "expected the same number of successor inputs as operands");
290
291 unsigned firstIndex = 0;
292 if (inputs.size() != lattices.size()) {
293 if (!point->isBlockStart()) {
294 if (!inputs.empty())
295 firstIndex = cast<OpResult>(Val: inputs.front()).getResultNumber();
296 visitNonControlFlowArgumentsImpl(
297 op: branch,
298 successor: RegionSuccessor(
299 branch->getResults().slice(firstIndex, inputs.size())),
300 argLattices: lattices, firstIndex);
301 } else {
302 if (!inputs.empty())
303 firstIndex = cast<BlockArgument>(Val: inputs.front()).getArgNumber();
304 Region *region = point->getBlock()->getParent();
305 visitNonControlFlowArgumentsImpl(
306 op: branch,
307 successor: RegionSuccessor(region, region->getArguments().slice(
308 N: firstIndex, M: inputs.size())),
309 argLattices: lattices, firstIndex);
310 }
311 }
312
313 for (auto it : llvm::zip(t&: *operands, u: lattices.drop_front(N: firstIndex)))
314 join(lhs: std::get<1>(t&: it), rhs: *getLatticeElementFor(point, value: std::get<0>(t&: it)));
315 }
316}
317
318const AbstractSparseLattice *
319AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point,
320 Value value) {
321 AbstractSparseLattice *state = getLatticeElement(value);
322 addDependency(state, point);
323 return state;
324}
325
326void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
327 ArrayRef<AbstractSparseLattice *> lattices) {
328 for (AbstractSparseLattice *lattice : lattices)
329 setToEntryState(lattice);
330}
331
332void AbstractSparseForwardDataFlowAnalysis::join(
333 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
334 propagateIfChanged(state: lhs, changed: lhs->join(rhs));
335}
336
337//===----------------------------------------------------------------------===//
338// AbstractSparseBackwardDataFlowAnalysis
339//===----------------------------------------------------------------------===//
340
341AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
342 DataFlowSolver &solver, SymbolTableCollection &symbolTable)
343 : DataFlowAnalysis(solver), symbolTable(symbolTable) {
344 registerAnchorKind<CFGEdge>();
345}
346
347LogicalResult
348AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
349 return initializeRecursively(op: top);
350}
351
352LogicalResult
353AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
354 if (failed(Result: visitOperation(op)))
355 return failure();
356
357 for (Region &region : op->getRegions()) {
358 for (Block &block : region) {
359 getOrCreate<Executable>(anchor: getProgramPointBefore(block: &block))
360 ->blockContentSubscribe(analysis: this);
361 // Initialize ops in reverse order, so we can do as much initial
362 // propagation as possible without having to go through the
363 // solver queue.
364 for (auto it = block.rbegin(); it != block.rend(); it++)
365 if (failed(Result: initializeRecursively(op: &*it)))
366 return failure();
367 }
368 }
369 return success();
370}
371
372LogicalResult
373AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
374 // For backward dataflow, we don't have to do any work for the blocks
375 // themselves. CFG edges between blocks are processed by the BranchOp
376 // logic in `visitOperation`, and entry blocks for functions are tied
377 // to the CallOp arguments by visitOperation.
378 if (point->isBlockStart())
379 return success();
380 return visitOperation(op: point->getPrevOp());
381}
382
383SmallVector<AbstractSparseLattice *>
384AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
385 SmallVector<AbstractSparseLattice *> resultLattices;
386 resultLattices.reserve(N: values.size());
387 for (Value result : values) {
388 AbstractSparseLattice *resultLattice = getLatticeElement(value: result);
389 resultLattices.push_back(Elt: resultLattice);
390 }
391 return resultLattices;
392}
393
394SmallVector<const AbstractSparseLattice *>
395AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
396 ProgramPoint *point, ValueRange values) {
397 SmallVector<const AbstractSparseLattice *> resultLattices;
398 resultLattices.reserve(N: values.size());
399 for (Value result : values) {
400 const AbstractSparseLattice *resultLattice =
401 getLatticeElementFor(point, value: result);
402 resultLattices.push_back(Elt: resultLattice);
403 }
404 return resultLattices;
405}
406
407static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
408 return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
409}
410
411LogicalResult
412AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
413 // If we're in a dead block, bail out.
414 if (op->getBlock() != nullptr &&
415 !getOrCreate<Executable>(anchor: getProgramPointBefore(block: op->getBlock()))->isLive())
416 return success();
417
418 SmallVector<AbstractSparseLattice *> operandLattices =
419 getLatticeElements(values: op->getOperands());
420 SmallVector<const AbstractSparseLattice *> resultLattices =
421 getLatticeElementsFor(point: getProgramPointAfter(op), values: op->getResults());
422
423 // Block arguments of region branch operations flow back into the operands
424 // of the parent op
425 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
426 visitRegionSuccessors(branch: branch, operands: operandLattices);
427 return success();
428 }
429
430 if (auto branch = dyn_cast<BranchOpInterface>(op)) {
431 // Block arguments of successor blocks flow back into our operands.
432
433 // We remember all operands not forwarded to any block in a BitVector.
434 // We can't just cut out a range here, since the non-forwarded ops might
435 // be non-contiguous (if there's more than one successor).
436 BitVector unaccounted(op->getNumOperands(), true);
437
438 for (auto [index, block] : llvm::enumerate(First: op->getSuccessors())) {
439 SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
440 OperandRange forwarded = successorOperands.getForwardedOperands();
441 if (!forwarded.empty()) {
442 MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
443 N: forwarded.getBeginOperandIndex(), M: forwarded.size());
444 for (OpOperand &operand : operands) {
445 unaccounted.reset(operand.getOperandNumber());
446 if (std::optional<BlockArgument> blockArg =
447 detail::getBranchSuccessorArgument(
448 successorOperands, operand.getOperandNumber(), block)) {
449 meet(getLatticeElement(operand.get()),
450 *getLatticeElementFor(getProgramPointAfter(op), *blockArg));
451 }
452 }
453 }
454 }
455 // Operands not forwarded to successor blocks are typically parameters
456 // of the branch operation itself (for example the boolean for if/else).
457 for (int index : unaccounted.set_bits()) {
458 OpOperand &operand = op->getOpOperand(idx: index);
459 visitBranchOperand(operand);
460 }
461 return success();
462 }
463
464 // For function calls, connect the arguments of the entry blocks to the
465 // operands of the call op that are forwarded to these arguments.
466 if (auto call = dyn_cast<CallOpInterface>(op)) {
467 Operation *callableOp = call.resolveCallableInTable(&symbolTable);
468 if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
469 // Not all operands of a call op forward to arguments. Such operands are
470 // stored in `unaccounted`.
471 BitVector unaccounted(op->getNumOperands(), true);
472
473 // If the call invokes an external function (or a function treated as
474 // external due to config), defer to the corresponding extension hook.
475 // By default, it just does `visitCallOperand` for all operands.
476 OperandRange argOperands = call.getArgOperands();
477 MutableArrayRef<OpOperand> argOpOperands =
478 operandsToOpOperands(operands&: argOperands);
479 Region *region = callable.getCallableRegion();
480 if (!region || region->empty() ||
481 !getSolverConfig().isInterprocedural()) {
482 visitExternalCallImpl(call: call, operandLattices, resultLattices);
483 return success();
484 }
485
486 // Otherwise, propagate information from the entry point of the function
487 // back to operands whenever possible.
488 Block &block = region->front();
489 for (auto [blockArg, argOpOperand] :
490 llvm::zip(block.getArguments(), argOpOperands)) {
491 meet(getLatticeElement(argOpOperand.get()),
492 *getLatticeElementFor(getProgramPointAfter(op), blockArg));
493 unaccounted.reset(argOpOperand.getOperandNumber());
494 }
495
496 // Handle the operands of the call op that aren't forwarded to any
497 // arguments.
498 for (int index : unaccounted.set_bits()) {
499 OpOperand &opOperand = op->getOpOperand(idx: index);
500 visitCallOperand(operand&: opOperand);
501 }
502 return success();
503 }
504 }
505
506 // When the region of an op implementing `RegionBranchOpInterface` has a
507 // terminator implementing `RegionBranchTerminatorOpInterface` or a
508 // return-like terminator, the region's successors' arguments flow back into
509 // the "successor operands" of this terminator.
510 //
511 // A successor operand with respect to an op implementing
512 // `RegionBranchOpInterface` is an operand that is forwarded to a region
513 // successor's input. There are two types of successor operands: the operands
514 // of this op itself and the operands of the terminators of the regions of
515 // this op.
516 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
517 if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
518 visitRegionSuccessorsFromTerminator(terminator, branch);
519 return success();
520 }
521 }
522
523 if (op->hasTrait<OpTrait::ReturnLike>()) {
524 // Going backwards, the operands of the return are derived from the
525 // results of all CallOps calling this CallableOp.
526 if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
527 return visitCallableOperation(op, callable, operandLattices);
528 }
529
530 return visitOperationImpl(op, operandLattices, resultLattices);
531}
532
533LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
534 Operation *op, CallableOpInterface callable,
535 ArrayRef<AbstractSparseLattice *> operandLattices) {
536 const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
537 getProgramPointAfter(op), getProgramPointAfter(callable));
538 if (callsites->allPredecessorsKnown()) {
539 for (Operation *call : callsites->getKnownPredecessors()) {
540 SmallVector<const AbstractSparseLattice *> callResultLattices =
541 getLatticeElementsFor(getProgramPointAfter(op), call->getResults());
542 for (auto [op, result] : llvm::zip(operandLattices, callResultLattices))
543 meet(op, *result);
544 }
545 } else {
546 // If we don't know all the callers, we can't know where the
547 // returned values go. Note that, in particular, this will trigger
548 // for the return ops of any public functions.
549 setAllToExitStates(operandLattices);
550 }
551 return success();
552}
553
554void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
555 RegionBranchOpInterface branch,
556 ArrayRef<AbstractSparseLattice *> operandLattices) {
557 Operation *op = branch.getOperation();
558 SmallVector<RegionSuccessor> successors;
559 SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
560 branch.getEntrySuccessorRegions(operands, successors);
561
562 // All operands not forwarded to any successor. This set can be non-contiguous
563 // in the presence of multiple successors.
564 BitVector unaccounted(op->getNumOperands(), true);
565
566 for (RegionSuccessor &successor : successors) {
567 OperandRange operands = branch.getEntrySuccessorOperands(successor);
568 MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
569 ValueRange inputs = successor.getSuccessorInputs();
570 for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
571 meet(getLatticeElement(operand.get()),
572 *getLatticeElementFor(getProgramPointAfter(op), input));
573 unaccounted.reset(operand.getOperandNumber());
574 }
575 }
576 // All operands not forwarded to regions are typically parameters of the
577 // branch operation itself (for example the boolean for if/else).
578 for (int index : unaccounted.set_bits()) {
579 visitBranchOperand(op->getOpOperand(index));
580 }
581}
582
583void AbstractSparseBackwardDataFlowAnalysis::
584 visitRegionSuccessorsFromTerminator(
585 RegionBranchTerminatorOpInterface terminator,
586 RegionBranchOpInterface branch) {
587 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
588 "expected a `RegionBranchTerminatorOpInterface` op");
589 assert(terminator->getParentOp() == branch.getOperation() &&
590 "expected `branch` to be the parent op of `terminator`");
591
592 SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
593 nullptr);
594 SmallVector<RegionSuccessor> successors;
595 terminator.getSuccessorRegions(operandAttributes, successors);
596 // All operands not forwarded to any successor. This set can be
597 // non-contiguous in the presence of multiple successors.
598 BitVector unaccounted(terminator->getNumOperands(), true);
599
600 for (const RegionSuccessor &successor : successors) {
601 ValueRange inputs = successor.getSuccessorInputs();
602 OperandRange operands = terminator.getSuccessorOperands(successor);
603 MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
604 for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
605 meet(getLatticeElement(opOperand.get()),
606 *getLatticeElementFor(getProgramPointAfter(terminator), input));
607 unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
608 }
609 }
610 // Visit operands of the branch op not forwarded to the next region.
611 // (Like e.g. the boolean of `scf.conditional`)
612 for (int index : unaccounted.set_bits()) {
613 visitBranchOperand(terminator->getOpOperand(index));
614 }
615}
616
617const AbstractSparseLattice *
618AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
619 ProgramPoint *point, Value value) {
620 AbstractSparseLattice *state = getLatticeElement(value);
621 addDependency(state, point);
622 return state;
623}
624
625void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
626 ArrayRef<AbstractSparseLattice *> lattices) {
627 for (AbstractSparseLattice *lattice : lattices)
628 setToExitState(lattice);
629}
630
631void AbstractSparseBackwardDataFlowAnalysis::meet(
632 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
633 propagateIfChanged(state: lhs, changed: lhs->meet(rhs));
634}
635

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp