1 | //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MLProgram/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/MLProgram/IR/MLProgram.h" |
12 | #include "mlir/Dialect/MLProgram/Transforms/Passes.h" |
13 | #include "mlir/IR/BuiltinOps.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
16 | |
17 | namespace mlir { |
18 | namespace ml_program { |
19 | #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS |
20 | #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" |
21 | |
22 | namespace { |
23 | |
24 | class MLProgramPipelineGlobals |
25 | : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> { |
26 | public: |
27 | void runOnOperation() override; |
28 | |
29 | private: |
30 | LogicalResult buildGlobalMap(ModuleOp op); |
31 | |
32 | void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, |
33 | llvm::DenseSet<SymbolRefAttr> &symbolStore); |
34 | |
35 | llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap; |
36 | llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap; |
37 | }; |
38 | |
39 | // Traverses upwards searchign for the operation mapped by the symbol. |
40 | static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) { |
41 | for (auto *op = baseOp; op; op = op->getParentOp()) { |
42 | auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol); |
43 | if (lookup) |
44 | return lookup; |
45 | } |
46 | return nullptr; |
47 | } |
48 | |
49 | // Builds map from a symbol to MLProgram global symbols loaded or stored |
50 | // during processing. |
51 | LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) { |
52 | llvm::DenseMap<SymbolRefAttr, Operation *> callableMap; |
53 | auto res = module->walk([&](Operation *op) { |
54 | if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) { |
55 | auto callable = caller.getCallableForCallee(); |
56 | // For now we do not know how to handle Value based tracing, so fail. |
57 | if (mlir::isa<Value>(callable)) { |
58 | return WalkResult::interrupt(); |
59 | } |
60 | |
61 | auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable); |
62 | auto *func = getFromSymbol(op, symbol); |
63 | callableMap[symbol] = func; |
64 | } |
65 | return WalkResult::advance(); |
66 | }); |
67 | |
68 | if (res.wasInterrupted()) { |
69 | return failure(); |
70 | } |
71 | |
72 | // First grab all symbols loaded or stored by each function. This |
73 | // will not handle calls initially. |
74 | llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols; |
75 | llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols; |
76 | for (auto callable : callableMap) { |
77 | llvm::DenseSet<SymbolRefAttr> loadSymbols; |
78 | llvm::DenseSet<SymbolRefAttr> storeSymbols; |
79 | |
80 | callable.getSecond()->walk( |
81 | [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); }); |
82 | |
83 | callable.getSecond()->walk( |
84 | [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); }); |
85 | |
86 | opLoadSymbols[callable.getFirst()] = std::move(loadSymbols); |
87 | opStoreSymbols[callable.getFirst()] = std::move(storeSymbols); |
88 | } |
89 | |
90 | // For each callable function we find each global loaded/stored within the |
91 | // function or a nested called function. This includes recursion checking to |
92 | // avoid infinitely recursing. |
93 | for (auto callable : callableMap) { |
94 | SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first); |
95 | llvm::SmallVector<SymbolRefAttr> work = {thisSymbol}; |
96 | llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol}; |
97 | llvm::DenseSet<SymbolRefAttr> loadSymbols; |
98 | llvm::DenseSet<SymbolRefAttr> storeSymbols; |
99 | |
100 | for (size_t i = 0; i < work.size(); ++i) { |
101 | callableMap[work[i]]->walk([&](CallOpInterface call) { |
102 | auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee()); |
103 | if (!visited.contains(symbol)) { |
104 | visited.insert(symbol); |
105 | work.push_back(symbol); |
106 | } |
107 | }); |
108 | |
109 | for (auto load : opLoadSymbols[work[i]]) |
110 | loadSymbols.insert(load); |
111 | |
112 | for (auto store : opStoreSymbols[work[i]]) |
113 | storeSymbols.insert(store); |
114 | } |
115 | |
116 | loadSymbolsMap[thisSymbol] = std::move(loadSymbols); |
117 | storeSymbolsMap[thisSymbol] = std::move(storeSymbols); |
118 | } |
119 | |
120 | return success(); |
121 | } |
122 | |
123 | // Process each operation in the block deleting unneeded loads / stores, |
124 | // recursing on subblocks and checking function calls. |
125 | void MLProgramPipelineGlobals::processBlock( |
126 | Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, |
127 | llvm::DenseSet<SymbolRefAttr> &symbolStore) { |
128 | |
129 | llvm::DenseMap<SymbolRefAttr, Value> previousLoads; |
130 | llvm::DenseMap<SymbolRefAttr, Operation *> previousStores; |
131 | llvm::SmallVector<Operation *> toDelete; |
132 | for (auto &op : block) { |
133 | // If this is a global load, remap to a previous value if known |
134 | // and delete this load. Remember that this value is the currently |
135 | // known load. |
136 | if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) { |
137 | auto ref = load.getGlobal(); |
138 | symbolLoad.insert(ref); |
139 | if (previousLoads.contains(ref)) { |
140 | toDelete.push_back(&op); |
141 | load.getResult().replaceAllUsesWith(previousLoads[ref]); |
142 | } else { |
143 | previousLoads[ref] = load.getResult(); |
144 | } |
145 | continue; |
146 | } |
147 | |
148 | // Delete a previous store if it exists and is not needed, update |
149 | // the most recent known value for this global ref. |
150 | if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) { |
151 | auto ref = store.getGlobal(); |
152 | symbolStore.insert(ref); |
153 | if (previousStores.contains(ref)) { |
154 | toDelete.push_back(previousStores.find(ref)->getSecond()); |
155 | } |
156 | |
157 | previousLoads[ref] = store.getValue(); |
158 | previousStores[ref] = &op; |
159 | continue; |
160 | } |
161 | |
162 | // If a function is called, clear known values for loads/stores used by |
163 | // the function or its sub-functions. |
164 | if (auto call = mlir::dyn_cast<CallOpInterface>(op)) { |
165 | auto loadSymbols = |
166 | loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; |
167 | auto storeSymbols = |
168 | storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; |
169 | |
170 | for (auto sym : loadSymbols) { |
171 | previousStores.erase(sym); |
172 | } |
173 | |
174 | for (auto sym : storeSymbols) { |
175 | previousLoads.erase(sym); |
176 | previousStores.erase(sym); |
177 | } |
178 | continue; |
179 | } |
180 | |
181 | // If the op has sub-regions, recurse inside. We make no guarantees whether |
182 | // the recursion occurs. |
183 | llvm::DenseSet<SymbolRefAttr> opSymbolLoad; |
184 | llvm::DenseSet<SymbolRefAttr> opSymbolStore; |
185 | for (auto ®ion : op.getRegions()) { |
186 | for (auto &block : region) { |
187 | processBlock(block, opSymbolLoad, opSymbolStore); |
188 | } |
189 | } |
190 | |
191 | // Update current state from the subblock. |
192 | for (auto change : opSymbolLoad) { |
193 | symbolLoad.insert(change); |
194 | previousStores.erase(change); |
195 | } |
196 | |
197 | for (auto change : opSymbolStore) { |
198 | symbolStore.insert(change); |
199 | previousLoads.erase(change); |
200 | previousStores.erase(change); |
201 | } |
202 | } |
203 | |
204 | for (auto *op : toDelete) { |
205 | op->erase(); |
206 | } |
207 | } |
208 | |
209 | void MLProgramPipelineGlobals::runOnOperation() { |
210 | auto targetOp = getOperation(); |
211 | if (failed(buildGlobalMap(module: targetOp))) { |
212 | return; |
213 | } |
214 | |
215 | for (auto &funcOp : *targetOp.getBody()) { |
216 | for (auto ®ion : funcOp.getRegions()) { |
217 | for (auto &block : region.getBlocks()) { |
218 | llvm::DenseSet<SymbolRefAttr> symbolsLoaded; |
219 | llvm::DenseSet<SymbolRefAttr> symbolsStored; |
220 | processBlock(block, symbolsLoaded, symbolsStored); |
221 | } |
222 | } |
223 | } |
224 | } |
225 | |
226 | } // namespace |
227 | |
228 | std::unique_ptr<OperationPass<mlir::ModuleOp>> |
229 | createMLProgramPipelineGlobalsPass() { |
230 | return std::make_unique<MLProgramPipelineGlobals>(); |
231 | } |
232 | |
233 | } // namespace ml_program |
234 | } // namespace mlir |
235 | |