1 | //===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the memory operations in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
14 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
15 | |
16 | #include "SPIRVOpUtils.h" |
17 | #include "SPIRVParsingUtils.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/IR/Diagnostics.h" |
20 | |
21 | #include "llvm/ADT/StringExtras.h" |
22 | #include "llvm/Support/Casting.h" |
23 | |
24 | using namespace mlir::spirv::AttrNames; |
25 | |
26 | namespace mlir::spirv { |
27 | |
28 | /// Parses optional memory access (a.k.a. memory operand) attributes attached to |
29 | /// a memory access operand/pointer. Specifically, parses the following syntax: |
30 | /// (`[` memory-access `]`)? |
31 | /// where: |
32 | /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` |
33 | /// integer-literal | `"NonTemporal"` |
34 | template <typename MemoryOpTy> |
35 | ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, |
36 | OperationState &state) { |
37 | // Parse an optional list of attributes staring with '[' |
38 | if (parser.parseOptionalLSquare()) { |
39 | // Nothing to do |
40 | return success(); |
41 | } |
42 | |
43 | spirv::MemoryAccess memoryAccessAttr; |
44 | StringAttr memoryAccessAttrName = |
45 | MemoryOpTy::getMemoryAccessAttrName(state.name); |
46 | if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>( |
47 | memoryAccessAttr, parser, state, memoryAccessAttrName)) |
48 | return failure(); |
49 | |
50 | if (spirv::bitEnumContainsAll(memoryAccessAttr, |
51 | spirv::MemoryAccess::Aligned)) { |
52 | // Parse integer attribute for alignment. |
53 | Attribute alignmentAttr; |
54 | StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name); |
55 | Type i32Type = parser.getBuilder().getIntegerType(32); |
56 | if (parser.parseComma() || |
57 | parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName, |
58 | state.attributes)) { |
59 | return failure(); |
60 | } |
61 | } |
62 | return parser.parseRSquare(); |
63 | } |
64 | |
65 | // TODO Make sure to merge this and the previous function into one template |
66 | // parameterized by memory access attribute name and alignment. Doing so now |
67 | // results in VS2017 in producing an internal error (at the call site) that's |
68 | // not detailed enough to understand what is happening. |
69 | template <typename MemoryOpTy> |
70 | static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, |
71 | OperationState &state) { |
72 | // Parse an optional list of attributes staring with '[' |
73 | if (parser.parseOptionalLSquare()) { |
74 | // Nothing to do |
75 | return success(); |
76 | } |
77 | |
78 | spirv::MemoryAccess memoryAccessAttr; |
79 | StringRef memoryAccessAttrName = |
80 | MemoryOpTy::getSourceMemoryAccessAttrName(state.name); |
81 | if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>( |
82 | memoryAccessAttr, parser, state, memoryAccessAttrName)) |
83 | return failure(); |
84 | |
85 | if (spirv::bitEnumContainsAll(memoryAccessAttr, |
86 | spirv::MemoryAccess::Aligned)) { |
87 | // Parse integer attribute for alignment. |
88 | Attribute alignmentAttr; |
89 | StringAttr alignmentAttrName = |
90 | MemoryOpTy::getSourceAlignmentAttrName(state.name); |
91 | Type i32Type = parser.getBuilder().getIntegerType(32); |
92 | if (parser.parseComma() || |
93 | parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName, |
94 | state.attributes)) { |
95 | return failure(); |
96 | } |
97 | } |
98 | return parser.parseRSquare(); |
99 | } |
100 | |
101 | // TODO Make sure to merge this and the previous function into one template |
102 | // parameterized by memory access attribute name and alignment. Doing so now |
103 | // results in VS2017 in producing an internal error (at the call site) that's |
104 | // not detailed enough to understand what is happening. |
105 | template <typename MemoryOpTy> |
106 | static void printSourceMemoryAccessAttribute( |
107 | MemoryOpTy memoryOp, OpAsmPrinter &printer, |
108 | SmallVectorImpl<StringRef> &elidedAttrs, |
109 | std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt, |
110 | std::optional<uint32_t> alignmentAttrValue = std::nullopt) { |
111 | |
112 | printer << ", " ; |
113 | |
114 | // Print optional memory access attribute. |
115 | if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue |
116 | : memoryOp.getMemoryAccess())) { |
117 | elidedAttrs.push_back(Elt: memoryOp.getSourceMemoryAccessAttrName()); |
118 | |
119 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"" ; |
120 | |
121 | if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { |
122 | // Print integer alignment attribute. |
123 | if (auto alignment = (alignmentAttrValue ? alignmentAttrValue |
124 | : memoryOp.getAlignment())) { |
125 | elidedAttrs.push_back(Elt: memoryOp.getSourceAlignmentAttrName()); |
126 | printer << ", " << *alignment; |
127 | } |
128 | } |
129 | printer << "]" ; |
130 | } |
131 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); |
132 | } |
133 | |
134 | template <typename MemoryOpTy> |
135 | static void printMemoryAccessAttribute( |
136 | MemoryOpTy memoryOp, OpAsmPrinter &printer, |
137 | SmallVectorImpl<StringRef> &elidedAttrs, |
138 | std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt, |
139 | std::optional<uint32_t> alignmentAttrValue = std::nullopt) { |
140 | // Print optional memory access attribute. |
141 | if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue |
142 | : memoryOp.getMemoryAccess())) { |
143 | elidedAttrs.push_back(Elt: memoryOp.getMemoryAccessAttrName()); |
144 | |
145 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"" ; |
146 | |
147 | if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { |
148 | // Print integer alignment attribute. |
149 | if (auto alignment = (alignmentAttrValue ? alignmentAttrValue |
150 | : memoryOp.getAlignment())) { |
151 | elidedAttrs.push_back(Elt: memoryOp.getAlignmentAttrName()); |
152 | printer << ", " << *alignment; |
153 | } |
154 | } |
155 | printer << "]" ; |
156 | } |
157 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); |
158 | } |
159 | |
160 | template <typename LoadStoreOpTy> |
161 | static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, |
162 | Value val) { |
163 | // ODS already checks ptr is spirv::PointerType. Just check that the pointee |
164 | // type of the pointer and the type of the value are the same |
165 | // |
166 | // TODO: Check that the value type satisfies restrictions of |
167 | // SPIR-V OpLoad/OpStore operations |
168 | if (val.getType() != |
169 | llvm::cast<spirv::PointerType>(Val: ptr.getType()).getPointeeType()) { |
170 | return op.emitOpError("mismatch in result type and pointer type" ); |
171 | } |
172 | return success(); |
173 | } |
174 | |
175 | template <typename MemoryOpTy> |
176 | static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { |
177 | // ODS checks for attributes values. Just need to verify that if the |
178 | // memory-access attribute is Aligned, then the alignment attribute must be |
179 | // present. |
180 | auto *op = memoryOp.getOperation(); |
181 | auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName()); |
182 | if (!memAccessAttr) { |
183 | // Alignment attribute shouldn't be present if memory access attribute is |
184 | // not present. |
185 | if (op->getAttr(memoryOp.getAlignmentAttrName())) { |
186 | return memoryOp.emitOpError( |
187 | "invalid alignment specification without aligned memory access " |
188 | "specification" ); |
189 | } |
190 | return success(); |
191 | } |
192 | |
193 | auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr); |
194 | |
195 | if (!memAccess) { |
196 | return memoryOp.emitOpError("invalid memory access specifier: " ) |
197 | << memAccessAttr; |
198 | } |
199 | |
200 | if (spirv::bitEnumContainsAll(memAccess.getValue(), |
201 | spirv::MemoryAccess::Aligned)) { |
202 | if (!op->getAttr(memoryOp.getAlignmentAttrName())) { |
203 | return memoryOp.emitOpError("missing alignment value" ); |
204 | } |
205 | } else { |
206 | if (op->getAttr(memoryOp.getAlignmentAttrName())) { |
207 | return memoryOp.emitOpError( |
208 | "invalid alignment specification with non-aligned memory access " |
209 | "specification" ); |
210 | } |
211 | } |
212 | return success(); |
213 | } |
214 | |
215 | // TODO Make sure to merge this and the previous function into one template |
216 | // parameterized by memory access attribute name and alignment. Doing so now |
217 | // results in VS2017 in producing an internal error (at the call site) that's |
218 | // not detailed enough to understand what is happening. |
219 | template <typename MemoryOpTy> |
220 | static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { |
221 | // ODS checks for attributes values. Just need to verify that if the |
222 | // memory-access attribute is Aligned, then the alignment attribute must be |
223 | // present. |
224 | auto *op = memoryOp.getOperation(); |
225 | auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName()); |
226 | if (!memAccessAttr) { |
227 | // Alignment attribute shouldn't be present if memory access attribute is |
228 | // not present. |
229 | if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) { |
230 | return memoryOp.emitOpError( |
231 | "invalid alignment specification without aligned memory access " |
232 | "specification" ); |
233 | } |
234 | return success(); |
235 | } |
236 | |
237 | auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr); |
238 | |
239 | if (!memAccess) { |
240 | return memoryOp.emitOpError("invalid memory access specifier: " ) |
241 | << memAccess; |
242 | } |
243 | |
244 | if (spirv::bitEnumContainsAll(memAccess.getValue(), |
245 | spirv::MemoryAccess::Aligned)) { |
246 | if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) { |
247 | return memoryOp.emitOpError("missing alignment value" ); |
248 | } |
249 | } else { |
250 | if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) { |
251 | return memoryOp.emitOpError( |
252 | "invalid alignment specification with non-aligned memory access " |
253 | "specification" ); |
254 | } |
255 | } |
256 | return success(); |
257 | } |
258 | |
259 | //===----------------------------------------------------------------------===// |
260 | // spirv.AccessChainOp |
261 | //===----------------------------------------------------------------------===// |
262 | |
263 | static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { |
264 | auto ptrType = llvm::dyn_cast<spirv::PointerType>(Val&: type); |
265 | if (!ptrType) { |
266 | emitError(loc: baseLoc, message: "'spirv.AccessChain' op expected a pointer " |
267 | "to composite type, but provided " ) |
268 | << type; |
269 | return nullptr; |
270 | } |
271 | |
272 | auto resultType = ptrType.getPointeeType(); |
273 | auto resultStorageClass = ptrType.getStorageClass(); |
274 | int32_t index = 0; |
275 | |
276 | for (auto indexSSA : indices) { |
277 | auto cType = llvm::dyn_cast<spirv::CompositeType>(Val&: resultType); |
278 | if (!cType) { |
279 | emitError( |
280 | loc: baseLoc, |
281 | message: "'spirv.AccessChain' op cannot extract from non-composite type " ) |
282 | << resultType << " with index " << index; |
283 | return nullptr; |
284 | } |
285 | index = 0; |
286 | if (llvm::isa<spirv::StructType>(Val: resultType)) { |
287 | Operation *op = indexSSA.getDefiningOp(); |
288 | if (!op) { |
289 | emitError(loc: baseLoc, message: "'spirv.AccessChain' op index must be an " |
290 | "integer spirv.Constant to access " |
291 | "element of spirv.struct" ); |
292 | return nullptr; |
293 | } |
294 | |
295 | // TODO: this should be relaxed to allow |
296 | // integer literals of other bitwidths. |
297 | if (failed(Result: spirv::extractValueFromConstOp(op, value&: index))) { |
298 | emitError( |
299 | loc: baseLoc, |
300 | message: "'spirv.AccessChain' index must be an integer spirv.Constant to " |
301 | "access element of spirv.struct, but provided " ) |
302 | << op->getName(); |
303 | return nullptr; |
304 | } |
305 | if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { |
306 | emitError(loc: baseLoc, message: "'spirv.AccessChain' op index " ) |
307 | << index << " out of bounds for " << resultType; |
308 | return nullptr; |
309 | } |
310 | } |
311 | resultType = cType.getElementType(index); |
312 | } |
313 | return spirv::PointerType::get(resultType, resultStorageClass); |
314 | } |
315 | |
316 | void AccessChainOp::build(OpBuilder &builder, OperationState &state, |
317 | Value basePtr, ValueRange indices) { |
318 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); |
319 | assert(type && "Unable to deduce return type based on basePtr and indices" ); |
320 | build(builder, state, type, basePtr, indices); |
321 | } |
322 | |
323 | template <typename Op> |
324 | static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { |
325 | printer << ' ' << op.getBasePtr() << '[' << indices |
326 | << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes(); |
327 | } |
328 | |
329 | template <typename Op> |
330 | static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { |
331 | auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(), |
332 | indices, accessChainOp.getLoc()); |
333 | if (!resultType) |
334 | return failure(); |
335 | |
336 | auto providedResultType = |
337 | llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType()); |
338 | if (!providedResultType) |
339 | return accessChainOp.emitOpError( |
340 | "result type must be a pointer, but provided" ) |
341 | << providedResultType; |
342 | |
343 | if (resultType != providedResultType) |
344 | return accessChainOp.emitOpError("invalid result type: expected " ) |
345 | << resultType << ", but provided " << providedResultType; |
346 | |
347 | return success(); |
348 | } |
349 | |
350 | LogicalResult AccessChainOp::verify() { |
351 | return verifyAccessChain(*this, getIndices()); |
352 | } |
353 | |
354 | //===----------------------------------------------------------------------===// |
355 | // spirv.LoadOp |
356 | //===----------------------------------------------------------------------===// |
357 | |
358 | void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr, |
359 | MemoryAccessAttr memoryAccess, IntegerAttr alignment) { |
360 | auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType()); |
361 | build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, |
362 | alignment); |
363 | } |
364 | |
365 | ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { |
366 | // Parse the storage class specification |
367 | spirv::StorageClass storageClass; |
368 | OpAsmParser::UnresolvedOperand ptrInfo; |
369 | Type elementType; |
370 | if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || |
371 | parseMemoryAccessAttributes<LoadOp>(parser, result) || |
372 | parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
373 | parser.parseType(elementType)) { |
374 | return failure(); |
375 | } |
376 | |
377 | auto ptrType = spirv::PointerType::get(elementType, storageClass); |
378 | if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { |
379 | return failure(); |
380 | } |
381 | |
382 | result.addTypes(elementType); |
383 | return success(); |
384 | } |
385 | |
386 | void LoadOp::print(OpAsmPrinter &printer) { |
387 | SmallVector<StringRef, 4> elidedAttrs; |
388 | StringRef sc = stringifyStorageClass( |
389 | llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass()); |
390 | printer << " \"" << sc << "\" " << getPtr(); |
391 | |
392 | printMemoryAccessAttribute(*this, printer, elidedAttrs); |
393 | |
394 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); |
395 | printer << " : " << getType(); |
396 | } |
397 | |
398 | LogicalResult LoadOp::verify() { |
399 | // SPIR-V spec : "Result Type is the type of the loaded object. It must be a |
400 | // type with fixed size; i.e., it cannot be, nor include, any |
401 | // OpTypeRuntimeArray types." |
402 | if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) { |
403 | return failure(); |
404 | } |
405 | return verifyMemoryAccessAttribute(*this); |
406 | } |
407 | |
408 | //===----------------------------------------------------------------------===// |
409 | // spirv.StoreOp |
410 | //===----------------------------------------------------------------------===// |
411 | |
412 | ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { |
413 | // Parse the storage class specification |
414 | spirv::StorageClass storageClass; |
415 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; |
416 | auto loc = parser.getCurrentLocation(); |
417 | Type elementType; |
418 | if (parseEnumStrAttr(storageClass, parser) || |
419 | parser.parseOperandList(operandInfo, 2) || |
420 | parseMemoryAccessAttributes<StoreOp>(parser, result) || |
421 | parser.parseColon() || parser.parseType(elementType)) { |
422 | return failure(); |
423 | } |
424 | |
425 | auto ptrType = spirv::PointerType::get(elementType, storageClass); |
426 | if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, |
427 | result.operands)) { |
428 | return failure(); |
429 | } |
430 | return success(); |
431 | } |
432 | |
433 | void StoreOp::print(OpAsmPrinter &printer) { |
434 | SmallVector<StringRef, 4> elidedAttrs; |
435 | StringRef sc = stringifyStorageClass( |
436 | llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass()); |
437 | printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); |
438 | |
439 | printMemoryAccessAttribute(*this, printer, elidedAttrs); |
440 | |
441 | printer << " : " << getValue().getType(); |
442 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); |
443 | } |
444 | |
445 | LogicalResult StoreOp::verify() { |
446 | // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an |
447 | // OpTypePointer whose Type operand is the same as the type of Object." |
448 | if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) |
449 | return failure(); |
450 | return verifyMemoryAccessAttribute(*this); |
451 | } |
452 | |
453 | //===----------------------------------------------------------------------===// |
454 | // spirv.CopyMemory |
455 | //===----------------------------------------------------------------------===// |
456 | |
457 | void CopyMemoryOp::print(OpAsmPrinter &printer) { |
458 | printer << ' '; |
459 | |
460 | StringRef targetStorageClass = stringifyStorageClass( |
461 | llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass()); |
462 | printer << " \"" << targetStorageClass << "\" " << getTarget() << ", " ; |
463 | |
464 | StringRef sourceStorageClass = stringifyStorageClass( |
465 | llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass()); |
466 | printer << " \"" << sourceStorageClass << "\" " << getSource(); |
467 | |
468 | SmallVector<StringRef, 4> elidedAttrs; |
469 | printMemoryAccessAttribute(*this, printer, elidedAttrs); |
470 | printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, |
471 | getSourceMemoryAccess(), |
472 | getSourceAlignment()); |
473 | |
474 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); |
475 | |
476 | Type pointeeType = |
477 | llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType(); |
478 | printer << " : " << pointeeType; |
479 | } |
480 | |
481 | ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) { |
482 | spirv::StorageClass targetStorageClass; |
483 | OpAsmParser::UnresolvedOperand targetPtrInfo; |
484 | |
485 | spirv::StorageClass sourceStorageClass; |
486 | OpAsmParser::UnresolvedOperand sourcePtrInfo; |
487 | |
488 | Type elementType; |
489 | |
490 | if (parseEnumStrAttr(targetStorageClass, parser) || |
491 | parser.parseOperand(targetPtrInfo) || parser.parseComma() || |
492 | parseEnumStrAttr(sourceStorageClass, parser) || |
493 | parser.parseOperand(sourcePtrInfo) || |
494 | parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) { |
495 | return failure(); |
496 | } |
497 | |
498 | if (!parser.parseOptionalComma()) { |
499 | // Parse 2nd memory access attributes. |
500 | if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) { |
501 | return failure(); |
502 | } |
503 | } |
504 | |
505 | if (parser.parseColon() || parser.parseType(elementType)) |
506 | return failure(); |
507 | |
508 | if (parser.parseOptionalAttrDict(result.attributes)) |
509 | return failure(); |
510 | |
511 | auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); |
512 | auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); |
513 | |
514 | if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) || |
515 | parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) { |
516 | return failure(); |
517 | } |
518 | |
519 | return success(); |
520 | } |
521 | |
522 | LogicalResult CopyMemoryOp::verify() { |
523 | Type targetType = |
524 | llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType(); |
525 | |
526 | Type sourceType = |
527 | llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType(); |
528 | |
529 | if (targetType != sourceType) |
530 | return emitOpError("both operands must be pointers to the same type" ); |
531 | |
532 | if (failed(verifyMemoryAccessAttribute(*this))) |
533 | return failure(); |
534 | |
535 | // TODO - According to the spec: |
536 | // |
537 | // If two masks are present, the first applies to Target and cannot include |
538 | // MakePointerVisible, and the second applies to Source and cannot include |
539 | // MakePointerAvailable. |
540 | // |
541 | // Add such verification here. |
542 | |
543 | return verifySourceMemoryAccessAttribute(*this); |
544 | } |
545 | |
546 | //===----------------------------------------------------------------------===// |
547 | // spirv.InBoundsPtrAccessChainOp |
548 | //===----------------------------------------------------------------------===// |
549 | |
550 | void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state, |
551 | Value basePtr, Value element, |
552 | ValueRange indices) { |
553 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); |
554 | assert(type && "Unable to deduce return type based on basePtr and indices" ); |
555 | build(builder, state, type, basePtr, element, indices); |
556 | } |
557 | |
558 | LogicalResult InBoundsPtrAccessChainOp::verify() { |
559 | return verifyAccessChain(*this, getIndices()); |
560 | } |
561 | |
562 | //===----------------------------------------------------------------------===// |
563 | // spirv.PtrAccessChainOp |
564 | //===----------------------------------------------------------------------===// |
565 | |
566 | void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, |
567 | Value basePtr, Value element, ValueRange indices) { |
568 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); |
569 | assert(type && "Unable to deduce return type based on basePtr and indices" ); |
570 | build(builder, state, type, basePtr, element, indices); |
571 | } |
572 | |
573 | LogicalResult PtrAccessChainOp::verify() { |
574 | return verifyAccessChain(*this, getIndices()); |
575 | } |
576 | |
577 | //===----------------------------------------------------------------------===// |
578 | // spirv.Variable |
579 | //===----------------------------------------------------------------------===// |
580 | |
581 | ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) { |
582 | // Parse optional initializer |
583 | std::optional<OpAsmParser::UnresolvedOperand> initInfo; |
584 | if (succeeded(parser.parseOptionalKeyword("init" ))) { |
585 | initInfo = OpAsmParser::UnresolvedOperand(); |
586 | if (parser.parseLParen() || parser.parseOperand(*initInfo) || |
587 | parser.parseRParen()) |
588 | return failure(); |
589 | } |
590 | |
591 | if (parseVariableDecorations(parser, result)) { |
592 | return failure(); |
593 | } |
594 | |
595 | // Parse result pointer type |
596 | Type type; |
597 | if (parser.parseColon()) |
598 | return failure(); |
599 | auto loc = parser.getCurrentLocation(); |
600 | if (parser.parseType(type)) |
601 | return failure(); |
602 | |
603 | auto ptrType = llvm::dyn_cast<spirv::PointerType>(type); |
604 | if (!ptrType) |
605 | return parser.emitError(loc, "expected spirv.ptr type" ); |
606 | result.addTypes(ptrType); |
607 | |
608 | // Resolve the initializer operand |
609 | if (initInfo) { |
610 | if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), |
611 | result.operands)) |
612 | return failure(); |
613 | } |
614 | |
615 | auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>( |
616 | ptrType.getStorageClass()); |
617 | result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr); |
618 | |
619 | return success(); |
620 | } |
621 | |
622 | void VariableOp::print(OpAsmPrinter &printer) { |
623 | SmallVector<StringRef, 4> elidedAttrs{ |
624 | spirv::attributeName<spirv::StorageClass>()}; |
625 | // Print optional initializer |
626 | if (getNumOperands() != 0) |
627 | printer << " init(" << getInitializer() << ")" ; |
628 | |
629 | printVariableDecorations(*this, printer, elidedAttrs); |
630 | printer << " : " << getType(); |
631 | } |
632 | |
633 | LogicalResult VariableOp::verify() { |
634 | // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the |
635 | // object. It cannot be Generic. It must be the same as the Storage Class |
636 | // operand of the Result Type." |
637 | if (getStorageClass() != spirv::StorageClass::Function) { |
638 | return emitOpError( |
639 | "can only be used to model function-level variables. Use " |
640 | "spirv.GlobalVariable for module-level variables." ); |
641 | } |
642 | |
643 | auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
644 | if (getStorageClass() != pointerType.getStorageClass()) |
645 | return emitOpError( |
646 | "storage class must match result pointer's storage class" ); |
647 | |
648 | if (getNumOperands() != 0) { |
649 | // SPIR-V spec: "Initializer must be an <id> from a constant instruction or |
650 | // a global (module scope) OpVariable instruction". |
651 | auto *initOp = getOperand(0).getDefiningOp(); |
652 | if (!initOp || !isa<spirv::ConstantOp, // for normal constant |
653 | spirv::ReferenceOfOp, // for spec constant |
654 | spirv::AddressOfOp>(initOp)) |
655 | return emitOpError("initializer must be the result of a " |
656 | "constant or spirv.GlobalVariable op" ); |
657 | } |
658 | |
659 | auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) { |
660 | return op->getAttr( |
661 | llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration))); |
662 | }; |
663 | |
664 | // TODO: generate these strings using ODS. |
665 | for (auto decoration : |
666 | {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding, |
667 | spirv::Decoration::BuiltIn}) { |
668 | if (auto attr = getDecorationAttr(decoration)) |
669 | return emitOpError("cannot have '" ) |
670 | << llvm::convertToSnakeFromCamelCase( |
671 | stringifyDecoration(decoration)) |
672 | << "' attribute (only allowed in spirv.GlobalVariable)" ; |
673 | } |
674 | |
675 | // From SPV_KHR_physical_storage_buffer: |
676 | // > If an OpVariable's pointee type is a pointer (or array of pointers) in |
677 | // > PhysicalStorageBuffer storage class, then the variable must be decorated |
678 | // > with exactly one of AliasedPointer or RestrictPointer. |
679 | auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType()); |
680 | if (!pointeePtrType) { |
681 | if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) { |
682 | pointeePtrType = |
683 | dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType()); |
684 | } |
685 | } |
686 | |
687 | if (pointeePtrType && pointeePtrType.getStorageClass() == |
688 | spirv::StorageClass::PhysicalStorageBuffer) { |
689 | bool hasAliasedPtr = |
690 | getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr; |
691 | bool hasRestrictPtr = |
692 | getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr; |
693 | |
694 | if (!hasAliasedPtr && !hasRestrictPtr) |
695 | return emitOpError() << " with physical buffer pointer must be decorated " |
696 | "either 'AliasedPointer' or 'RestrictPointer'" ; |
697 | |
698 | if (hasAliasedPtr && hasRestrictPtr) |
699 | return emitOpError() |
700 | << " with physical buffer pointer must have exactly one " |
701 | "aliasing decoration" ; |
702 | } |
703 | |
704 | return success(); |
705 | } |
706 | |
707 | } // namespace mlir::spirv |
708 | |