1//===- SparseAnalysis.h - Sparse data-flow analysis -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements sparse data-flow analysis using the data-flow analysis
10// framework. The analysis is forward and conditional and uses the results of
11// dead code analysis to prune dead code during the analysis.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
16#define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
17
18#include "mlir/Analysis/DataFlowFramework.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/Interfaces/CallInterfaces.h"
21#include "mlir/Interfaces/ControlFlowInterfaces.h"
22#include "llvm/ADT/SmallPtrSet.h"
23
24namespace mlir {
25namespace dataflow {
26
27//===----------------------------------------------------------------------===//
28// AbstractSparseLattice
29//===----------------------------------------------------------------------===//
30
31/// This class represents an abstract lattice. A lattice contains information
32/// about an SSA value and is what's propagated across the IR by sparse
33/// data-flow analysis.
34class AbstractSparseLattice : public AnalysisState {
35public:
36 /// Lattices can only be created for values.
37 AbstractSparseLattice(Value value) : AnalysisState(value) {}
38
39 /// Return the program point this lattice is located at.
40 Value getPoint() const { return AnalysisState::getPoint().get<Value>(); }
41
42 /// Join the information contained in 'rhs' into this lattice. Returns
43 /// if the value of the lattice changed.
44 virtual ChangeResult join(const AbstractSparseLattice &rhs) {
45 return ChangeResult::NoChange;
46 }
47
48 /// Meet (intersect) the information in this lattice with 'rhs'. Returns
49 /// if the value of the lattice changed.
50 virtual ChangeResult meet(const AbstractSparseLattice &rhs) {
51 return ChangeResult::NoChange;
52 }
53
54 /// When the lattice gets updated, propagate an update to users of the value
55 /// using its use-def chain to subscribed analyses.
56 void onUpdate(DataFlowSolver *solver) const override;
57
58 /// Subscribe an analysis to updates of the lattice. When the lattice changes,
59 /// subscribed analyses are re-invoked on all users of the value. This is
60 /// more efficient than relying on the dependency map.
61 void useDefSubscribe(DataFlowAnalysis *analysis) {
62 useDefSubscribers.insert(X: analysis);
63 }
64
65private:
66 /// A set of analyses that should be updated when this lattice changes.
67 SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
68 SmallPtrSet<DataFlowAnalysis *, 4>>
69 useDefSubscribers;
70};
71
72//===----------------------------------------------------------------------===//
73// Lattice
74//===----------------------------------------------------------------------===//
75
76/// This class represents a lattice holding a specific value of type `ValueT`.
77/// Lattice values (`ValueT`) are required to adhere to the following:
78///
79/// * static ValueT join(const ValueT &lhs, const ValueT &rhs);
80/// - This method conservatively joins the information held by `lhs`
81/// and `rhs` into a new value. This method is required to be monotonic.
82/// * bool operator==(const ValueT &rhs) const;
83///
84template <typename ValueT>
85class Lattice : public AbstractSparseLattice {
86public:
87 using AbstractSparseLattice::AbstractSparseLattice;
88
89 /// Return the program point this lattice is located at.
90 Value getPoint() const { return point.get<Value>(); }
91
92 /// Return the value held by this lattice. This requires that the value is
93 /// initialized.
94 ValueT &getValue() { return value; }
95 const ValueT &getValue() const {
96 return const_cast<Lattice<ValueT> *>(this)->getValue();
97 }
98
99 using LatticeT = Lattice<ValueT>;
100
101 /// Join the information contained in the 'rhs' lattice into this
102 /// lattice. Returns if the state of the current lattice changed.
103 ChangeResult join(const AbstractSparseLattice &rhs) override {
104 return join(static_cast<const LatticeT &>(rhs).getValue());
105 }
106
107 /// Meet (intersect) the information contained in the 'rhs' lattice with
108 /// this lattice. Returns if the state of the current lattice changed.
109 ChangeResult meet(const AbstractSparseLattice &rhs) override {
110 return meet(static_cast<const LatticeT &>(rhs).getValue());
111 }
112
113 /// Join the information contained in the 'rhs' value into this
114 /// lattice. Returns if the state of the current lattice changed.
115 ChangeResult join(const ValueT &rhs) {
116 // Otherwise, join rhs with the current optimistic value.
117 ValueT newValue = ValueT::join(value, rhs);
118 assert(ValueT::join(newValue, value) == newValue &&
119 "expected `join` to be monotonic");
120 assert(ValueT::join(newValue, rhs) == newValue &&
121 "expected `join` to be monotonic");
122
123 // Update the current optimistic value if something changed.
124 if (newValue == value)
125 return ChangeResult::NoChange;
126
127 value = newValue;
128 return ChangeResult::Change;
129 }
130
131 /// Trait to check if `T` provides a `meet` method. Needed since for forward
132 /// analysis, lattices will only have a `join`, no `meet`, but we want to use
133 /// the same `Lattice` class for both directions.
134 template <typename T, typename... Args>
135 using has_meet = decltype(std::declval<T>().meet());
136 template <typename T>
137 using lattice_has_meet = llvm::is_detected<has_meet, T>;
138
139 /// Meet (intersect) the information contained in the 'rhs' value with this
140 /// lattice. Returns if the state of the current lattice changed. If the
141 /// lattice elements don't have a `meet` method, this is a no-op (see below.)
142 template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
143 ChangeResult meet(const VT &rhs) {
144 ValueT newValue = ValueT::meet(value, rhs);
145 assert(ValueT::meet(newValue, value) == newValue &&
146 "expected `meet` to be monotonic");
147 assert(ValueT::meet(newValue, rhs) == newValue &&
148 "expected `meet` to be monotonic");
149
150 // Update the current optimistic value if something changed.
151 if (newValue == value)
152 return ChangeResult::NoChange;
153
154 value = newValue;
155 return ChangeResult::Change;
156 }
157
158 template <typename VT>
159 ChangeResult meet(const VT &rhs) {
160 return ChangeResult::NoChange;
161 }
162
163 /// Print the lattice element.
164 void print(raw_ostream &os) const override { value.print(os); }
165
166private:
167 /// The currently computed value that is optimistically assumed to be true.
168 ValueT value;
169};
170
171//===----------------------------------------------------------------------===//
172// AbstractSparseForwardDataFlowAnalysis
173//===----------------------------------------------------------------------===//
174
175/// Base class for sparse forward data-flow analyses. A sparse analysis
176/// implements a transfer function on operations from the lattices of the
177/// operands to the lattices of the results. This analysis will propagate
178/// lattices across control-flow edges and the callgraph using liveness
179/// information.
180class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
181public:
182 /// Initialize the analysis by visiting every owner of an SSA value: all
183 /// operations and blocks.
184 LogicalResult initialize(Operation *top) override;
185
186 /// Visit a program point. If this is a block and all control-flow
187 /// predecessors or callsites are known, then the arguments lattices are
188 /// propagated from them. If this is a call operation or an operation with
189 /// region control-flow, then its result lattices are set accordingly.
190 /// Otherwise, the operation transfer function is invoked.
191 LogicalResult visit(ProgramPoint point) override;
192
193protected:
194 explicit AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver);
195
196 /// The operation transfer function. Given the operand lattices, this
197 /// function is expected to set the result lattices.
198 virtual void
199 visitOperationImpl(Operation *op,
200 ArrayRef<const AbstractSparseLattice *> operandLattices,
201 ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
202
203 /// The transfer function for calls to external functions.
204 virtual void visitExternalCallImpl(
205 CallOpInterface call,
206 ArrayRef<const AbstractSparseLattice *> argumentLattices,
207 ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
208
209 /// Given an operation with region control-flow, the lattices of the operands,
210 /// and a region successor, compute the lattice values for block arguments
211 /// that are not accounted for by the branching control flow (ex. the bounds
212 /// of loops).
213 virtual void visitNonControlFlowArgumentsImpl(
214 Operation *op, const RegionSuccessor &successor,
215 ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
216
217 /// Get the lattice element of a value.
218 virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
219
220 /// Get a read-only lattice element for a value and add it as a dependency to
221 /// a program point.
222 const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
223 Value value);
224
225 /// Set the given lattice element(s) at control flow entry point(s).
226 virtual void setToEntryState(AbstractSparseLattice *lattice) = 0;
227 void setAllToEntryStates(ArrayRef<AbstractSparseLattice *> lattices);
228
229 /// Join the lattice element and propagate and update if it changed.
230 void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
231
232private:
233 /// Recursively initialize the analysis on nested operations and blocks.
234 LogicalResult initializeRecursively(Operation *op);
235
236 /// Visit an operation. If this is a call operation or an operation with
237 /// region control-flow, then its result lattices are set accordingly.
238 /// Otherwise, the operation transfer function is invoked.
239 void visitOperation(Operation *op);
240
241 /// Visit a block to compute the lattice values of its arguments. If this is
242 /// an entry block, then the argument values are determined from the block's
243 /// "predecessors" as set by `PredecessorState`. The predecessors can be
244 /// region terminators or callable callsites. Otherwise, the values are
245 /// determined from block predecessors.
246 void visitBlock(Block *block);
247
248 /// Visit a program point `point` with predecessors within a region branch
249 /// operation `branch`, which can either be the entry block of one of the
250 /// regions or the parent operation itself, and set either the argument or
251 /// parent result lattices.
252 void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
253 RegionBranchPoint successor,
254 ArrayRef<AbstractSparseLattice *> lattices);
255};
256
257//===----------------------------------------------------------------------===//
258// SparseForwardDataFlowAnalysis
259//===----------------------------------------------------------------------===//
260
261/// A sparse forward data-flow analysis for propagating SSA value lattices
262/// across the IR by implementing transfer functions for operations.
263///
264/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
265template <typename StateT>
266class SparseForwardDataFlowAnalysis
267 : public AbstractSparseForwardDataFlowAnalysis {
268 static_assert(
269 std::is_base_of<AbstractSparseLattice, StateT>::value,
270 "analysis state class expected to subclass AbstractSparseLattice");
271
272public:
273 explicit SparseForwardDataFlowAnalysis(DataFlowSolver &solver)
274 : AbstractSparseForwardDataFlowAnalysis(solver) {}
275
276 /// Visit an operation with the lattices of its operands. This function is
277 /// expected to set the lattices of the operation's results.
278 virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
279 ArrayRef<StateT *> results) = 0;
280
281 /// Visit a call operation to an externally defined function given the
282 /// lattices of its arguments.
283 virtual void visitExternalCall(CallOpInterface call,
284 ArrayRef<const StateT *> argumentLattices,
285 ArrayRef<StateT *> resultLattices) {
286 setAllToEntryStates(resultLattices);
287 }
288
289 /// Given an operation with possible region control-flow, the lattices of the
290 /// operands, and a region successor, compute the lattice values for block
291 /// arguments that are not accounted for by the branching control flow (ex.
292 /// the bounds of loops). By default, this method marks all such lattice
293 /// elements as having reached a pessimistic fixpoint. `firstIndex` is the
294 /// index of the first element of `argLattices` that is set by control-flow.
295 virtual void visitNonControlFlowArguments(Operation *op,
296 const RegionSuccessor &successor,
297 ArrayRef<StateT *> argLattices,
298 unsigned firstIndex) {
299 setAllToEntryStates(argLattices.take_front(firstIndex));
300 setAllToEntryStates(argLattices.drop_front(
301 firstIndex + successor.getSuccessorInputs().size()));
302 }
303
304protected:
305 /// Get the lattice element for a value.
306 StateT *getLatticeElement(Value value) override {
307 return getOrCreate<StateT>(value);
308 }
309
310 /// Get the lattice element for a value and create a dependency on the
311 /// provided program point.
312 const StateT *getLatticeElementFor(ProgramPoint point, Value value) {
313 return static_cast<const StateT *>(
314 AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(point,
315 value));
316 }
317
318 /// Set the given lattice element(s) at control flow entry point(s).
319 virtual void setToEntryState(StateT *lattice) = 0;
320 void setAllToEntryStates(ArrayRef<StateT *> lattices) {
321 AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
322 {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
323 lattices.size()});
324 }
325
326private:
327 /// Type-erased wrappers that convert the abstract lattice operands to derived
328 /// lattices and invoke the virtual hooks operating on the derived lattices.
329 void visitOperationImpl(
330 Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
331 ArrayRef<AbstractSparseLattice *> resultLattices) override {
332 visitOperation(
333 op,
334 operands: {reinterpret_cast<const StateT *const *>(operandLattices.begin()),
335 operandLattices.size()},
336 results: {reinterpret_cast<StateT *const *>(resultLattices.begin()),
337 resultLattices.size()});
338 }
339 void visitExternalCallImpl(
340 CallOpInterface call,
341 ArrayRef<const AbstractSparseLattice *> argumentLattices,
342 ArrayRef<AbstractSparseLattice *> resultLattices) override {
343 visitExternalCall(
344 call,
345 {reinterpret_cast<const StateT *const *>(argumentLattices.begin()),
346 argumentLattices.size()},
347 {reinterpret_cast<StateT *const *>(resultLattices.begin()),
348 resultLattices.size()});
349 }
350 void visitNonControlFlowArgumentsImpl(
351 Operation *op, const RegionSuccessor &successor,
352 ArrayRef<AbstractSparseLattice *> argLattices,
353 unsigned firstIndex) override {
354 visitNonControlFlowArguments(
355 op, successor,
356 argLattices: {reinterpret_cast<StateT *const *>(argLattices.begin()),
357 argLattices.size()},
358 firstIndex);
359 }
360 void setToEntryState(AbstractSparseLattice *lattice) override {
361 return setToEntryState(reinterpret_cast<StateT *>(lattice));
362 }
363};
364
365//===----------------------------------------------------------------------===//
366// AbstractSparseBackwardDataFlowAnalysis
367//===----------------------------------------------------------------------===//
368
369/// Base class for sparse backward data-flow analyses. Similar to
370/// AbstractSparseForwardDataFlowAnalysis, but walks bottom to top.
371class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
372public:
373 /// Initialize the analysis by visiting the operation and everything nested
374 /// under it.
375 LogicalResult initialize(Operation *top) override;
376
377 /// Visit a program point. If this is a call operation or an operation with
378 /// block or region control-flow, then operand lattices are set accordingly.
379 /// Otherwise, invokes the operation transfer function (`visitOperationImpl`).
380 LogicalResult visit(ProgramPoint point) override;
381
382protected:
383 explicit AbstractSparseBackwardDataFlowAnalysis(
384 DataFlowSolver &solver, SymbolTableCollection &symbolTable);
385
386 /// The operation transfer function. Given the result lattices, this
387 /// function is expected to set the operand lattices.
388 virtual void visitOperationImpl(
389 Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
390 ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
391
392 /// The transfer function for calls to external functions.
393 virtual void visitExternalCallImpl(
394 CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
395 ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
396
397 // Visit operands on branch instructions that are not forwarded.
398 virtual void visitBranchOperand(OpOperand &operand) = 0;
399
400 // Visit operands on call instructions that are not forwarded.
401 virtual void visitCallOperand(OpOperand &operand) = 0;
402
403 /// Set the given lattice element(s) at control flow exit point(s).
404 virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
405
406 /// Set the given lattice element(s) at control flow exit point(s).
407 void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices);
408
409 /// Get the lattice element for a value.
410 virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
411
412 /// Get the lattice elements for a range of values.
413 SmallVector<AbstractSparseLattice *> getLatticeElements(ValueRange values);
414
415 /// Join the lattice element and propagate and update if it changed.
416 void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
417
418private:
419 /// Recursively initialize the analysis on nested operations and blocks.
420 LogicalResult initializeRecursively(Operation *op);
421
422 /// Visit an operation. If this is a call operation or an operation with
423 /// region control-flow, then its operand lattices are set accordingly.
424 /// Otherwise, the operation transfer function is invoked.
425 void visitOperation(Operation *op);
426
427 /// Visit a block.
428 void visitBlock(Block *block);
429
430 /// Visit an op with regions (like e.g. `scf.while`)
431 void visitRegionSuccessors(RegionBranchOpInterface branch,
432 ArrayRef<AbstractSparseLattice *> operands);
433
434 /// Visit a `RegionBranchTerminatorOpInterface` to compute the lattice values
435 /// of its operands, given its parent op `branch`. The lattice value of an
436 /// operand is determined based on the corresponding arguments in
437 /// `terminator`'s region successor(s).
438 void visitRegionSuccessorsFromTerminator(
439 RegionBranchTerminatorOpInterface terminator,
440 RegionBranchOpInterface branch);
441
442 /// Get the lattice element for a value, and also set up
443 /// dependencies so that the analysis on the given ProgramPoint is re-invoked
444 /// if the value changes.
445 const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
446 Value value);
447
448 /// Get the lattice elements for a range of values, and also set up
449 /// dependencies so that the analysis on the given ProgramPoint is re-invoked
450 /// if any of the values change.
451 SmallVector<const AbstractSparseLattice *>
452 getLatticeElementsFor(ProgramPoint point, ValueRange values);
453
454 SymbolTableCollection &symbolTable;
455};
456
457//===----------------------------------------------------------------------===//
458// SparseBackwardDataFlowAnalysis
459//===----------------------------------------------------------------------===//
460
461/// A sparse (backward) data-flow analysis for propagating SSA value lattices
462/// backwards across the IR by implementing transfer functions for operations.
463///
464/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
465template <typename StateT>
466class SparseBackwardDataFlowAnalysis
467 : public AbstractSparseBackwardDataFlowAnalysis {
468public:
469 explicit SparseBackwardDataFlowAnalysis(DataFlowSolver &solver,
470 SymbolTableCollection &symbolTable)
471 : AbstractSparseBackwardDataFlowAnalysis(solver, symbolTable) {}
472
473 /// Visit an operation with the lattices of its results. This function is
474 /// expected to set the lattices of the operation's operands.
475 virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
476 ArrayRef<const StateT *> results) = 0;
477
478 /// Visit a call to an external function. This function is expected to set
479 /// lattice values of the call operands. By default, calls `visitCallOperand`
480 /// for all operands.
481 virtual void visitExternalCall(CallOpInterface call,
482 ArrayRef<StateT *> argumentLattices,
483 ArrayRef<const StateT *> resultLattices) {
484 (void)argumentLattices;
485 (void)resultLattices;
486 for (OpOperand &operand : call->getOpOperands()) {
487 visitCallOperand(operand);
488 }
489 };
490
491protected:
492 /// Get the lattice element for a value.
493 StateT *getLatticeElement(Value value) override {
494 return getOrCreate<StateT>(value);
495 }
496
497 /// Set the given lattice element(s) at control flow exit point(s).
498 virtual void setToExitState(StateT *lattice) = 0;
499 void setToExitState(AbstractSparseLattice *lattice) override {
500 return setToExitState(reinterpret_cast<StateT *>(lattice));
501 }
502 void setAllToExitStates(ArrayRef<StateT *> lattices) {
503 AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
504 {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
505 lattices.size()});
506 }
507
508private:
509 /// Type-erased wrappers that convert the abstract lattice operands to derived
510 /// lattices and invoke the virtual hooks operating on the derived lattices.
511 void visitOperationImpl(
512 Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
513 ArrayRef<const AbstractSparseLattice *> resultLattices) override {
514 visitOperation(
515 op,
516 operands: {reinterpret_cast<StateT *const *>(operandLattices.begin()),
517 operandLattices.size()},
518 results: {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
519 resultLattices.size()});
520 }
521
522 void visitExternalCallImpl(
523 CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
524 ArrayRef<const AbstractSparseLattice *> resultLattices) override {
525 visitExternalCall(
526 call,
527 {reinterpret_cast<StateT *const *>(operandLattices.begin()),
528 operandLattices.size()},
529 {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
530 resultLattices.size()});
531 }
532};
533
534} // end namespace dataflow
535} // end namespace mlir
536
537#endif // MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
538

source code of mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h