1//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIR-V module combiner library.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/SymbolTable.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/Hashing.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/StringMap.h"
26
27using namespace mlir;
28
29static constexpr unsigned maxFreeID = 1 << 20;
30
31/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
32/// suffix in `lastUsedID`.
33static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
34 spirv::ModuleOp module) {
35 SmallString<64> newSymName(oldSymName);
36 newSymName.push_back(Elt: '_');
37
38 MLIRContext *ctx = module->getContext();
39
40 while (lastUsedID < maxFreeID) {
41 auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
42 if (!SymbolTable::lookupSymbolIn(module, possible))
43 return possible;
44 }
45
46 return StringAttr::get(ctx, newSymName);
47}
48
49/// Checks if a symbol with the same name as `op` already exists in `source`.
50/// If so, renames `op` and updates all its references in `target`.
51static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
52 spirv::ModuleOp target,
53 spirv::ModuleOp source,
54 unsigned &lastUsedID) {
55 if (!SymbolTable::lookupSymbolIn(source, op.getName()))
56 return success();
57
58 StringRef oldSymName = op.getName();
59 StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
60
61 if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
62 return op.emitError("unable to update all symbol uses for ")
63 << oldSymName << " to " << newSymName;
64
65 SymbolTable::setSymbolName(op, newSymName);
66 return success();
67}
68
69/// Computes a hash code to represent `symbolOp` based on all its attributes
70/// except for the symbol name.
71///
72/// Note: We use the operation's name (not the symbol name) as part of the hash
73/// computation. This prevents, for example, mistakenly considering a global
74/// variable and a spec constant as duplicates because their descriptor set +
75/// binding and spec_id, respectively, happen to hash to the same value.
76static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
77 auto range =
78 llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
79 return attr.getName() != SymbolTable::getSymbolAttrName();
80 });
81
82 return llvm::hash_combine(
83 symbolOp->getName(),
84 llvm::hash_combine_range(range.begin(), range.end()));
85}
86
87namespace mlir {
88namespace spirv {
89
90OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
91 OpBuilder &combinedModuleBuilder,
92 SymbolRenameListener symRenameListener) {
93 if (inputModules.empty())
94 return nullptr;
95
96 spirv::ModuleOp firstModule = inputModules.front();
97 auto addressingModel = firstModule.getAddressingModel();
98 auto memoryModel = firstModule.getMemoryModel();
99 auto vceTriple = firstModule.getVceTriple();
100
101 // First check whether there are conflicts between addressing/memory model.
102 // Return early if so.
103 for (auto module : inputModules) {
104 if (module.getAddressingModel() != addressingModel ||
105 module.getMemoryModel() != memoryModel ||
106 module.getVceTriple() != vceTriple) {
107 module.emitError("input modules differ in addressing model, memory "
108 "model, and/or VCE triple");
109 return nullptr;
110 }
111 }
112
113 auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
114 firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
115 combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
116
117 // In some cases, a symbol in the (current state of the) combined module is
118 // renamed in order to enable the conflicting symbol in the input module
119 // being merged. For example, if the conflict is between a global variable in
120 // the current combined module and a function in the input module, the global
121 // variable is renamed. In order to notify listeners of the symbol updates in
122 // such cases, we need to keep track of the module from which the renamed
123 // symbol in the combined module originated. This map keeps such information.
124 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
125
126 unsigned lastUsedID = 0;
127
128 for (auto inputModule : inputModules) {
129 OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
130
131 // In the combined module, rename all symbols that conflict with symbols
132 // from the current input module. This renaming applies to all ops except
133 // for spirv.funcs. This way, if the conflicting op in the input module is
134 // non-spirv.func, we rename that symbol instead and maintain the spirv.func
135 // in the combined module name as it is.
136 for (auto &op : *combinedModule.getBody()) {
137 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
138 if (!symbolOp)
139 continue;
140
141 StringRef oldSymName = symbolOp.getName();
142
143 if (!isa<FuncOp>(op) &&
144 failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
145 lastUsedID)))
146 return nullptr;
147
148 StringRef newSymName = symbolOp.getName();
149
150 if (symRenameListener && oldSymName != newSymName) {
151 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
152
153 if (!originalModule) {
154 inputModule.emitError(
155 "unable to find original spirv::ModuleOp for symbol ")
156 << oldSymName;
157 return nullptr;
158 }
159
160 symRenameListener(originalModule, oldSymName, newSymName);
161
162 // Since the symbol name is updated, there is no need to maintain the
163 // entry that associates the old symbol name with the original module.
164 symNameToModuleMap.erase(oldSymName);
165 // Instead, add a new entry to map the new symbol name to the original
166 // module in case it gets renamed again later.
167 symNameToModuleMap[newSymName] = originalModule;
168 }
169 }
170
171 // In the current input module, rename all symbols that conflict with
172 // symbols from the combined module. This includes renaming spirv.funcs.
173 for (auto &op : *moduleClone->getBody()) {
174 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
175 if (!symbolOp)
176 continue;
177
178 StringRef oldSymName = symbolOp.getName();
179
180 if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
181 lastUsedID)))
182 return nullptr;
183
184 StringRef newSymName = symbolOp.getName();
185
186 if (symRenameListener) {
187 if (oldSymName != newSymName)
188 symRenameListener(inputModule, oldSymName, newSymName);
189
190 // Insert the module associated with the symbol name.
191 auto emplaceResult =
192 symNameToModuleMap.try_emplace(newSymName, inputModule);
193
194 // If an entry with the same symbol name is already present, this must
195 // be a problem with the implementation, specially clean-up of the map
196 // while iterating over the combined module above.
197 if (!emplaceResult.second) {
198 inputModule.emitError("did not expect to find an entry for symbol ")
199 << symbolOp.getName();
200 return nullptr;
201 }
202 }
203 }
204
205 // Clone all the module's ops to the combined module.
206 for (auto &op : *moduleClone->getBody())
207 combinedModuleBuilder.insert(op.clone());
208 }
209
210 // Deduplicate identical global variables, spec constants, and functions.
211 DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
212 SmallVector<SymbolOpInterface, 0> eraseList;
213
214 for (auto &op : *combinedModule.getBody()) {
215 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
216 if (!symbolOp)
217 continue;
218
219 // Do not support ops with operands or results.
220 // Global variables, spec constants, and functions won't have
221 // operands/results, but just for safety here.
222 if (op.getNumOperands() != 0 || op.getNumResults() != 0)
223 continue;
224
225 // Deduplicating functions are not supported yet.
226 if (isa<FuncOp>(op))
227 continue;
228
229 auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
230 if (result.second)
231 continue;
232
233 SymbolOpInterface replacementSymOp = result.first->second;
234
235 if (failed(SymbolTable::replaceAllSymbolUses(
236 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
237 symbolOp.emitError("unable to update all symbol uses for ")
238 << symbolOp.getName() << " to " << replacementSymOp.getName();
239 return nullptr;
240 }
241
242 eraseList.push_back(symbolOp);
243 }
244
245 for (auto symbolOp : eraseList)
246 symbolOp.erase();
247
248 return combinedModule;
249}
250
251} // namespace spirv
252} // namespace mlir
253

source code of mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp