| 1 | //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
| 10 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
| 11 | #include "mlir/IR/BuiltinAttributes.h" |
| 12 | #include "mlir/IR/OpDefinition.h" |
| 13 | #include "mlir/IR/Operation.h" |
| 14 | #include "mlir/IR/Value.h" |
| 15 | #include "mlir/Support/LLVM.h" |
| 16 | #include "llvm/ADT/STLExtras.h" |
| 17 | #include "llvm/Support/Casting.h" |
| 18 | #include "llvm/Support/Debug.h" |
| 19 | #include <cassert> |
| 20 | |
| 21 | #define DEBUG_TYPE "constant-propagation" |
| 22 | |
| 23 | using namespace mlir; |
| 24 | using namespace mlir::dataflow; |
| 25 | |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | // ConstantValue |
| 28 | //===----------------------------------------------------------------------===// |
| 29 | |
| 30 | void ConstantValue::print(raw_ostream &os) const { |
| 31 | if (isUninitialized()) { |
| 32 | os << "<UNINITIALIZED>" ; |
| 33 | return; |
| 34 | } |
| 35 | if (getConstantValue() == nullptr) { |
| 36 | os << "<UNKNOWN>" ; |
| 37 | return; |
| 38 | } |
| 39 | return getConstantValue().print(os); |
| 40 | } |
| 41 | |
| 42 | //===----------------------------------------------------------------------===// |
| 43 | // SparseConstantPropagation |
| 44 | //===----------------------------------------------------------------------===// |
| 45 | |
| 46 | LogicalResult SparseConstantPropagation::visitOperation( |
| 47 | Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, |
| 48 | ArrayRef<Lattice<ConstantValue> *> results) { |
| 49 | LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n" ); |
| 50 | |
| 51 | // Don't try to simulate the results of a region operation as we can't |
| 52 | // guarantee that folding will be out-of-place. We don't allow in-place |
| 53 | // folds as the desire here is for simulated execution, and not general |
| 54 | // folding. |
| 55 | if (op->getNumRegions()) { |
| 56 | setAllToEntryStates(results); |
| 57 | return success(); |
| 58 | } |
| 59 | |
| 60 | SmallVector<Attribute, 8> constantOperands; |
| 61 | constantOperands.reserve(N: op->getNumOperands()); |
| 62 | for (auto *operandLattice : operands) { |
| 63 | if (operandLattice->getValue().isUninitialized()) |
| 64 | return success(); |
| 65 | constantOperands.push_back(Elt: operandLattice->getValue().getConstantValue()); |
| 66 | } |
| 67 | |
| 68 | // Save the original operands and attributes just in case the operation |
| 69 | // folds in-place. The constant passed in may not correspond to the real |
| 70 | // runtime value, so in-place updates are not allowed. |
| 71 | SmallVector<Value, 8> originalOperands(op->getOperands()); |
| 72 | DictionaryAttr originalAttrs = op->getAttrDictionary(); |
| 73 | |
| 74 | // Simulate the result of folding this operation to a constant. If folding |
| 75 | // fails or was an in-place fold, mark the results as overdefined. |
| 76 | SmallVector<OpFoldResult, 8> foldResults; |
| 77 | foldResults.reserve(N: op->getNumResults()); |
| 78 | if (failed(Result: op->fold(operands: constantOperands, results&: foldResults))) { |
| 79 | setAllToEntryStates(results); |
| 80 | return success(); |
| 81 | } |
| 82 | |
| 83 | // If the folding was in-place, mark the results as overdefined and reset |
| 84 | // the operation. We don't allow in-place folds as the desire here is for |
| 85 | // simulated execution, and not general folding. |
| 86 | if (foldResults.empty()) { |
| 87 | op->setOperands(originalOperands); |
| 88 | op->setAttrs(originalAttrs); |
| 89 | setAllToEntryStates(results); |
| 90 | return success(); |
| 91 | } |
| 92 | |
| 93 | // Merge the fold results into the lattice for this operation. |
| 94 | assert(foldResults.size() == op->getNumResults() && "invalid result size" ); |
| 95 | for (const auto it : llvm::zip(t&: results, u&: foldResults)) { |
| 96 | Lattice<ConstantValue> *lattice = std::get<0>(t: it); |
| 97 | |
| 98 | // Merge in the result of the fold, either a constant or a value. |
| 99 | OpFoldResult foldResult = std::get<1>(t: it); |
| 100 | if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: foldResult)) { |
| 101 | LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n" ); |
| 102 | propagateIfChanged(state: lattice, |
| 103 | changed: lattice->join(rhs: ConstantValue(attr, op->getDialect()))); |
| 104 | } else { |
| 105 | LLVM_DEBUG(llvm::dbgs() |
| 106 | << "Folded to value: " << cast<Value>(foldResult) << "\n" ); |
| 107 | AbstractSparseForwardDataFlowAnalysis::join( |
| 108 | lhs: lattice, rhs: *getLatticeElement(value: cast<Value>(Val&: foldResult))); |
| 109 | } |
| 110 | } |
| 111 | return success(); |
| 112 | } |
| 113 | |
| 114 | void SparseConstantPropagation::setToEntryState( |
| 115 | Lattice<ConstantValue> *lattice) { |
| 116 | propagateIfChanged(state: lattice, |
| 117 | changed: lattice->join(rhs: ConstantValue::getUnknownConstant())); |
| 118 | } |
| 119 | |