1 | //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
10 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
11 | #include "mlir/IR/BuiltinAttributes.h" |
12 | #include "mlir/IR/OpDefinition.h" |
13 | #include "mlir/IR/Operation.h" |
14 | #include "mlir/IR/Value.h" |
15 | #include "mlir/Support/LLVM.h" |
16 | #include "mlir/Support/LogicalResult.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include "llvm/Support/Casting.h" |
19 | #include "llvm/Support/Debug.h" |
20 | #include <cassert> |
21 | |
22 | #define DEBUG_TYPE "constant-propagation" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::dataflow; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // ConstantValue |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | void ConstantValue::print(raw_ostream &os) const { |
32 | if (isUninitialized()) { |
33 | os << "<UNINITIALIZED>" ; |
34 | return; |
35 | } |
36 | if (getConstantValue() == nullptr) { |
37 | os << "<UNKNOWN>" ; |
38 | return; |
39 | } |
40 | return getConstantValue().print(os); |
41 | } |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // SparseConstantPropagation |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | void SparseConstantPropagation::visitOperation( |
48 | Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, |
49 | ArrayRef<Lattice<ConstantValue> *> results) { |
50 | LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n" ); |
51 | |
52 | // Don't try to simulate the results of a region operation as we can't |
53 | // guarantee that folding will be out-of-place. We don't allow in-place |
54 | // folds as the desire here is for simulated execution, and not general |
55 | // folding. |
56 | if (op->getNumRegions()) { |
57 | setAllToEntryStates(results); |
58 | return; |
59 | } |
60 | |
61 | SmallVector<Attribute, 8> constantOperands; |
62 | constantOperands.reserve(N: op->getNumOperands()); |
63 | for (auto *operandLattice : operands) { |
64 | if (operandLattice->getValue().isUninitialized()) |
65 | return; |
66 | constantOperands.push_back(Elt: operandLattice->getValue().getConstantValue()); |
67 | } |
68 | |
69 | // Save the original operands and attributes just in case the operation |
70 | // folds in-place. The constant passed in may not correspond to the real |
71 | // runtime value, so in-place updates are not allowed. |
72 | SmallVector<Value, 8> originalOperands(op->getOperands()); |
73 | DictionaryAttr originalAttrs = op->getAttrDictionary(); |
74 | |
75 | // Simulate the result of folding this operation to a constant. If folding |
76 | // fails or was an in-place fold, mark the results as overdefined. |
77 | SmallVector<OpFoldResult, 8> foldResults; |
78 | foldResults.reserve(N: op->getNumResults()); |
79 | if (failed(result: op->fold(operands: constantOperands, results&: foldResults))) { |
80 | setAllToEntryStates(results); |
81 | return; |
82 | } |
83 | |
84 | // If the folding was in-place, mark the results as overdefined and reset |
85 | // the operation. We don't allow in-place folds as the desire here is for |
86 | // simulated execution, and not general folding. |
87 | if (foldResults.empty()) { |
88 | op->setOperands(originalOperands); |
89 | op->setAttrs(originalAttrs); |
90 | setAllToEntryStates(results); |
91 | return; |
92 | } |
93 | |
94 | // Merge the fold results into the lattice for this operation. |
95 | assert(foldResults.size() == op->getNumResults() && "invalid result size" ); |
96 | for (const auto it : llvm::zip(t&: results, u&: foldResults)) { |
97 | Lattice<ConstantValue> *lattice = std::get<0>(t: it); |
98 | |
99 | // Merge in the result of the fold, either a constant or a value. |
100 | OpFoldResult foldResult = std::get<1>(t: it); |
101 | if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: foldResult)) { |
102 | LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n" ); |
103 | propagateIfChanged(lattice, |
104 | lattice->join(rhs: ConstantValue(attr, op->getDialect()))); |
105 | } else { |
106 | LLVM_DEBUG(llvm::dbgs() |
107 | << "Folded to value: " << foldResult.get<Value>() << "\n" ); |
108 | AbstractSparseForwardDataFlowAnalysis::join( |
109 | lattice, *getLatticeElement(foldResult.get<Value>())); |
110 | } |
111 | } |
112 | } |
113 | |
114 | void SparseConstantPropagation::setToEntryState( |
115 | Lattice<ConstantValue> *lattice) { |
116 | propagateIfChanged(lattice, |
117 | lattice->join(rhs: ConstantValue::getUnknownConstant())); |
118 | } |
119 | |