1//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIR-V module combiner library.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/SymbolTable.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/Hashing.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/StringMap.h"
26
27using namespace mlir;
28
29static constexpr unsigned maxFreeID = 1 << 20;
30
31/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
32/// suffix in `lastUsedID`.
33static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
34 spirv::ModuleOp module) {
35 SmallString<64> newSymName(oldSymName);
36 newSymName.push_back(Elt: '_');
37
38 MLIRContext *ctx = module->getContext();
39
40 while (lastUsedID < maxFreeID) {
41 auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
42 if (!SymbolTable::lookupSymbolIn(module, possible))
43 return possible;
44 }
45
46 return StringAttr::get(ctx, newSymName);
47}
48
49/// Checks if a symbol with the same name as `op` already exists in `source`.
50/// If so, renames `op` and updates all its references in `target`.
51static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
52 spirv::ModuleOp target,
53 spirv::ModuleOp source,
54 unsigned &lastUsedID) {
55 if (!SymbolTable::lookupSymbolIn(source, op.getName()))
56 return success();
57
58 StringRef oldSymName = op.getName();
59 StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
60
61 if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
62 return op.emitError("unable to update all symbol uses for ")
63 << oldSymName << " to " << newSymName;
64
65 SymbolTable::setSymbolName(op, newSymName);
66 return success();
67}
68
69/// Computes a hash code to represent `symbolOp` based on all its attributes
70/// except for the symbol name.
71///
72/// Note: We use the operation's name (not the symbol name) as part of the hash
73/// computation. This prevents, for example, mistakenly considering a global
74/// variable and a spec constant as duplicates because their descriptor set +
75/// binding and spec_id, respectively, happen to hash to the same value.
76static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
77 auto range =
78 llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
79 return attr.getName() != SymbolTable::getSymbolAttrName();
80 });
81
82 return llvm::hash_combine(symbolOp->getName(),
83 llvm::hash_combine_range(range));
84}
85
86namespace mlir {
87namespace spirv {
88
89OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
90 OpBuilder &combinedModuleBuilder,
91 SymbolRenameListener symRenameListener) {
92 if (inputModules.empty())
93 return nullptr;
94
95 spirv::ModuleOp firstModule = inputModules.front();
96 auto addressingModel = firstModule.getAddressingModel();
97 auto memoryModel = firstModule.getMemoryModel();
98 auto vceTriple = firstModule.getVceTriple();
99
100 // First check whether there are conflicts between addressing/memory model.
101 // Return early if so.
102 for (auto module : inputModules) {
103 if (module.getAddressingModel() != addressingModel ||
104 module.getMemoryModel() != memoryModel ||
105 module.getVceTriple() != vceTriple) {
106 module.emitError("input modules differ in addressing model, memory "
107 "model, and/or VCE triple");
108 return nullptr;
109 }
110 }
111
112 auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
113 firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
114 combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
115
116 // In some cases, a symbol in the (current state of the) combined module is
117 // renamed in order to enable the conflicting symbol in the input module
118 // being merged. For example, if the conflict is between a global variable in
119 // the current combined module and a function in the input module, the global
120 // variable is renamed. In order to notify listeners of the symbol updates in
121 // such cases, we need to keep track of the module from which the renamed
122 // symbol in the combined module originated. This map keeps such information.
123 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
124
125 unsigned lastUsedID = 0;
126
127 for (auto inputModule : inputModules) {
128 OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
129
130 // In the combined module, rename all symbols that conflict with symbols
131 // from the current input module. This renaming applies to all ops except
132 // for spirv.funcs. This way, if the conflicting op in the input module is
133 // non-spirv.func, we rename that symbol instead and maintain the spirv.func
134 // in the combined module name as it is.
135 for (auto &op : *combinedModule.getBody()) {
136 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
137 if (!symbolOp)
138 continue;
139
140 StringRef oldSymName = symbolOp.getName();
141
142 if (!isa<FuncOp>(op) &&
143 failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
144 lastUsedID)))
145 return nullptr;
146
147 StringRef newSymName = symbolOp.getName();
148
149 if (symRenameListener && oldSymName != newSymName) {
150 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
151
152 if (!originalModule) {
153 inputModule.emitError(
154 "unable to find original spirv::ModuleOp for symbol ")
155 << oldSymName;
156 return nullptr;
157 }
158
159 symRenameListener(originalModule, oldSymName, newSymName);
160
161 // Since the symbol name is updated, there is no need to maintain the
162 // entry that associates the old symbol name with the original module.
163 symNameToModuleMap.erase(oldSymName);
164 // Instead, add a new entry to map the new symbol name to the original
165 // module in case it gets renamed again later.
166 symNameToModuleMap[newSymName] = originalModule;
167 }
168 }
169
170 // In the current input module, rename all symbols that conflict with
171 // symbols from the combined module. This includes renaming spirv.funcs.
172 for (auto &op : *moduleClone->getBody()) {
173 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
174 if (!symbolOp)
175 continue;
176
177 StringRef oldSymName = symbolOp.getName();
178
179 if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
180 lastUsedID)))
181 return nullptr;
182
183 StringRef newSymName = symbolOp.getName();
184
185 if (symRenameListener) {
186 if (oldSymName != newSymName)
187 symRenameListener(inputModule, oldSymName, newSymName);
188
189 // Insert the module associated with the symbol name.
190 auto emplaceResult =
191 symNameToModuleMap.try_emplace(newSymName, inputModule);
192
193 // If an entry with the same symbol name is already present, this must
194 // be a problem with the implementation, specially clean-up of the map
195 // while iterating over the combined module above.
196 if (!emplaceResult.second) {
197 inputModule.emitError("did not expect to find an entry for symbol ")
198 << symbolOp.getName();
199 return nullptr;
200 }
201 }
202 }
203
204 // Clone all the module's ops to the combined module.
205 for (auto &op : *moduleClone->getBody())
206 combinedModuleBuilder.insert(op.clone());
207 }
208
209 // Deduplicate identical global variables, spec constants, and functions.
210 DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
211 SmallVector<SymbolOpInterface, 0> eraseList;
212
213 for (auto &op : *combinedModule.getBody()) {
214 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
215 if (!symbolOp)
216 continue;
217
218 // Do not support ops with operands or results.
219 // Global variables, spec constants, and functions won't have
220 // operands/results, but just for safety here.
221 if (op.getNumOperands() != 0 || op.getNumResults() != 0)
222 continue;
223
224 // Deduplicating functions are not supported yet.
225 if (isa<FuncOp>(op))
226 continue;
227
228 auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
229 if (result.second)
230 continue;
231
232 SymbolOpInterface replacementSymOp = result.first->second;
233
234 if (failed(SymbolTable::replaceAllSymbolUses(
235 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
236 symbolOp.emitError("unable to update all symbol uses for ")
237 << symbolOp.getName() << " to " << replacementSymOp.getName();
238 return nullptr;
239 }
240
241 eraseList.push_back(symbolOp);
242 }
243
244 for (auto symbolOp : eraseList)
245 symbolOp.erase();
246
247 return combinedModule;
248}
249
250} // namespace spirv
251} // namespace mlir
252

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp