1//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
10
11#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
12#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17namespace mlir {
18namespace ml_program {
19#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
20#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21
22namespace {
23
24class MLProgramPipelineGlobals
25 : public impl::MLProgramPipelineGlobalsPassBase<MLProgramPipelineGlobals> {
26public:
27 void runOnOperation() override;
28
29private:
30 LogicalResult buildGlobalMap(ModuleOp op);
31
32 void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
33 llvm::DenseSet<SymbolRefAttr> &symbolStore);
34
35 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
36 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
37};
38
39// Traverses upwards searchign for the operation mapped by the symbol.
40static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41 for (auto *op = baseOp; op; op = op->getParentOp()) {
42 auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
43 if (lookup)
44 return lookup;
45 }
46 return nullptr;
47}
48
49// Builds map from a symbol to MLProgram global symbols loaded or stored
50// during processing.
51LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
52 llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
53 auto res = module->walk([&](Operation *op) {
54 if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55 auto callable = caller.getCallableForCallee();
56 // For now we do not know how to handle Value based tracing, so fail.
57 if (mlir::isa<Value>(callable)) {
58 return WalkResult::interrupt();
59 }
60
61 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62 auto *func = getFromSymbol(op, symbol);
63 callableMap[symbol] = func;
64 }
65 return WalkResult::advance();
66 });
67
68 if (res.wasInterrupted()) {
69 return failure();
70 }
71
72 // First grab all symbols loaded or stored by each function. This
73 // will not handle calls initially.
74 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
75 llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
76 for (auto callable : callableMap) {
77 llvm::DenseSet<SymbolRefAttr> loadSymbols;
78 llvm::DenseSet<SymbolRefAttr> storeSymbols;
79
80 callable.getSecond()->walk(
81 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
82
83 callable.getSecond()->walk(
84 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
85
86 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
88 }
89
90 // For each callable function we find each global loaded/stored within the
91 // function or a nested called function. This includes recursion checking to
92 // avoid infinitely recursing.
93 for (auto callable : callableMap) {
94 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95 llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96 llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
97 llvm::DenseSet<SymbolRefAttr> loadSymbols;
98 llvm::DenseSet<SymbolRefAttr> storeSymbols;
99
100 for (size_t i = 0; i < work.size(); ++i) {
101 callableMap[work[i]]->walk([&](CallOpInterface call) {
102 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103 if (visited.insert(symbol).second)
104 work.push_back(symbol);
105 });
106
107 loadSymbols.insert_range(opLoadSymbols[work[i]]);
108
109 storeSymbols.insert_range(opStoreSymbols[work[i]]);
110 }
111
112 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
113 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
114 }
115
116 return success();
117}
118
119// Process each operation in the block deleting unneeded loads / stores,
120// recursing on subblocks and checking function calls.
121void MLProgramPipelineGlobals::processBlock(
122 Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
123 llvm::DenseSet<SymbolRefAttr> &symbolStore) {
124
125 llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
126 llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
127 llvm::SmallVector<Operation *> toDelete;
128 for (auto &op : block) {
129 // If this is a global load, remap to a previous value if known
130 // and delete this load. Remember that this value is the currently
131 // known load.
132 if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
133 auto ref = load.getGlobal();
134 symbolLoad.insert(ref);
135 if (previousLoads.contains(ref)) {
136 toDelete.push_back(&op);
137 load.getResult().replaceAllUsesWith(previousLoads[ref]);
138 } else {
139 previousLoads[ref] = load.getResult();
140 }
141 continue;
142 }
143
144 // Delete a previous store if it exists and is not needed, update
145 // the most recent known value for this global ref.
146 if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
147 auto ref = store.getGlobal();
148 symbolStore.insert(ref);
149 auto it = previousStores.find(ref);
150 if (it != previousStores.end()) {
151 toDelete.push_back(it->getSecond());
152 }
153
154 previousLoads[ref] = store.getValue();
155 previousStores[ref] = &op;
156 continue;
157 }
158
159 // If a function is called, clear known values for loads/stores used by
160 // the function or its sub-functions.
161 if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
162 auto loadSymbols =
163 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
164 auto storeSymbols =
165 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
166
167 for (auto sym : loadSymbols) {
168 previousStores.erase(sym);
169 }
170
171 for (auto sym : storeSymbols) {
172 previousLoads.erase(sym);
173 previousStores.erase(sym);
174 }
175 continue;
176 }
177
178 // If the op has sub-regions, recurse inside. We make no guarantees whether
179 // the recursion occurs.
180 llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
181 llvm::DenseSet<SymbolRefAttr> opSymbolStore;
182 for (auto &region : op.getRegions()) {
183 for (auto &block : region) {
184 processBlock(block, opSymbolLoad, opSymbolStore);
185 }
186 }
187
188 // Update current state from the subblock.
189 for (auto change : opSymbolLoad) {
190 symbolLoad.insert(change);
191 previousStores.erase(change);
192 }
193
194 for (auto change : opSymbolStore) {
195 symbolStore.insert(change);
196 previousLoads.erase(change);
197 previousStores.erase(change);
198 }
199 }
200
201 for (auto *op : toDelete) {
202 op->erase();
203 }
204}
205
206void MLProgramPipelineGlobals::runOnOperation() {
207 auto targetOp = getOperation();
208 if (failed(buildGlobalMap(module: targetOp))) {
209 return;
210 }
211
212 for (auto &funcOp : *targetOp.getBody()) {
213 for (auto &region : funcOp.getRegions()) {
214 for (auto &block : region.getBlocks()) {
215 llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
216 llvm::DenseSet<SymbolRefAttr> symbolsStored;
217 processBlock(block, symbolsLoaded, symbolsStored);
218 }
219 }
220 }
221}
222
223} // namespace
224
225} // namespace ml_program
226} // namespace mlir
227

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp