1//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIR-V module combiner library.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
14
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/SymbolTable.h"
19#include "llvm/ADT/Hashing.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/StringMap.h"
22
23using namespace mlir;
24
25static constexpr unsigned maxFreeID = 1 << 20;
26
27/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
28/// suffix in `lastUsedID`.
29static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
30 spirv::ModuleOp module) {
31 SmallString<64> newSymName(oldSymName);
32 newSymName.push_back(Elt: '_');
33
34 MLIRContext *ctx = module->getContext();
35
36 while (lastUsedID < maxFreeID) {
37 auto possible = StringAttr::get(context: ctx, bytes: newSymName + Twine(++lastUsedID));
38 if (!SymbolTable::lookupSymbolIn(op: module, symbol: possible))
39 return possible;
40 }
41
42 return StringAttr::get(context: ctx, bytes: newSymName);
43}
44
45/// Checks if a symbol with the same name as `op` already exists in `source`.
46/// If so, renames `op` and updates all its references in `target`.
47static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
48 spirv::ModuleOp target,
49 spirv::ModuleOp source,
50 unsigned &lastUsedID) {
51 if (!SymbolTable::lookupSymbolIn(op: source, symbol: op.getName()))
52 return success();
53
54 StringRef oldSymName = op.getName();
55 StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, module: target);
56
57 if (failed(Result: SymbolTable::replaceAllSymbolUses(oldSymbol: op, newSymbolName: newSymName, from: target)))
58 return op.emitError(message: "unable to update all symbol uses for ")
59 << oldSymName << " to " << newSymName;
60
61 SymbolTable::setSymbolName(symbol: op, name: newSymName);
62 return success();
63}
64
65/// Computes a hash code to represent `symbolOp` based on all its attributes
66/// except for the symbol name.
67///
68/// Note: We use the operation's name (not the symbol name) as part of the hash
69/// computation. This prevents, for example, mistakenly considering a global
70/// variable and a spec constant as duplicates because their descriptor set +
71/// binding and spec_id, respectively, happen to hash to the same value.
72static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
73 auto range =
74 llvm::make_filter_range(Range: symbolOp->getAttrs(), Pred: [](NamedAttribute attr) {
75 return attr.getName() != SymbolTable::getSymbolAttrName();
76 });
77
78 return llvm::hash_combine(args: symbolOp->getName(),
79 args: llvm::hash_combine_range(R&: range));
80}
81
82namespace mlir {
83namespace spirv {
84
85OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
86 OpBuilder &combinedModuleBuilder,
87 SymbolRenameListener symRenameListener) {
88 if (inputModules.empty())
89 return nullptr;
90
91 spirv::ModuleOp firstModule = inputModules.front();
92 auto addressingModel = firstModule.getAddressingModel();
93 auto memoryModel = firstModule.getMemoryModel();
94 auto vceTriple = firstModule.getVceTriple();
95
96 // First check whether there are conflicts between addressing/memory model.
97 // Return early if so.
98 for (auto module : inputModules) {
99 if (module.getAddressingModel() != addressingModel ||
100 module.getMemoryModel() != memoryModel ||
101 module.getVceTriple() != vceTriple) {
102 module.emitError(message: "input modules differ in addressing model, memory "
103 "model, and/or VCE triple");
104 return nullptr;
105 }
106 }
107
108 auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
109 location: firstModule.getLoc(), args&: addressingModel, args&: memoryModel, args&: vceTriple);
110 combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
111
112 // In some cases, a symbol in the (current state of the) combined module is
113 // renamed in order to enable the conflicting symbol in the input module
114 // being merged. For example, if the conflict is between a global variable in
115 // the current combined module and a function in the input module, the global
116 // variable is renamed. In order to notify listeners of the symbol updates in
117 // such cases, we need to keep track of the module from which the renamed
118 // symbol in the combined module originated. This map keeps such information.
119 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
120
121 unsigned lastUsedID = 0;
122
123 for (auto inputModule : inputModules) {
124 OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
125
126 // In the combined module, rename all symbols that conflict with symbols
127 // from the current input module. This renaming applies to all ops except
128 // for spirv.funcs. This way, if the conflicting op in the input module is
129 // non-spirv.func, we rename that symbol instead and maintain the spirv.func
130 // in the combined module name as it is.
131 for (auto &op : *combinedModule.getBody()) {
132 auto symbolOp = dyn_cast<SymbolOpInterface>(Val&: op);
133 if (!symbolOp)
134 continue;
135
136 StringRef oldSymName = symbolOp.getName();
137
138 if (!isa<FuncOp>(Val: op) &&
139 failed(Result: updateSymbolAndAllUses(op: symbolOp, target: combinedModule, source: *moduleClone,
140 lastUsedID)))
141 return nullptr;
142
143 StringRef newSymName = symbolOp.getName();
144
145 if (symRenameListener && oldSymName != newSymName) {
146 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(Key: oldSymName);
147
148 if (!originalModule) {
149 inputModule.emitError(
150 message: "unable to find original spirv::ModuleOp for symbol ")
151 << oldSymName;
152 return nullptr;
153 }
154
155 symRenameListener(originalModule, oldSymName, newSymName);
156
157 // Since the symbol name is updated, there is no need to maintain the
158 // entry that associates the old symbol name with the original module.
159 symNameToModuleMap.erase(Key: oldSymName);
160 // Instead, add a new entry to map the new symbol name to the original
161 // module in case it gets renamed again later.
162 symNameToModuleMap[newSymName] = originalModule;
163 }
164 }
165
166 // In the current input module, rename all symbols that conflict with
167 // symbols from the combined module. This includes renaming spirv.funcs.
168 for (auto &op : *moduleClone->getBody()) {
169 auto symbolOp = dyn_cast<SymbolOpInterface>(Val&: op);
170 if (!symbolOp)
171 continue;
172
173 StringRef oldSymName = symbolOp.getName();
174
175 if (failed(Result: updateSymbolAndAllUses(op: symbolOp, target: *moduleClone, source: combinedModule,
176 lastUsedID)))
177 return nullptr;
178
179 StringRef newSymName = symbolOp.getName();
180
181 if (symRenameListener) {
182 if (oldSymName != newSymName)
183 symRenameListener(inputModule, oldSymName, newSymName);
184
185 // Insert the module associated with the symbol name.
186 auto emplaceResult =
187 symNameToModuleMap.try_emplace(Key: newSymName, Args&: inputModule);
188
189 // If an entry with the same symbol name is already present, this must
190 // be a problem with the implementation, specially clean-up of the map
191 // while iterating over the combined module above.
192 if (!emplaceResult.second) {
193 inputModule.emitError(message: "did not expect to find an entry for symbol ")
194 << symbolOp.getName();
195 return nullptr;
196 }
197 }
198 }
199
200 // Clone all the module's ops to the combined module.
201 for (auto &op : *moduleClone->getBody())
202 combinedModuleBuilder.insert(op: op.clone());
203 }
204
205 // Deduplicate identical global variables, spec constants, and functions.
206 DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
207 SmallVector<SymbolOpInterface, 0> eraseList;
208
209 for (auto &op : *combinedModule.getBody()) {
210 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(Val&: op);
211 if (!symbolOp)
212 continue;
213
214 // Do not support ops with operands or results.
215 // Global variables, spec constants, and functions won't have
216 // operands/results, but just for safety here.
217 if (op.getNumOperands() != 0 || op.getNumResults() != 0)
218 continue;
219
220 // Deduplicating functions are not supported yet.
221 if (isa<FuncOp>(Val: op))
222 continue;
223
224 auto result = hashToSymbolOp.try_emplace(Key: computeHash(symbolOp), Args&: symbolOp);
225 if (result.second)
226 continue;
227
228 SymbolOpInterface replacementSymOp = result.first->second;
229
230 if (failed(Result: SymbolTable::replaceAllSymbolUses(
231 oldSymbol: symbolOp, newSymbolName: replacementSymOp.getNameAttr(), from: combinedModule))) {
232 symbolOp.emitError(message: "unable to update all symbol uses for ")
233 << symbolOp.getName() << " to " << replacementSymOp.getName();
234 return nullptr;
235 }
236
237 eraseList.push_back(Elt: symbolOp);
238 }
239
240 for (auto symbolOp : eraseList)
241 symbolOp.erase();
242
243 return combinedModule;
244}
245
246} // namespace spirv
247} // namespace mlir
248

source code of mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp