1 | //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
10 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
11 | #include "mlir/IR/BuiltinAttributes.h" |
12 | #include "mlir/IR/OpDefinition.h" |
13 | #include "mlir/IR/Operation.h" |
14 | #include "mlir/IR/Value.h" |
15 | #include "mlir/Support/LLVM.h" |
16 | #include "llvm/ADT/STLExtras.h" |
17 | #include "llvm/Support/Casting.h" |
18 | #include "llvm/Support/Debug.h" |
19 | #include <cassert> |
20 | |
21 | #define DEBUG_TYPE "constant-propagation" |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::dataflow; |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // ConstantValue |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | void ConstantValue::print(raw_ostream &os) const { |
31 | if (isUninitialized()) { |
32 | os << "<UNINITIALIZED>" ; |
33 | return; |
34 | } |
35 | if (getConstantValue() == nullptr) { |
36 | os << "<UNKNOWN>" ; |
37 | return; |
38 | } |
39 | return getConstantValue().print(os); |
40 | } |
41 | |
42 | //===----------------------------------------------------------------------===// |
43 | // SparseConstantPropagation |
44 | //===----------------------------------------------------------------------===// |
45 | |
46 | LogicalResult SparseConstantPropagation::visitOperation( |
47 | Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, |
48 | ArrayRef<Lattice<ConstantValue> *> results) { |
49 | LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n" ); |
50 | |
51 | // Don't try to simulate the results of a region operation as we can't |
52 | // guarantee that folding will be out-of-place. We don't allow in-place |
53 | // folds as the desire here is for simulated execution, and not general |
54 | // folding. |
55 | if (op->getNumRegions()) { |
56 | setAllToEntryStates(results); |
57 | return success(); |
58 | } |
59 | |
60 | SmallVector<Attribute, 8> constantOperands; |
61 | constantOperands.reserve(N: op->getNumOperands()); |
62 | for (auto *operandLattice : operands) { |
63 | if (operandLattice->getValue().isUninitialized()) |
64 | return success(); |
65 | constantOperands.push_back(Elt: operandLattice->getValue().getConstantValue()); |
66 | } |
67 | |
68 | // Save the original operands and attributes just in case the operation |
69 | // folds in-place. The constant passed in may not correspond to the real |
70 | // runtime value, so in-place updates are not allowed. |
71 | SmallVector<Value, 8> originalOperands(op->getOperands()); |
72 | DictionaryAttr originalAttrs = op->getAttrDictionary(); |
73 | |
74 | // Simulate the result of folding this operation to a constant. If folding |
75 | // fails or was an in-place fold, mark the results as overdefined. |
76 | SmallVector<OpFoldResult, 8> foldResults; |
77 | foldResults.reserve(N: op->getNumResults()); |
78 | if (failed(Result: op->fold(operands: constantOperands, results&: foldResults))) { |
79 | setAllToEntryStates(results); |
80 | return success(); |
81 | } |
82 | |
83 | // If the folding was in-place, mark the results as overdefined and reset |
84 | // the operation. We don't allow in-place folds as the desire here is for |
85 | // simulated execution, and not general folding. |
86 | if (foldResults.empty()) { |
87 | op->setOperands(originalOperands); |
88 | op->setAttrs(originalAttrs); |
89 | setAllToEntryStates(results); |
90 | return success(); |
91 | } |
92 | |
93 | // Merge the fold results into the lattice for this operation. |
94 | assert(foldResults.size() == op->getNumResults() && "invalid result size" ); |
95 | for (const auto it : llvm::zip(t&: results, u&: foldResults)) { |
96 | Lattice<ConstantValue> *lattice = std::get<0>(t: it); |
97 | |
98 | // Merge in the result of the fold, either a constant or a value. |
99 | OpFoldResult foldResult = std::get<1>(t: it); |
100 | if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: foldResult)) { |
101 | LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n" ); |
102 | propagateIfChanged(state: lattice, |
103 | changed: lattice->join(rhs: ConstantValue(attr, op->getDialect()))); |
104 | } else { |
105 | LLVM_DEBUG(llvm::dbgs() |
106 | << "Folded to value: " << cast<Value>(foldResult) << "\n" ); |
107 | AbstractSparseForwardDataFlowAnalysis::join( |
108 | lhs: lattice, rhs: *getLatticeElement(value: cast<Value>(Val&: foldResult))); |
109 | } |
110 | } |
111 | return success(); |
112 | } |
113 | |
114 | void SparseConstantPropagation::setToEntryState( |
115 | Lattice<ConstantValue> *lattice) { |
116 | propagateIfChanged(state: lattice, |
117 | changed: lattice->join(rhs: ConstantValue::getUnknownConstant())); |
118 | } |
119 | |