1 | //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MLProgram/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/MLProgram/IR/MLProgram.h" |
12 | #include "mlir/Dialect/MLProgram/Transforms/Passes.h" |
13 | #include "mlir/IR/BuiltinOps.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
16 | |
17 | namespace mlir { |
18 | namespace ml_program { |
19 | #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS |
20 | #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" |
21 | |
22 | namespace { |
23 | |
24 | class MLProgramPipelineGlobals |
25 | : public impl::MLProgramPipelineGlobalsPassBase<MLProgramPipelineGlobals> { |
26 | public: |
27 | void runOnOperation() override; |
28 | |
29 | private: |
30 | LogicalResult buildGlobalMap(ModuleOp op); |
31 | |
32 | void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, |
33 | llvm::DenseSet<SymbolRefAttr> &symbolStore); |
34 | |
35 | llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap; |
36 | llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap; |
37 | }; |
38 | |
39 | // Traverses upwards searchign for the operation mapped by the symbol. |
40 | static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) { |
41 | for (auto *op = baseOp; op; op = op->getParentOp()) { |
42 | auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol); |
43 | if (lookup) |
44 | return lookup; |
45 | } |
46 | return nullptr; |
47 | } |
48 | |
49 | // Builds map from a symbol to MLProgram global symbols loaded or stored |
50 | // during processing. |
51 | LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) { |
52 | llvm::DenseMap<SymbolRefAttr, Operation *> callableMap; |
53 | auto res = module->walk([&](Operation *op) { |
54 | if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) { |
55 | auto callable = caller.getCallableForCallee(); |
56 | // For now we do not know how to handle Value based tracing, so fail. |
57 | if (mlir::isa<Value>(callable)) { |
58 | return WalkResult::interrupt(); |
59 | } |
60 | |
61 | auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable); |
62 | auto *func = getFromSymbol(op, symbol); |
63 | callableMap[symbol] = func; |
64 | } |
65 | return WalkResult::advance(); |
66 | }); |
67 | |
68 | if (res.wasInterrupted()) { |
69 | return failure(); |
70 | } |
71 | |
72 | // First grab all symbols loaded or stored by each function. This |
73 | // will not handle calls initially. |
74 | llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols; |
75 | llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols; |
76 | for (auto callable : callableMap) { |
77 | llvm::DenseSet<SymbolRefAttr> loadSymbols; |
78 | llvm::DenseSet<SymbolRefAttr> storeSymbols; |
79 | |
80 | callable.getSecond()->walk( |
81 | [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); }); |
82 | |
83 | callable.getSecond()->walk( |
84 | [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); }); |
85 | |
86 | opLoadSymbols[callable.getFirst()] = std::move(loadSymbols); |
87 | opStoreSymbols[callable.getFirst()] = std::move(storeSymbols); |
88 | } |
89 | |
90 | // For each callable function we find each global loaded/stored within the |
91 | // function or a nested called function. This includes recursion checking to |
92 | // avoid infinitely recursing. |
93 | for (auto callable : callableMap) { |
94 | SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first); |
95 | llvm::SmallVector<SymbolRefAttr> work = {thisSymbol}; |
96 | llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol}; |
97 | llvm::DenseSet<SymbolRefAttr> loadSymbols; |
98 | llvm::DenseSet<SymbolRefAttr> storeSymbols; |
99 | |
100 | for (size_t i = 0; i < work.size(); ++i) { |
101 | callableMap[work[i]]->walk([&](CallOpInterface call) { |
102 | auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee()); |
103 | if (visited.insert(symbol).second) |
104 | work.push_back(symbol); |
105 | }); |
106 | |
107 | loadSymbols.insert_range(opLoadSymbols[work[i]]); |
108 | |
109 | storeSymbols.insert_range(opStoreSymbols[work[i]]); |
110 | } |
111 | |
112 | loadSymbolsMap[thisSymbol] = std::move(loadSymbols); |
113 | storeSymbolsMap[thisSymbol] = std::move(storeSymbols); |
114 | } |
115 | |
116 | return success(); |
117 | } |
118 | |
119 | // Process each operation in the block deleting unneeded loads / stores, |
120 | // recursing on subblocks and checking function calls. |
121 | void MLProgramPipelineGlobals::processBlock( |
122 | Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, |
123 | llvm::DenseSet<SymbolRefAttr> &symbolStore) { |
124 | |
125 | llvm::DenseMap<SymbolRefAttr, Value> previousLoads; |
126 | llvm::DenseMap<SymbolRefAttr, Operation *> previousStores; |
127 | llvm::SmallVector<Operation *> toDelete; |
128 | for (auto &op : block) { |
129 | // If this is a global load, remap to a previous value if known |
130 | // and delete this load. Remember that this value is the currently |
131 | // known load. |
132 | if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) { |
133 | auto ref = load.getGlobal(); |
134 | symbolLoad.insert(ref); |
135 | if (previousLoads.contains(ref)) { |
136 | toDelete.push_back(&op); |
137 | load.getResult().replaceAllUsesWith(previousLoads[ref]); |
138 | } else { |
139 | previousLoads[ref] = load.getResult(); |
140 | } |
141 | continue; |
142 | } |
143 | |
144 | // Delete a previous store if it exists and is not needed, update |
145 | // the most recent known value for this global ref. |
146 | if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) { |
147 | auto ref = store.getGlobal(); |
148 | symbolStore.insert(ref); |
149 | auto it = previousStores.find(ref); |
150 | if (it != previousStores.end()) { |
151 | toDelete.push_back(it->getSecond()); |
152 | } |
153 | |
154 | previousLoads[ref] = store.getValue(); |
155 | previousStores[ref] = &op; |
156 | continue; |
157 | } |
158 | |
159 | // If a function is called, clear known values for loads/stores used by |
160 | // the function or its sub-functions. |
161 | if (auto call = mlir::dyn_cast<CallOpInterface>(op)) { |
162 | auto loadSymbols = |
163 | loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; |
164 | auto storeSymbols = |
165 | storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; |
166 | |
167 | for (auto sym : loadSymbols) { |
168 | previousStores.erase(sym); |
169 | } |
170 | |
171 | for (auto sym : storeSymbols) { |
172 | previousLoads.erase(sym); |
173 | previousStores.erase(sym); |
174 | } |
175 | continue; |
176 | } |
177 | |
178 | // If the op has sub-regions, recurse inside. We make no guarantees whether |
179 | // the recursion occurs. |
180 | llvm::DenseSet<SymbolRefAttr> opSymbolLoad; |
181 | llvm::DenseSet<SymbolRefAttr> opSymbolStore; |
182 | for (auto ®ion : op.getRegions()) { |
183 | for (auto &block : region) { |
184 | processBlock(block, opSymbolLoad, opSymbolStore); |
185 | } |
186 | } |
187 | |
188 | // Update current state from the subblock. |
189 | for (auto change : opSymbolLoad) { |
190 | symbolLoad.insert(change); |
191 | previousStores.erase(change); |
192 | } |
193 | |
194 | for (auto change : opSymbolStore) { |
195 | symbolStore.insert(change); |
196 | previousLoads.erase(change); |
197 | previousStores.erase(change); |
198 | } |
199 | } |
200 | |
201 | for (auto *op : toDelete) { |
202 | op->erase(); |
203 | } |
204 | } |
205 | |
206 | void MLProgramPipelineGlobals::runOnOperation() { |
207 | auto targetOp = getOperation(); |
208 | if (failed(buildGlobalMap(module: targetOp))) { |
209 | return; |
210 | } |
211 | |
212 | for (auto &funcOp : *targetOp.getBody()) { |
213 | for (auto ®ion : funcOp.getRegions()) { |
214 | for (auto &block : region.getBlocks()) { |
215 | llvm::DenseSet<SymbolRefAttr> symbolsLoaded; |
216 | llvm::DenseSet<SymbolRefAttr> symbolsStored; |
217 | processBlock(block, symbolsLoaded, symbolsStored); |
218 | } |
219 | } |
220 | } |
221 | } |
222 | |
223 | } // namespace |
224 | |
225 | } // namespace ml_program |
226 | } // namespace mlir |
227 | |