1//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Interfaces/SideEffectInterfaces.h"
14#include "mlir/Pass/Pass.h"
15
16using namespace mlir;
17using namespace mlir::dataflow;
18
19namespace {
20
21/// Lattice value storing the a set of memory resources that something
22/// is written to.
23struct WrittenToLatticeValue {
24 bool operator==(const WrittenToLatticeValue &other) {
25 return this->writes == other.writes;
26 }
27
28 static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
29 const WrittenToLatticeValue &rhs) {
30 WrittenToLatticeValue res = lhs;
31 (void)res.addWrites(writes: rhs.writes);
32
33 return res;
34 }
35
36 static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
37 const WrittenToLatticeValue &rhs) {
38 // Should not be triggered by this test, but required by `Lattice<T>`
39 llvm_unreachable("Join should not be triggered by this test");
40 }
41
42 ChangeResult addWrites(const SetVector<StringAttr> &writes) {
43 int sizeBefore = this->writes.size();
44 this->writes.insert_range(writes);
45 int sizeAfter = this->writes.size();
46 return sizeBefore == sizeAfter ? ChangeResult::NoChange
47 : ChangeResult::Change;
48 }
49
50 void print(raw_ostream &os) const {
51 os << "[";
52 llvm::interleave(
53 writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
54 os << "]";
55 }
56
57 void clear() { writes.clear(); }
58
59 SetVector<StringAttr> writes;
60};
61
62/// This lattice represents, for a given value, the set of memory resources that
63/// this value, or anything derived from this value, is potentially written to.
64struct WrittenTo : public Lattice<WrittenToLatticeValue> {
65 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
66 using Lattice::Lattice;
67};
68
69/// An analysis that, by going backwards along the dataflow graph, annotates
70/// each value with all the memory resources it (or anything derived from it)
71/// is eventually written to.
72class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
73public:
74 WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
75 bool assumeFuncWrites)
76 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
77 assumeFuncWrites(assumeFuncWrites) {}
78
79 LogicalResult visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
80 ArrayRef<const WrittenTo *> results) override;
81
82 void visitBranchOperand(OpOperand &operand) override;
83
84 void visitCallOperand(OpOperand &operand) override;
85
86 void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
87 ArrayRef<const WrittenTo *> results) override;
88
89 void setToExitState(WrittenTo *lattice) override {
90 lattice->getValue().clear();
91 }
92
93private:
94 bool assumeFuncWrites;
95};
96
97LogicalResult
98WrittenToAnalysis::visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
99 ArrayRef<const WrittenTo *> results) {
100 if (auto store = dyn_cast<memref::StoreOp>(op)) {
101 SetVector<StringAttr> newWrites;
102 newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
103 propagateIfChanged(state: operands[0],
104 changed: operands[0]->getValue().addWrites(newWrites));
105 return success();
106 } // By default, every result of an op depends on every operand.
107 for (const WrittenTo *r : results) {
108 for (WrittenTo *operand : operands) {
109 meet(operand, *r);
110 }
111 addDependency(const_cast<WrittenTo *>(r), getProgramPointAfter(op));
112 }
113 return success();
114}
115
116void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
117 // Mark branch operands as "brancharg%d", with %d the operand number.
118 WrittenTo *lattice = getLatticeElement(value: operand.get());
119 SetVector<StringAttr> newWrites;
120 newWrites.insert(
121 StringAttr::get(operand.getOwner()->getContext(),
122 "brancharg" + Twine(operand.getOperandNumber())));
123 propagateIfChanged(state: lattice, changed: lattice->getValue().addWrites(newWrites));
124}
125
126void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
127 // Mark call operands as "callarg%d", with %d the operand number.
128 WrittenTo *lattice = getLatticeElement(value: operand.get());
129 SetVector<StringAttr> newWrites;
130 newWrites.insert(
131 StringAttr::get(operand.getOwner()->getContext(),
132 "callarg" + Twine(operand.getOperandNumber())));
133 propagateIfChanged(state: lattice, changed: lattice->getValue().addWrites(newWrites));
134}
135
136void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
137 ArrayRef<WrittenTo *> operands,
138 ArrayRef<const WrittenTo *> results) {
139 if (!assumeFuncWrites) {
140 return SparseBackwardDataFlowAnalysis::visitExternalCall(call: call, argumentLattices: operands,
141 resultLattices: results);
142 }
143
144 for (WrittenTo *lattice : operands) {
145 SetVector<StringAttr> newWrites;
146 StringAttr name = call->getAttrOfType<StringAttr>("tag_name");
147 if (!name) {
148 name = StringAttr::get(call->getContext(),
149 call.getOperation()->getName().getStringRef());
150 }
151 newWrites.insert(name);
152 propagateIfChanged(state: lattice, changed: lattice->getValue().addWrites(newWrites));
153 }
154}
155
156} // end anonymous namespace
157
158namespace {
159struct TestWrittenToPass
160 : public PassWrapper<TestWrittenToPass, OperationPass<>> {
161 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
162
163 TestWrittenToPass() = default;
164 TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
165 interprocedural = other.interprocedural;
166 assumeFuncWrites = other.assumeFuncWrites;
167 }
168
169 StringRef getArgument() const override { return "test-written-to"; }
170
171 Option<bool> interprocedural{
172 *this, "interprocedural", llvm::cl::init(Val: true),
173 llvm::cl::desc("perform interprocedural analysis")};
174 Option<bool> assumeFuncWrites{
175 *this, "assume-func-writes", llvm::cl::init(Val: false),
176 llvm::cl::desc(
177 "assume external functions have write effect on all arguments")};
178
179 void runOnOperation() override {
180 Operation *op = getOperation();
181
182 SymbolTableCollection symbolTable;
183
184 DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
185 solver.load<DeadCodeAnalysis>();
186 solver.load<SparseConstantPropagation>();
187 solver.load<WrittenToAnalysis>(args&: symbolTable, args&: assumeFuncWrites);
188 if (failed(Result: solver.initializeAndRun(top: op)))
189 return signalPassFailure();
190
191 raw_ostream &os = llvm::outs();
192 op->walk(callback: [&](Operation *op) {
193 auto tag = op->getAttrOfType<StringAttr>("tag");
194 if (!tag)
195 return;
196 os << "test_tag: " << tag.getValue() << ":\n";
197 for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) {
198 const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(anchor: operand);
199 assert(writtenTo && "expected a sparse lattice");
200 os << " operand #" << index << ": ";
201 writtenTo->print(os);
202 os << "\n";
203 }
204 for (auto [index, operand] : llvm::enumerate(First: op->getResults())) {
205 const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(anchor: operand);
206 assert(writtenTo && "expected a sparse lattice");
207 os << " result #" << index << ": ";
208 writtenTo->print(os);
209 os << "\n";
210 }
211 });
212 }
213};
214} // end anonymous namespace
215
216namespace mlir {
217namespace test {
218void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
219} // end namespace test
220} // end namespace mlir
221

source code of mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp