1//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Interfaces/SideEffectInterfaces.h"
14#include "mlir/Pass/Pass.h"
15
16using namespace mlir;
17using namespace mlir::dataflow;
18
19namespace {
20
21/// This lattice represents, for a given value, the set of memory resources that
22/// this value, or anything derived from this value, is potentially written to.
23struct WrittenTo : public AbstractSparseLattice {
24 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
25 using AbstractSparseLattice::AbstractSparseLattice;
26
27 void print(raw_ostream &os) const override {
28 os << "[";
29 llvm::interleave(
30 writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
31 os << "]";
32 }
33 ChangeResult addWrites(const SetVector<StringAttr> &writes) {
34 int sizeBefore = this->writes.size();
35 this->writes.insert(writes.begin(), writes.end());
36 int sizeAfter = this->writes.size();
37 return sizeBefore == sizeAfter ? ChangeResult::NoChange
38 : ChangeResult::Change;
39 }
40 ChangeResult meet(const AbstractSparseLattice &other) override {
41 const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
42 return addWrites(writes: rhs->writes);
43 }
44
45 SetVector<StringAttr> writes;
46};
47
48/// An analysis that, by going backwards along the dataflow graph, annotates
49/// each value with all the memory resources it (or anything derived from it)
50/// is eventually written to.
51class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
52public:
53 WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
54 bool assumeFuncWrites)
55 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
56 assumeFuncWrites(assumeFuncWrites) {}
57
58 void visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
59 ArrayRef<const WrittenTo *> results) override;
60
61 void visitBranchOperand(OpOperand &operand) override;
62
63 void visitCallOperand(OpOperand &operand) override;
64
65 void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
66 ArrayRef<const WrittenTo *> results) override;
67
68 void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
69
70private:
71 bool assumeFuncWrites;
72};
73
74void WrittenToAnalysis::visitOperation(Operation *op,
75 ArrayRef<WrittenTo *> operands,
76 ArrayRef<const WrittenTo *> results) {
77 if (auto store = dyn_cast<memref::StoreOp>(op)) {
78 SetVector<StringAttr> newWrites;
79 newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
80 propagateIfChanged(operands[0], operands[0]->addWrites(writes: newWrites));
81 return;
82 } // By default, every result of an op depends on every operand.
83 for (const WrittenTo *r : results) {
84 for (WrittenTo *operand : operands) {
85 meet(operand, *r);
86 }
87 addDependency(const_cast<WrittenTo *>(r), op);
88 }
89}
90
91void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
92 // Mark branch operands as "brancharg%d", with %d the operand number.
93 WrittenTo *lattice = getLatticeElement(operand.get());
94 SetVector<StringAttr> newWrites;
95 newWrites.insert(
96 StringAttr::get(operand.getOwner()->getContext(),
97 "brancharg" + Twine(operand.getOperandNumber())));
98 propagateIfChanged(lattice, lattice->addWrites(writes: newWrites));
99}
100
101void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
102 // Mark call operands as "callarg%d", with %d the operand number.
103 WrittenTo *lattice = getLatticeElement(operand.get());
104 SetVector<StringAttr> newWrites;
105 newWrites.insert(
106 StringAttr::get(operand.getOwner()->getContext(),
107 "callarg" + Twine(operand.getOperandNumber())));
108 propagateIfChanged(lattice, lattice->addWrites(writes: newWrites));
109}
110
111void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
112 ArrayRef<WrittenTo *> operands,
113 ArrayRef<const WrittenTo *> results) {
114 if (!assumeFuncWrites) {
115 return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands,
116 results);
117 }
118
119 for (WrittenTo *lattice : operands) {
120 SetVector<StringAttr> newWrites;
121 StringAttr name = call->getAttrOfType<StringAttr>("tag_name");
122 if (!name) {
123 name = StringAttr::get(call->getContext(),
124 call.getOperation()->getName().getStringRef());
125 }
126 newWrites.insert(name);
127 propagateIfChanged(lattice, lattice->addWrites(writes: newWrites));
128 }
129}
130
131} // end anonymous namespace
132
133namespace {
134struct TestWrittenToPass
135 : public PassWrapper<TestWrittenToPass, OperationPass<>> {
136 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
137
138 TestWrittenToPass() = default;
139 TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
140 interprocedural = other.interprocedural;
141 assumeFuncWrites = other.assumeFuncWrites;
142 }
143
144 StringRef getArgument() const override { return "test-written-to"; }
145
146 Option<bool> interprocedural{
147 *this, "interprocedural", llvm::cl::init(Val: true),
148 llvm::cl::desc("perform interprocedural analysis")};
149 Option<bool> assumeFuncWrites{
150 *this, "assume-func-writes", llvm::cl::init(Val: false),
151 llvm::cl::desc(
152 "assume external functions have write effect on all arguments")};
153
154 void runOnOperation() override {
155 Operation *op = getOperation();
156
157 SymbolTableCollection symbolTable;
158
159 DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
160 solver.load<DeadCodeAnalysis>();
161 solver.load<SparseConstantPropagation>();
162 solver.load<WrittenToAnalysis>(args&: symbolTable, args&: assumeFuncWrites);
163 if (failed(result: solver.initializeAndRun(top: op)))
164 return signalPassFailure();
165
166 raw_ostream &os = llvm::outs();
167 op->walk(callback: [&](Operation *op) {
168 auto tag = op->getAttrOfType<StringAttr>("tag");
169 if (!tag)
170 return;
171 os << "test_tag: " << tag.getValue() << ":\n";
172 for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) {
173 const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(point: operand);
174 assert(writtenTo && "expected a sparse lattice");
175 os << " operand #" << index << ": ";
176 writtenTo->print(os);
177 os << "\n";
178 }
179 for (auto [index, operand] : llvm::enumerate(First: op->getResults())) {
180 const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(point: operand);
181 assert(writtenTo && "expected a sparse lattice");
182 os << " result #" << index << ": ";
183 writtenTo->print(os);
184 os << "\n";
185 }
186 });
187 }
188};
189} // end anonymous namespace
190
191namespace mlir {
192namespace test {
193void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
194} // end namespace test
195} // end namespace mlir
196

source code of mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp