1//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10#include "mlir/Analysis/DataFlow/Utils.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Interfaces/SideEffectInterfaces.h"
13#include "mlir/Pass/Pass.h"
14
15using namespace mlir;
16using namespace mlir::dataflow;
17
18namespace {
19
20/// Lattice value storing the a set of memory resources that something
21/// is written to.
22struct WrittenToLatticeValue {
23 bool operator==(const WrittenToLatticeValue &other) {
24 return this->writes == other.writes;
25 }
26
27 static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
28 const WrittenToLatticeValue &rhs) {
29 WrittenToLatticeValue res = lhs;
30 (void)res.addWrites(writes: rhs.writes);
31
32 return res;
33 }
34
35 static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
36 const WrittenToLatticeValue &rhs) {
37 // Should not be triggered by this test, but required by `Lattice<T>`
38 llvm_unreachable("Join should not be triggered by this test");
39 }
40
41 ChangeResult addWrites(const SetVector<StringAttr> &writes) {
42 int sizeBefore = this->writes.size();
43 this->writes.insert_range(R: writes);
44 int sizeAfter = this->writes.size();
45 return sizeBefore == sizeAfter ? ChangeResult::NoChange
46 : ChangeResult::Change;
47 }
48
49 void print(raw_ostream &os) const {
50 os << "[";
51 llvm::interleave(
52 c: writes, os, each_fn: [&](const StringAttr &a) { os << a.str(); }, separator: " ");
53 os << "]";
54 }
55
56 void clear() { writes.clear(); }
57
58 SetVector<StringAttr> writes;
59};
60
61/// This lattice represents, for a given value, the set of memory resources that
62/// this value, or anything derived from this value, is potentially written to.
63struct WrittenTo : public Lattice<WrittenToLatticeValue> {
64 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
65 using Lattice::Lattice;
66};
67
68/// An analysis that, by going backwards along the dataflow graph, annotates
69/// each value with all the memory resources it (or anything derived from it)
70/// is eventually written to.
71class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
72public:
73 WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
74 bool assumeFuncWrites)
75 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
76 assumeFuncWrites(assumeFuncWrites) {}
77
78 LogicalResult visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
79 ArrayRef<const WrittenTo *> results) override;
80
81 void visitBranchOperand(OpOperand &operand) override;
82
83 void visitCallOperand(OpOperand &operand) override;
84
85 void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
86 ArrayRef<const WrittenTo *> results) override;
87
88 void setToExitState(WrittenTo *lattice) override {
89 lattice->getValue().clear();
90 }
91
92private:
93 bool assumeFuncWrites;
94};
95
96LogicalResult
97WrittenToAnalysis::visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
98 ArrayRef<const WrittenTo *> results) {
99 if (auto store = dyn_cast<memref::StoreOp>(Val: op)) {
100 SetVector<StringAttr> newWrites;
101 newWrites.insert(X: op->getAttrOfType<StringAttr>(name: "tag_name"));
102 propagateIfChanged(state: operands[0],
103 changed: operands[0]->getValue().addWrites(writes: newWrites));
104 return success();
105 } // By default, every result of an op depends on every operand.
106 for (const WrittenTo *r : results) {
107 for (WrittenTo *operand : operands) {
108 meet(lhs: operand, rhs: *r);
109 }
110 addDependency(state: const_cast<WrittenTo *>(r), point: getProgramPointAfter(op));
111 }
112 return success();
113}
114
115void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
116 // Mark branch operands as "brancharg%d", with %d the operand number.
117 WrittenTo *lattice = getLatticeElement(value: operand.get());
118 SetVector<StringAttr> newWrites;
119 newWrites.insert(
120 X: StringAttr::get(context: operand.getOwner()->getContext(),
121 bytes: "brancharg" + Twine(operand.getOperandNumber())));
122 propagateIfChanged(state: lattice, changed: lattice->getValue().addWrites(writes: newWrites));
123}
124
125void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
126 // Mark call operands as "callarg%d", with %d the operand number.
127 WrittenTo *lattice = getLatticeElement(value: operand.get());
128 SetVector<StringAttr> newWrites;
129 newWrites.insert(
130 X: StringAttr::get(context: operand.getOwner()->getContext(),
131 bytes: "callarg" + Twine(operand.getOperandNumber())));
132 propagateIfChanged(state: lattice, changed: lattice->getValue().addWrites(writes: newWrites));
133}
134
135void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
136 ArrayRef<WrittenTo *> operands,
137 ArrayRef<const WrittenTo *> results) {
138 if (!assumeFuncWrites) {
139 return SparseBackwardDataFlowAnalysis::visitExternalCall(call, argumentLattices: operands,
140 resultLattices: results);
141 }
142
143 for (WrittenTo *lattice : operands) {
144 SetVector<StringAttr> newWrites;
145 StringAttr name = call->getAttrOfType<StringAttr>(name: "tag_name");
146 if (!name) {
147 name = StringAttr::get(context: call->getContext(),
148 bytes: call.getOperation()->getName().getStringRef());
149 }
150 newWrites.insert(X: name);
151 propagateIfChanged(state: lattice, changed: lattice->getValue().addWrites(writes: newWrites));
152 }
153}
154
155} // end anonymous namespace
156
157namespace {
158struct TestWrittenToPass
159 : public PassWrapper<TestWrittenToPass, OperationPass<>> {
160 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
161
162 TestWrittenToPass() = default;
163 TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
164 interprocedural = other.interprocedural;
165 assumeFuncWrites = other.assumeFuncWrites;
166 }
167
168 StringRef getArgument() const override { return "test-written-to"; }
169
170 Option<bool> interprocedural{
171 *this, "interprocedural", llvm::cl::init(Val: true),
172 llvm::cl::desc("perform interprocedural analysis")};
173 Option<bool> assumeFuncWrites{
174 *this, "assume-func-writes", llvm::cl::init(Val: false),
175 llvm::cl::desc(
176 "assume external functions have write effect on all arguments")};
177
178 void runOnOperation() override {
179 Operation *op = getOperation();
180
181 SymbolTableCollection symbolTable;
182
183 DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
184 loadBaselineAnalyses(solver);
185 solver.load<WrittenToAnalysis>(args&: symbolTable, args&: assumeFuncWrites);
186 if (failed(Result: solver.initializeAndRun(top: op)))
187 return signalPassFailure();
188
189 raw_ostream &os = llvm::outs();
190 op->walk(callback: [&](Operation *op) {
191 auto tag = op->getAttrOfType<StringAttr>(name: "tag");
192 if (!tag)
193 return;
194 os << "test_tag: " << tag.getValue() << ":\n";
195 for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) {
196 const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(anchor: operand);
197 assert(writtenTo && "expected a sparse lattice");
198 os << " operand #" << index << ": ";
199 writtenTo->print(os);
200 os << "\n";
201 }
202 for (auto [index, operand] : llvm::enumerate(First: op->getResults())) {
203 const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(anchor: operand);
204 assert(writtenTo && "expected a sparse lattice");
205 os << " result #" << index << ": ";
206 writtenTo->print(os);
207 os << "\n";
208 }
209 });
210 }
211};
212} // end anonymous namespace
213
214namespace mlir {
215namespace test {
216void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
217} // end namespace test
218} // end namespace mlir
219

source code of mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp