1//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
10
11#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
12#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17namespace mlir {
18namespace ml_program {
19#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
20#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21
22namespace {
23
24class MLProgramPipelineGlobals
25 : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
26public:
27 void runOnOperation() override;
28
29private:
30 LogicalResult buildGlobalMap(ModuleOp op);
31
32 void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
33 llvm::DenseSet<SymbolRefAttr> &symbolStore);
34
35 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
36 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
37};
38
39// Traverses upwards searchign for the operation mapped by the symbol.
40static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41 for (auto *op = baseOp; op; op = op->getParentOp()) {
42 auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
43 if (lookup)
44 return lookup;
45 }
46 return nullptr;
47}
48
49// Builds map from a symbol to MLProgram global symbols loaded or stored
50// during processing.
51LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
52 llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
53 auto res = module->walk([&](Operation *op) {
54 if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55 auto callable = caller.getCallableForCallee();
56 // For now we do not know how to handle Value based tracing, so fail.
57 if (mlir::isa<Value>(callable)) {
58 return WalkResult::interrupt();
59 }
60
61 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62 auto *func = getFromSymbol(op, symbol);
63 callableMap[symbol] = func;
64 }
65 return WalkResult::advance();
66 });
67
68 if (res.wasInterrupted()) {
69 return failure();
70 }
71
72 // First grab all symbols loaded or stored by each function. This
73 // will not handle calls initially.
74 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
75 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
76 for (auto callable : callableMap) {
77 llvm::DenseSet<SymbolRefAttr> loadSymbols;
78 llvm::DenseSet<SymbolRefAttr> storeSymbols;
79
80 callable.getSecond()->walk(
81 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
82
83 callable.getSecond()->walk(
84 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
85
86 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
88 }
89
90 // For each callable function we find each global loaded/stored within the
91 // function or a nested called function. This includes recursion checking to
92 // avoid infinitely recursing.
93 for (auto callable : callableMap) {
94 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95 llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96 llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
97 llvm::DenseSet<SymbolRefAttr> loadSymbols;
98 llvm::DenseSet<SymbolRefAttr> storeSymbols;
99
100 for (size_t i = 0; i < work.size(); ++i) {
101 callableMap[work[i]]->walk([&](CallOpInterface call) {
102 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103 if (!visited.contains(symbol)) {
104 visited.insert(symbol);
105 work.push_back(symbol);
106 }
107 });
108
109 for (auto load : opLoadSymbols[work[i]])
110 loadSymbols.insert(load);
111
112 for (auto store : opStoreSymbols[work[i]])
113 storeSymbols.insert(store);
114 }
115
116 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
117 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
118 }
119
120 return success();
121}
122
123// Process each operation in the block deleting unneeded loads / stores,
124// recursing on subblocks and checking function calls.
125void MLProgramPipelineGlobals::processBlock(
126 Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
127 llvm::DenseSet<SymbolRefAttr> &symbolStore) {
128
129 llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
130 llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
131 llvm::SmallVector<Operation *> toDelete;
132 for (auto &op : block) {
133 // If this is a global load, remap to a previous value if known
134 // and delete this load. Remember that this value is the currently
135 // known load.
136 if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
137 auto ref = load.getGlobal();
138 symbolLoad.insert(ref);
139 if (previousLoads.contains(ref)) {
140 toDelete.push_back(&op);
141 load.getResult().replaceAllUsesWith(previousLoads[ref]);
142 } else {
143 previousLoads[ref] = load.getResult();
144 }
145 continue;
146 }
147
148 // Delete a previous store if it exists and is not needed, update
149 // the most recent known value for this global ref.
150 if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
151 auto ref = store.getGlobal();
152 symbolStore.insert(ref);
153 if (previousStores.contains(ref)) {
154 toDelete.push_back(previousStores.find(ref)->getSecond());
155 }
156
157 previousLoads[ref] = store.getValue();
158 previousStores[ref] = &op;
159 continue;
160 }
161
162 // If a function is called, clear known values for loads/stores used by
163 // the function or its sub-functions.
164 if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
165 auto loadSymbols =
166 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
167 auto storeSymbols =
168 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
169
170 for (auto sym : loadSymbols) {
171 previousStores.erase(sym);
172 }
173
174 for (auto sym : storeSymbols) {
175 previousLoads.erase(sym);
176 previousStores.erase(sym);
177 }
178 continue;
179 }
180
181 // If the op has sub-regions, recurse inside. We make no guarantees whether
182 // the recursion occurs.
183 llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
184 llvm::DenseSet<SymbolRefAttr> opSymbolStore;
185 for (auto &region : op.getRegions()) {
186 for (auto &block : region) {
187 processBlock(block, opSymbolLoad, opSymbolStore);
188 }
189 }
190
191 // Update current state from the subblock.
192 for (auto change : opSymbolLoad) {
193 symbolLoad.insert(change);
194 previousStores.erase(change);
195 }
196
197 for (auto change : opSymbolStore) {
198 symbolStore.insert(change);
199 previousLoads.erase(change);
200 previousStores.erase(change);
201 }
202 }
203
204 for (auto *op : toDelete) {
205 op->erase();
206 }
207}
208
209void MLProgramPipelineGlobals::runOnOperation() {
210 auto targetOp = getOperation();
211 if (failed(buildGlobalMap(module: targetOp))) {
212 return;
213 }
214
215 for (auto &funcOp : *targetOp.getBody()) {
216 for (auto &region : funcOp.getRegions()) {
217 for (auto &block : region.getBlocks()) {
218 llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
219 llvm::DenseSet<SymbolRefAttr> symbolsStored;
220 processBlock(block, symbolsLoaded, symbolsStored);
221 }
222 }
223 }
224}
225
226} // namespace
227
228std::unique_ptr<OperationPass<mlir::ModuleOp>>
229createMLProgramPipelineGlobalsPass() {
230 return std::make_unique<MLProgramPipelineGlobals>();
231}
232
233} // namespace ml_program
234} // namespace mlir
235

source code of mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp