1//===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
11#include "mlir/IR/BuiltinAttributes.h"
12#include "mlir/IR/OpDefinition.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16#include "mlir/Support/LogicalResult.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/Support/Casting.h"
19#include "llvm/Support/Debug.h"
20#include <cassert>
21
22#define DEBUG_TYPE "constant-propagation"
23
24using namespace mlir;
25using namespace mlir::dataflow;
26
27//===----------------------------------------------------------------------===//
28// ConstantValue
29//===----------------------------------------------------------------------===//
30
31void ConstantValue::print(raw_ostream &os) const {
32 if (isUninitialized()) {
33 os << "<UNINITIALIZED>";
34 return;
35 }
36 if (getConstantValue() == nullptr) {
37 os << "<UNKNOWN>";
38 return;
39 }
40 return getConstantValue().print(os);
41}
42
43//===----------------------------------------------------------------------===//
44// SparseConstantPropagation
45//===----------------------------------------------------------------------===//
46
47void SparseConstantPropagation::visitOperation(
48 Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
49 ArrayRef<Lattice<ConstantValue> *> results) {
50 LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
51
52 // Don't try to simulate the results of a region operation as we can't
53 // guarantee that folding will be out-of-place. We don't allow in-place
54 // folds as the desire here is for simulated execution, and not general
55 // folding.
56 if (op->getNumRegions()) {
57 setAllToEntryStates(results);
58 return;
59 }
60
61 SmallVector<Attribute, 8> constantOperands;
62 constantOperands.reserve(N: op->getNumOperands());
63 for (auto *operandLattice : operands) {
64 if (operandLattice->getValue().isUninitialized())
65 return;
66 constantOperands.push_back(Elt: operandLattice->getValue().getConstantValue());
67 }
68
69 // Save the original operands and attributes just in case the operation
70 // folds in-place. The constant passed in may not correspond to the real
71 // runtime value, so in-place updates are not allowed.
72 SmallVector<Value, 8> originalOperands(op->getOperands());
73 DictionaryAttr originalAttrs = op->getAttrDictionary();
74
75 // Simulate the result of folding this operation to a constant. If folding
76 // fails or was an in-place fold, mark the results as overdefined.
77 SmallVector<OpFoldResult, 8> foldResults;
78 foldResults.reserve(N: op->getNumResults());
79 if (failed(result: op->fold(operands: constantOperands, results&: foldResults))) {
80 setAllToEntryStates(results);
81 return;
82 }
83
84 // If the folding was in-place, mark the results as overdefined and reset
85 // the operation. We don't allow in-place folds as the desire here is for
86 // simulated execution, and not general folding.
87 if (foldResults.empty()) {
88 op->setOperands(originalOperands);
89 op->setAttrs(originalAttrs);
90 setAllToEntryStates(results);
91 return;
92 }
93
94 // Merge the fold results into the lattice for this operation.
95 assert(foldResults.size() == op->getNumResults() && "invalid result size");
96 for (const auto it : llvm::zip(t&: results, u&: foldResults)) {
97 Lattice<ConstantValue> *lattice = std::get<0>(t: it);
98
99 // Merge in the result of the fold, either a constant or a value.
100 OpFoldResult foldResult = std::get<1>(t: it);
101 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: foldResult)) {
102 LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
103 propagateIfChanged(lattice,
104 lattice->join(rhs: ConstantValue(attr, op->getDialect())));
105 } else {
106 LLVM_DEBUG(llvm::dbgs()
107 << "Folded to value: " << foldResult.get<Value>() << "\n");
108 AbstractSparseForwardDataFlowAnalysis::join(
109 lattice, *getLatticeElement(foldResult.get<Value>()));
110 }
111 }
112}
113
114void SparseConstantPropagation::setToEntryState(
115 Lattice<ConstantValue> *lattice) {
116 propagateIfChanged(lattice,
117 lattice->join(rhs: ConstantValue::getUnknownConstant()));
118}
119

source code of mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp