1//===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
11#include "mlir/IR/BuiltinAttributes.h"
12#include "mlir/IR/OpDefinition.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Support/LLVM.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Support/Casting.h"
18#include "llvm/Support/Debug.h"
19#include <cassert>
20
21#define DEBUG_TYPE "constant-propagation"
22
23using namespace mlir;
24using namespace mlir::dataflow;
25
26//===----------------------------------------------------------------------===//
27// ConstantValue
28//===----------------------------------------------------------------------===//
29
30void ConstantValue::print(raw_ostream &os) const {
31 if (isUninitialized()) {
32 os << "<UNINITIALIZED>";
33 return;
34 }
35 if (getConstantValue() == nullptr) {
36 os << "<UNKNOWN>";
37 return;
38 }
39 return getConstantValue().print(os);
40}
41
42//===----------------------------------------------------------------------===//
43// SparseConstantPropagation
44//===----------------------------------------------------------------------===//
45
46LogicalResult SparseConstantPropagation::visitOperation(
47 Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
48 ArrayRef<Lattice<ConstantValue> *> results) {
49 LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
50
51 // Don't try to simulate the results of a region operation as we can't
52 // guarantee that folding will be out-of-place. We don't allow in-place
53 // folds as the desire here is for simulated execution, and not general
54 // folding.
55 if (op->getNumRegions()) {
56 setAllToEntryStates(results);
57 return success();
58 }
59
60 SmallVector<Attribute, 8> constantOperands;
61 constantOperands.reserve(N: op->getNumOperands());
62 for (auto *operandLattice : operands) {
63 if (operandLattice->getValue().isUninitialized())
64 return success();
65 constantOperands.push_back(Elt: operandLattice->getValue().getConstantValue());
66 }
67
68 // Save the original operands and attributes just in case the operation
69 // folds in-place. The constant passed in may not correspond to the real
70 // runtime value, so in-place updates are not allowed.
71 SmallVector<Value, 8> originalOperands(op->getOperands());
72 DictionaryAttr originalAttrs = op->getAttrDictionary();
73
74 // Simulate the result of folding this operation to a constant. If folding
75 // fails or was an in-place fold, mark the results as overdefined.
76 SmallVector<OpFoldResult, 8> foldResults;
77 foldResults.reserve(N: op->getNumResults());
78 if (failed(Result: op->fold(operands: constantOperands, results&: foldResults))) {
79 setAllToEntryStates(results);
80 return success();
81 }
82
83 // If the folding was in-place, mark the results as overdefined and reset
84 // the operation. We don't allow in-place folds as the desire here is for
85 // simulated execution, and not general folding.
86 if (foldResults.empty()) {
87 op->setOperands(originalOperands);
88 op->setAttrs(originalAttrs);
89 setAllToEntryStates(results);
90 return success();
91 }
92
93 // Merge the fold results into the lattice for this operation.
94 assert(foldResults.size() == op->getNumResults() && "invalid result size");
95 for (const auto it : llvm::zip(t&: results, u&: foldResults)) {
96 Lattice<ConstantValue> *lattice = std::get<0>(t: it);
97
98 // Merge in the result of the fold, either a constant or a value.
99 OpFoldResult foldResult = std::get<1>(t: it);
100 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: foldResult)) {
101 LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
102 propagateIfChanged(state: lattice,
103 changed: lattice->join(rhs: ConstantValue(attr, op->getDialect())));
104 } else {
105 LLVM_DEBUG(llvm::dbgs()
106 << "Folded to value: " << cast<Value>(foldResult) << "\n");
107 AbstractSparseForwardDataFlowAnalysis::join(
108 lhs: lattice, rhs: *getLatticeElement(value: cast<Value>(Val&: foldResult)));
109 }
110 }
111 return success();
112}
113
114void SparseConstantPropagation::setToEntryState(
115 Lattice<ConstantValue> *lattice) {
116 propagateIfChanged(state: lattice,
117 changed: lattice->join(rhs: ConstantValue::getUnknownConstant()));
118}
119

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp