1//===- Serializer.h - MLIR SPIR-V Serializer ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares the MLIR SPIR-V module to SPIR-V binary serializer.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H
14#define MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/Target/SPIRV/Serialization.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/Support/raw_ostream.h"
22
23namespace mlir {
24namespace spirv {
25
26void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
27 ArrayRef<uint32_t> operands);
28
29/// A SPIR-V module serializer.
30///
31/// A SPIR-V binary module is a single linear stream of instructions; each
32/// instruction is composed of 32-bit words with the layout:
33///
34/// | <word-count>|<opcode> | <operand> | <operand> | ... |
35/// | <------ word -------> | <-- word --> | <-- word --> | ... |
36///
37/// For the first word, the 16 high-order bits are the word count of the
38/// instruction, the 16 low-order bits are the opcode enumerant. The
39/// instructions then belong to different sections, which must be laid out in
40/// the particular order as specified in "2.4 Logical Layout of a Module" of
41/// the SPIR-V spec.
42class Serializer {
43public:
44 /// Creates a serializer for the given SPIR-V `module`.
45 explicit Serializer(spirv::ModuleOp module,
46 const SerializationOptions &options);
47
48 /// Serializes the remembered SPIR-V module.
49 LogicalResult serialize();
50
51 /// Collects the final SPIR-V `binary`.
52 void collect(SmallVectorImpl<uint32_t> &binary);
53
54#ifndef NDEBUG
55 /// (For debugging) prints each value and its corresponding result <id>.
56 void printValueIDMap(raw_ostream &os);
57#endif
58
59private:
60 // Note that there are two main categories of methods in this class:
61 // * process*() methods are meant to fully serialize a SPIR-V module entity
62 // (header, type, op, etc.). They update internal vectors containing
63 // different binary sections. They are not meant to be called except the
64 // top-level serialization loop.
65 // * prepare*() methods are meant to be helpers that prepare for serializing
66 // certain entity. They may or may not update internal vectors containing
67 // different binary sections. They are meant to be called among themselves
68 // or by other process*() methods for subtasks.
69
70 //===--------------------------------------------------------------------===//
71 // <id>
72 //===--------------------------------------------------------------------===//
73
74 // Note that it is illegal to use id <0> in SPIR-V binary module. Various
75 // methods in this class, if using SPIR-V word (uint32_t) as interface,
76 // check or return id <0> to indicate error in processing.
77
78 /// Consumes the next unused <id>. This method will never return 0.
79 uint32_t getNextID() { return nextID++; }
80
81 //===--------------------------------------------------------------------===//
82 // Module structure
83 //===--------------------------------------------------------------------===//
84
85 uint32_t getSpecConstID(StringRef constName) const {
86 return specConstIDMap.lookup(Key: constName);
87 }
88
89 uint32_t getVariableID(StringRef varName) const {
90 return globalVarIDMap.lookup(Key: varName);
91 }
92
93 uint32_t getFunctionID(StringRef fnName) const {
94 return funcIDMap.lookup(Key: fnName);
95 }
96
97 /// Gets the <id> for the function with the given name. Assigns the next
98 /// available <id> if the function haven't been deserialized.
99 uint32_t getOrCreateFunctionID(StringRef fnName);
100
101 void processCapability();
102
103 void processDebugInfo();
104
105 void processExtension();
106
107 void processMemoryModel();
108
109 LogicalResult processConstantOp(spirv::ConstantOp op);
110
111 LogicalResult processConstantCompositeReplicateOp(
112 spirv::EXTConstantCompositeReplicateOp op);
113
114 LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
115
116 LogicalResult
117 processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
118
119 LogicalResult processSpecConstantCompositeReplicateOp(
120 spirv::EXTSpecConstantCompositeReplicateOp op);
121
122 LogicalResult
123 processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
124
125 /// SPIR-V dialect supports OpUndef using spirv.UndefOp that produces a SSA
126 /// value to use with other operations. The SPIR-V spec recommends that
127 /// OpUndef be generated at module level. The serialization generates an
128 /// OpUndef for each type needed at module level.
129 LogicalResult processUndefOp(spirv::UndefOp op);
130
131 /// Emit OpName for the given `resultID`.
132 LogicalResult processName(uint32_t resultID, StringRef name);
133
134 /// Processes a SPIR-V function op.
135 LogicalResult processFuncOp(spirv::FuncOp op);
136 LogicalResult processFuncParameter(spirv::FuncOp op);
137
138 LogicalResult processVariableOp(spirv::VariableOp op);
139
140 /// Process a SPIR-V GlobalVariableOp
141 LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
142
143 /// Process attributes that translate to decorations on the result <id>
144 LogicalResult processDecorationAttr(Location loc, uint32_t resultID,
145 Decoration decoration, Attribute attr);
146 LogicalResult processDecoration(Location loc, uint32_t resultID,
147 NamedAttribute attr);
148
149 template <typename DType>
150 LogicalResult processTypeDecoration(Location loc, DType type,
151 uint32_t resultId) {
152 return emitError(loc, message: "unhandled decoration for type:") << type;
153 }
154
155 /// Process member decoration
156 LogicalResult processMemberDecoration(
157 uint32_t structID,
158 const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
159
160 //===--------------------------------------------------------------------===//
161 // Types
162 //===--------------------------------------------------------------------===//
163
164 uint32_t getTypeID(Type type) const { return typeIDMap.lookup(Val: type); }
165
166 Type getVoidType() { return mlirBuilder.getNoneType(); }
167
168 bool isVoidType(Type type) const { return isa<NoneType>(Val: type); }
169
170 /// Returns true if the given type is a pointer type to a struct in some
171 /// interface storage class.
172 bool isInterfaceStructPtrType(Type type) const;
173
174 /// Main dispatch method for serializing a type. The result <id> of the
175 /// serialized type will be returned as `typeID`.
176 LogicalResult processType(Location loc, Type type, uint32_t &typeID);
177 LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID,
178 SetVector<StringRef> &serializationCtx);
179
180 /// Method for preparing basic SPIR-V type serialization. Returns the type's
181 /// opcode and operands for the instruction via `typeEnum` and `operands`.
182 LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
183 spirv::Opcode &typeEnum,
184 SmallVectorImpl<uint32_t> &operands,
185 bool &deferSerialization,
186 SetVector<StringRef> &serializationCtx);
187
188 LogicalResult prepareFunctionType(Location loc, FunctionType type,
189 spirv::Opcode &typeEnum,
190 SmallVectorImpl<uint32_t> &operands);
191
192 //===--------------------------------------------------------------------===//
193 // Constant
194 //===--------------------------------------------------------------------===//
195
196 uint32_t getConstantID(Attribute value) const {
197 return constIDMap.lookup(Val: value);
198 }
199
200 uint32_t getConstantCompositeReplicateID(
201 std::pair<Attribute, Type> valueTypePair) const {
202 return constCompositeReplicateIDMap.lookup(Val: valueTypePair);
203 }
204
205 /// Main dispatch method for processing a constant with the given `constType`
206 /// and `valueAttr`. `constType` is needed here because we can interpret the
207 /// `valueAttr` as a different type than the type of `valueAttr` itself; for
208 /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
209 /// constants.
210 uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
211
212 /// Prepares array attribute serialization. This method emits corresponding
213 /// OpConstant* and returns the result <id> associated with it. Returns 0 if
214 /// failed.
215 uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
216
217 /// Prepares bool/int/float DenseElementsAttr serialization. This method
218 /// iterates the DenseElementsAttr to construct the constant array, and
219 /// returns the result <id> associated with it. Returns 0 if failed. Note
220 /// that the size of `index` must match the rank.
221 /// TODO: Consider to enhance splat elements cases. For splat cases,
222 /// we don't need to loop over all elements, especially when the splat value
223 /// is zero. We can use OpConstantNull when the value is zero.
224 uint32_t prepareDenseElementsConstant(Location loc, Type constType,
225 DenseElementsAttr valueAttr, int dim,
226 MutableArrayRef<uint64_t> index);
227
228 /// Prepares scalar attribute serialization. This method emits corresponding
229 /// OpConstant* and returns the result <id> associated with it. Returns 0 if
230 /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
231 /// true, then the constant will be serialized as a specialization constant.
232 uint32_t prepareConstantScalar(Location loc, Attribute valueAttr,
233 bool isSpec = false);
234
235 uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr,
236 bool isSpec = false);
237
238 uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
239 bool isSpec = false);
240
241 uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
242 bool isSpec = false);
243
244 /// Prepares `spirv.EXTConstantCompositeReplicateOp` serialization. This
245 /// method emits OpConstantCompositeReplicateEXT and returns the result <id>
246 /// associated with it.
247 uint32_t prepareConstantCompositeReplicate(Location loc, Type resultType,
248 Attribute valueAttr);
249
250 //===--------------------------------------------------------------------===//
251 // Control flow
252 //===--------------------------------------------------------------------===//
253
254 /// Returns the result <id> for the given block.
255 uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(Val: block); }
256
257 /// Returns the result <id> for the given block. If no <id> has been assigned,
258 /// assigns the next available <id>
259 uint32_t getOrCreateBlockID(Block *block);
260
261#ifndef NDEBUG
262 /// (For debugging) prints the block with its result <id>.
263 void printBlock(Block *block, raw_ostream &os);
264#endif
265
266 /// Processes the given `block` and emits SPIR-V instructions for all ops
267 /// inside. Does not emit OpLabel for this block if `omitLabel` is true.
268 /// `emitMerge` is a callback that will be invoked before handling the
269 /// terminator op to inject the Op*Merge instruction if this is a SPIR-V
270 /// selection/loop header block.
271 LogicalResult processBlock(Block *block, bool omitLabel = false,
272 function_ref<LogicalResult()> emitMerge = nullptr);
273
274 /// Emits OpPhi instructions for the given block if it has block arguments.
275 LogicalResult emitPhiForBlockArguments(Block *block);
276
277 LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
278
279 LogicalResult processLoopOp(spirv::LoopOp loopOp);
280
281 LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
282
283 LogicalResult processBranchOp(spirv::BranchOp branchOp);
284
285 //===--------------------------------------------------------------------===//
286 // Operations
287 //===--------------------------------------------------------------------===//
288
289 LogicalResult encodeExtensionInstruction(Operation *op,
290 StringRef extensionSetName,
291 uint32_t opcode,
292 ArrayRef<uint32_t> operands);
293
294 uint32_t getValueID(Value val) const { return valueIDMap.lookup(Val: val); }
295
296 LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
297
298 LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp);
299
300 /// Main dispatch method for serializing an operation.
301 LogicalResult processOperation(Operation *op);
302
303 /// Serializes an operation `op` as core instruction with `opcode` if
304 /// `extInstSet` is empty. Otherwise serializes it as an extended instruction
305 /// with `opcode` from `extInstSet`.
306 /// This method is a generic one for dispatching any SPIR-V ops that has no
307 /// variadic operands and attributes in TableGen definitions.
308 LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet,
309 uint32_t opcode);
310
311 /// Dispatches to the serialization function for an operation in SPIR-V
312 /// dialect that is a mirror of an instruction in the SPIR-V spec. This is
313 /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V
314 /// dialect that have hasOpcode == 1.
315 LogicalResult dispatchToAutogenSerialization(Operation *op);
316
317 /// Serializes an operation in the SPIR-V dialect that is a mirror of an
318 /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1
319 /// and autogenSerialization == 1 in ODS.
320 template <typename OpTy>
321 LogicalResult processOp(OpTy op) {
322 return op.emitError("unsupported op serialization");
323 }
324
325 //===--------------------------------------------------------------------===//
326 // Utilities
327 //===--------------------------------------------------------------------===//
328
329 /// Emits an OpDecorate instruction to decorate the given `target` with the
330 /// given `decoration`.
331 LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
332 ArrayRef<uint32_t> params = {});
333
334 /// Emits an OpLine instruction with the given `loc` location information into
335 /// the given `binary` vector.
336 LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc);
337
338private:
339 /// The SPIR-V module to be serialized.
340 spirv::ModuleOp module;
341
342 /// An MLIR builder for getting MLIR constructs.
343 mlir::Builder mlirBuilder;
344
345 /// Serialization options.
346 SerializationOptions options;
347
348 /// A flag which indicates if the last processed instruction was a merge
349 /// instruction.
350 /// According to SPIR-V spec: "If a branch merge instruction is used, the last
351 /// OpLine in the block must be before its merge instruction".
352 bool lastProcessedWasMergeInst = false;
353
354 /// The <id> of the OpString instruction, which specifies a file name, for
355 /// use by other debug instructions.
356 uint32_t fileID = 0;
357
358 /// The next available result <id>.
359 uint32_t nextID = 1;
360
361 // The following are for different SPIR-V instruction sections. They follow
362 // the logical layout of a SPIR-V module.
363
364 SmallVector<uint32_t, 4> capabilities;
365 SmallVector<uint32_t, 0> extensions;
366 SmallVector<uint32_t, 0> extendedSets;
367 SmallVector<uint32_t, 3> memoryModel;
368 SmallVector<uint32_t, 0> entryPoints;
369 SmallVector<uint32_t, 4> executionModes;
370 SmallVector<uint32_t, 0> debug;
371 SmallVector<uint32_t, 0> names;
372 SmallVector<uint32_t, 0> decorations;
373 SmallVector<uint32_t, 0> typesGlobalValues;
374 SmallVector<uint32_t, 0> functions;
375
376 /// Recursive struct references are serialized as OpTypePointer instructions
377 /// to the recursive struct type. However, the OpTypePointer instruction
378 /// cannot be emitted before the recursive struct's OpTypeStruct.
379 /// RecursiveStructPointerInfo stores the data needed to emit such
380 /// OpTypePointer instructions after forward references to such types.
381 struct RecursiveStructPointerInfo {
382 uint32_t pointerTypeID;
383 spirv::StorageClass storageClass;
384 };
385
386 // Maps spirv::StructType to its recursive reference member info.
387 DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>>
388 recursiveStructInfos;
389
390 /// `functionHeader` contains all the instructions that must be in the first
391 /// block in the function, and `functionBody` contains the rest. After
392 /// processing FuncOp, the encoded instructions of a function are appended to
393 /// `functions`. An example of instructions in `functionHeader` in order:
394 /// OpFunction ...
395 /// OpFunctionParameter ...
396 /// OpFunctionParameter ...
397 /// OpLabel ...
398 /// OpVariable ...
399 /// OpVariable ...
400 SmallVector<uint32_t, 0> functionHeader;
401 SmallVector<uint32_t, 0> functionBody;
402
403 /// Map from type used in SPIR-V module to their <id>s.
404 DenseMap<Type, uint32_t> typeIDMap;
405
406 /// Map from constant values to their <id>s.
407 DenseMap<Attribute, uint32_t> constIDMap;
408
409 /// Map from a replicated composite constant's value and type to their <id>s.
410 DenseMap<std::pair<Attribute, Type>, uint32_t> constCompositeReplicateIDMap;
411
412 /// Map from specialization constant names to their <id>s.
413 llvm::StringMap<uint32_t> specConstIDMap;
414
415 /// Map from GlobalVariableOps name to <id>s.
416 llvm::StringMap<uint32_t> globalVarIDMap;
417
418 /// Map from FuncOps name to <id>s.
419 llvm::StringMap<uint32_t> funcIDMap;
420
421 /// Map from blocks to their <id>s.
422 DenseMap<Block *, uint32_t> blockIDMap;
423
424 /// Map from the Type to the <id> that represents undef value of that type.
425 DenseMap<Type, uint32_t> undefValIDMap;
426
427 /// Map from results of normal operations to their <id>s.
428 DenseMap<Value, uint32_t> valueIDMap;
429
430 /// Map from extended instruction set name to <id>s.
431 llvm::StringMap<uint32_t> extendedInstSetIDMap;
432
433 /// Map from values used in OpPhi instructions to their offset in the
434 /// `functions` section.
435 ///
436 /// When processing a block with arguments, we need to emit OpPhi
437 /// instructions to record the predecessor block <id>s and the values they
438 /// send to the block in question. But it's not guaranteed all values are
439 /// visited and thus assigned result <id>s. So we need this list to capture
440 /// the offsets into `functions` where a value is used so that we can fix it
441 /// up later after processing all the blocks in a function.
442 ///
443 /// More concretely, say if we are visiting the following blocks:
444 ///
445 /// ```mlir
446 /// ^phi(%arg0: i32):
447 /// ...
448 /// ^parent1:
449 /// ...
450 /// spirv.Branch ^phi(%val0: i32)
451 /// ^parent2:
452 /// ...
453 /// spirv.Branch ^phi(%val1: i32)
454 /// ```
455 ///
456 /// When we are serializing the `^phi` block, we need to emit at the beginning
457 /// of the block OpPhi instructions which has the following parameters:
458 ///
459 /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1
460 /// id-for-%val1 id-for-^parent2
461 ///
462 /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit
463 /// all the blocks twice and use the first visit to assign an <id> to each
464 /// value. But it's paying the overheads just for OpPhi emission. Instead,
465 /// we still visit the blocks once for emission. When we emit the OpPhi
466 /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1.
467 /// At the same time, we record their offsets in the emitted binary (which is
468 /// placed inside `functions`) here. And then after emitting all blocks, we
469 /// replace the dummy <id> 0 with the real result <id> by overwriting
470 /// `functions[offset]`.
471 DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues;
472};
473} // namespace spirv
474} // namespace mlir
475
476#endif // MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H
477

source code of mlir/lib/Target/SPIRV/Serialization/Serializer.h