1//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the memory operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15
16#include "SPIRVOpUtils.h"
17#include "SPIRVParsingUtils.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/IR/Diagnostics.h"
20
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/Support/Casting.h"
23
24using namespace mlir::spirv::AttrNames;
25
26namespace mlir::spirv {
27
28/// Parses optional memory access (a.k.a. memory operand) attributes attached to
29/// a memory access operand/pointer. Specifically, parses the following syntax:
30/// (`[` memory-access `]`)?
31/// where:
32/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
33/// integer-literal | `"NonTemporal"`
34template <typename MemoryOpTy>
35ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
36 OperationState &state) {
37 // Parse an optional list of attributes staring with '['
38 if (parser.parseOptionalLSquare()) {
39 // Nothing to do
40 return success();
41 }
42
43 spirv::MemoryAccess memoryAccessAttr;
44 StringAttr memoryAccessAttrName =
45 MemoryOpTy::getMemoryAccessAttrName(state.name);
46 if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
47 memoryAccessAttr, parser, state, memoryAccessAttrName))
48 return failure();
49
50 if (spirv::bitEnumContainsAll(memoryAccessAttr,
51 spirv::MemoryAccess::Aligned)) {
52 // Parse integer attribute for alignment.
53 Attribute alignmentAttr;
54 StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
55 Type i32Type = parser.getBuilder().getIntegerType(32);
56 if (parser.parseComma() ||
57 parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
58 state.attributes)) {
59 return failure();
60 }
61 }
62 return parser.parseRSquare();
63}
64
65// TODO Make sure to merge this and the previous function into one template
66// parameterized by memory access attribute name and alignment. Doing so now
67// results in VS2017 in producing an internal error (at the call site) that's
68// not detailed enough to understand what is happening.
69template <typename MemoryOpTy>
70static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
71 OperationState &state) {
72 // Parse an optional list of attributes staring with '['
73 if (parser.parseOptionalLSquare()) {
74 // Nothing to do
75 return success();
76 }
77
78 spirv::MemoryAccess memoryAccessAttr;
79 StringRef memoryAccessAttrName =
80 MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
81 if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
82 memoryAccessAttr, parser, state, memoryAccessAttrName))
83 return failure();
84
85 if (spirv::bitEnumContainsAll(memoryAccessAttr,
86 spirv::MemoryAccess::Aligned)) {
87 // Parse integer attribute for alignment.
88 Attribute alignmentAttr;
89 StringAttr alignmentAttrName =
90 MemoryOpTy::getSourceAlignmentAttrName(state.name);
91 Type i32Type = parser.getBuilder().getIntegerType(32);
92 if (parser.parseComma() ||
93 parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
94 state.attributes)) {
95 return failure();
96 }
97 }
98 return parser.parseRSquare();
99}
100
101// TODO Make sure to merge this and the previous function into one template
102// parameterized by memory access attribute name and alignment. Doing so now
103// results in VS2017 in producing an internal error (at the call site) that's
104// not detailed enough to understand what is happening.
105template <typename MemoryOpTy>
106static void printSourceMemoryAccessAttribute(
107 MemoryOpTy memoryOp, OpAsmPrinter &printer,
108 SmallVectorImpl<StringRef> &elidedAttrs,
109 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
110 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
111
112 printer << ", ";
113
114 // Print optional memory access attribute.
115 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116 : memoryOp.getMemoryAccess())) {
117 elidedAttrs.push_back(Elt: memoryOp.getSourceMemoryAccessAttrName());
118
119 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
120
121 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
122 // Print integer alignment attribute.
123 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
124 : memoryOp.getAlignment())) {
125 elidedAttrs.push_back(Elt: memoryOp.getSourceAlignmentAttrName());
126 printer << ", " << *alignment;
127 }
128 }
129 printer << "]";
130 }
131 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
132}
133
134template <typename MemoryOpTy>
135static void printMemoryAccessAttribute(
136 MemoryOpTy memoryOp, OpAsmPrinter &printer,
137 SmallVectorImpl<StringRef> &elidedAttrs,
138 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
139 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
140 // Print optional memory access attribute.
141 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142 : memoryOp.getMemoryAccess())) {
143 elidedAttrs.push_back(Elt: memoryOp.getMemoryAccessAttrName());
144
145 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
146
147 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
148 // Print integer alignment attribute.
149 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
150 : memoryOp.getAlignment())) {
151 elidedAttrs.push_back(Elt: memoryOp.getAlignmentAttrName());
152 printer << ", " << *alignment;
153 }
154 }
155 printer << "]";
156 }
157 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
158}
159
160template <typename LoadStoreOpTy>
161static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
162 Value val) {
163 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
164 // type of the pointer and the type of the value are the same
165 //
166 // TODO: Check that the value type satisfies restrictions of
167 // SPIR-V OpLoad/OpStore operations
168 if (val.getType() !=
169 llvm::cast<spirv::PointerType>(Val: ptr.getType()).getPointeeType()) {
170 return op.emitOpError("mismatch in result type and pointer type");
171 }
172 return success();
173}
174
175template <typename MemoryOpTy>
176static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
177 // ODS checks for attributes values. Just need to verify that if the
178 // memory-access attribute is Aligned, then the alignment attribute must be
179 // present.
180 auto *op = memoryOp.getOperation();
181 auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182 if (!memAccessAttr) {
183 // Alignment attribute shouldn't be present if memory access attribute is
184 // not present.
185 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186 return memoryOp.emitOpError(
187 "invalid alignment specification without aligned memory access "
188 "specification");
189 }
190 return success();
191 }
192
193 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
194
195 if (!memAccess) {
196 return memoryOp.emitOpError("invalid memory access specifier: ")
197 << memAccessAttr;
198 }
199
200 if (spirv::bitEnumContainsAll(memAccess.getValue(),
201 spirv::MemoryAccess::Aligned)) {
202 if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203 return memoryOp.emitOpError("missing alignment value");
204 }
205 } else {
206 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207 return memoryOp.emitOpError(
208 "invalid alignment specification with non-aligned memory access "
209 "specification");
210 }
211 }
212 return success();
213}
214
215// TODO Make sure to merge this and the previous function into one template
216// parameterized by memory access attribute name and alignment. Doing so now
217// results in VS2017 in producing an internal error (at the call site) that's
218// not detailed enough to understand what is happening.
219template <typename MemoryOpTy>
220static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
221 // ODS checks for attributes values. Just need to verify that if the
222 // memory-access attribute is Aligned, then the alignment attribute must be
223 // present.
224 auto *op = memoryOp.getOperation();
225 auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226 if (!memAccessAttr) {
227 // Alignment attribute shouldn't be present if memory access attribute is
228 // not present.
229 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230 return memoryOp.emitOpError(
231 "invalid alignment specification without aligned memory access "
232 "specification");
233 }
234 return success();
235 }
236
237 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
238
239 if (!memAccess) {
240 return memoryOp.emitOpError("invalid memory access specifier: ")
241 << memAccess;
242 }
243
244 if (spirv::bitEnumContainsAll(memAccess.getValue(),
245 spirv::MemoryAccess::Aligned)) {
246 if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247 return memoryOp.emitOpError("missing alignment value");
248 }
249 } else {
250 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251 return memoryOp.emitOpError(
252 "invalid alignment specification with non-aligned memory access "
253 "specification");
254 }
255 }
256 return success();
257}
258
259//===----------------------------------------------------------------------===//
260// spirv.AccessChainOp
261//===----------------------------------------------------------------------===//
262
263static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
264 auto ptrType = llvm::dyn_cast<spirv::PointerType>(Val&: type);
265 if (!ptrType) {
266 emitError(loc: baseLoc, message: "'spirv.AccessChain' op expected a pointer "
267 "to composite type, but provided ")
268 << type;
269 return nullptr;
270 }
271
272 auto resultType = ptrType.getPointeeType();
273 auto resultStorageClass = ptrType.getStorageClass();
274 int32_t index = 0;
275
276 for (auto indexSSA : indices) {
277 auto cType = llvm::dyn_cast<spirv::CompositeType>(Val&: resultType);
278 if (!cType) {
279 emitError(
280 loc: baseLoc,
281 message: "'spirv.AccessChain' op cannot extract from non-composite type ")
282 << resultType << " with index " << index;
283 return nullptr;
284 }
285 index = 0;
286 if (llvm::isa<spirv::StructType>(Val: resultType)) {
287 Operation *op = indexSSA.getDefiningOp();
288 if (!op) {
289 emitError(loc: baseLoc, message: "'spirv.AccessChain' op index must be an "
290 "integer spirv.Constant to access "
291 "element of spirv.struct");
292 return nullptr;
293 }
294
295 // TODO: this should be relaxed to allow
296 // integer literals of other bitwidths.
297 if (failed(Result: spirv::extractValueFromConstOp(op, value&: index))) {
298 emitError(
299 loc: baseLoc,
300 message: "'spirv.AccessChain' index must be an integer spirv.Constant to "
301 "access element of spirv.struct, but provided ")
302 << op->getName();
303 return nullptr;
304 }
305 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
306 emitError(loc: baseLoc, message: "'spirv.AccessChain' op index ")
307 << index << " out of bounds for " << resultType;
308 return nullptr;
309 }
310 }
311 resultType = cType.getElementType(index);
312 }
313 return spirv::PointerType::get(resultType, resultStorageClass);
314}
315
316void AccessChainOp::build(OpBuilder &builder, OperationState &state,
317 Value basePtr, ValueRange indices) {
318 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
319 assert(type && "Unable to deduce return type based on basePtr and indices");
320 build(builder, state, type, basePtr, indices);
321}
322
323template <typename Op>
324static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
325 printer << ' ' << op.getBasePtr() << '[' << indices
326 << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
327}
328
329template <typename Op>
330static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
331 auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
332 indices, accessChainOp.getLoc());
333 if (!resultType)
334 return failure();
335
336 auto providedResultType =
337 llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
338 if (!providedResultType)
339 return accessChainOp.emitOpError(
340 "result type must be a pointer, but provided")
341 << providedResultType;
342
343 if (resultType != providedResultType)
344 return accessChainOp.emitOpError("invalid result type: expected ")
345 << resultType << ", but provided " << providedResultType;
346
347 return success();
348}
349
350LogicalResult AccessChainOp::verify() {
351 return verifyAccessChain(*this, getIndices());
352}
353
354//===----------------------------------------------------------------------===//
355// spirv.LoadOp
356//===----------------------------------------------------------------------===//
357
358void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
359 MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
360 auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
361 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
362 alignment);
363}
364
365ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
366 // Parse the storage class specification
367 spirv::StorageClass storageClass;
368 OpAsmParser::UnresolvedOperand ptrInfo;
369 Type elementType;
370 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
371 parseMemoryAccessAttributes<LoadOp>(parser, result) ||
372 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
373 parser.parseType(elementType)) {
374 return failure();
375 }
376
377 auto ptrType = spirv::PointerType::get(elementType, storageClass);
378 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
379 return failure();
380 }
381
382 result.addTypes(elementType);
383 return success();
384}
385
386void LoadOp::print(OpAsmPrinter &printer) {
387 SmallVector<StringRef, 4> elidedAttrs;
388 StringRef sc = stringifyStorageClass(
389 llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
390 printer << " \"" << sc << "\" " << getPtr();
391
392 printMemoryAccessAttribute(*this, printer, elidedAttrs);
393
394 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
395 printer << " : " << getType();
396}
397
398LogicalResult LoadOp::verify() {
399 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
400 // type with fixed size; i.e., it cannot be, nor include, any
401 // OpTypeRuntimeArray types."
402 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
403 return failure();
404 }
405 return verifyMemoryAccessAttribute(*this);
406}
407
408//===----------------------------------------------------------------------===//
409// spirv.StoreOp
410//===----------------------------------------------------------------------===//
411
412ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
413 // Parse the storage class specification
414 spirv::StorageClass storageClass;
415 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
416 auto loc = parser.getCurrentLocation();
417 Type elementType;
418 if (parseEnumStrAttr(storageClass, parser) ||
419 parser.parseOperandList(operandInfo, 2) ||
420 parseMemoryAccessAttributes<StoreOp>(parser, result) ||
421 parser.parseColon() || parser.parseType(elementType)) {
422 return failure();
423 }
424
425 auto ptrType = spirv::PointerType::get(elementType, storageClass);
426 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
427 result.operands)) {
428 return failure();
429 }
430 return success();
431}
432
433void StoreOp::print(OpAsmPrinter &printer) {
434 SmallVector<StringRef, 4> elidedAttrs;
435 StringRef sc = stringifyStorageClass(
436 llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
437 printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
438
439 printMemoryAccessAttribute(*this, printer, elidedAttrs);
440
441 printer << " : " << getValue().getType();
442 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
443}
444
445LogicalResult StoreOp::verify() {
446 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
447 // OpTypePointer whose Type operand is the same as the type of Object."
448 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
449 return failure();
450 return verifyMemoryAccessAttribute(*this);
451}
452
453//===----------------------------------------------------------------------===//
454// spirv.CopyMemory
455//===----------------------------------------------------------------------===//
456
457void CopyMemoryOp::print(OpAsmPrinter &printer) {
458 printer << ' ';
459
460 StringRef targetStorageClass = stringifyStorageClass(
461 llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
462 printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
463
464 StringRef sourceStorageClass = stringifyStorageClass(
465 llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
466 printer << " \"" << sourceStorageClass << "\" " << getSource();
467
468 SmallVector<StringRef, 4> elidedAttrs;
469 printMemoryAccessAttribute(*this, printer, elidedAttrs);
470 printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
471 getSourceMemoryAccess(),
472 getSourceAlignment());
473
474 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
475
476 Type pointeeType =
477 llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
478 printer << " : " << pointeeType;
479}
480
481ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
482 spirv::StorageClass targetStorageClass;
483 OpAsmParser::UnresolvedOperand targetPtrInfo;
484
485 spirv::StorageClass sourceStorageClass;
486 OpAsmParser::UnresolvedOperand sourcePtrInfo;
487
488 Type elementType;
489
490 if (parseEnumStrAttr(targetStorageClass, parser) ||
491 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
492 parseEnumStrAttr(sourceStorageClass, parser) ||
493 parser.parseOperand(sourcePtrInfo) ||
494 parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
495 return failure();
496 }
497
498 if (!parser.parseOptionalComma()) {
499 // Parse 2nd memory access attributes.
500 if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
501 return failure();
502 }
503 }
504
505 if (parser.parseColon() || parser.parseType(elementType))
506 return failure();
507
508 if (parser.parseOptionalAttrDict(result.attributes))
509 return failure();
510
511 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
512 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
513
514 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
515 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
516 return failure();
517 }
518
519 return success();
520}
521
522LogicalResult CopyMemoryOp::verify() {
523 Type targetType =
524 llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
525
526 Type sourceType =
527 llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
528
529 if (targetType != sourceType)
530 return emitOpError("both operands must be pointers to the same type");
531
532 if (failed(verifyMemoryAccessAttribute(*this)))
533 return failure();
534
535 // TODO - According to the spec:
536 //
537 // If two masks are present, the first applies to Target and cannot include
538 // MakePointerVisible, and the second applies to Source and cannot include
539 // MakePointerAvailable.
540 //
541 // Add such verification here.
542
543 return verifySourceMemoryAccessAttribute(*this);
544}
545
546//===----------------------------------------------------------------------===//
547// spirv.InBoundsPtrAccessChainOp
548//===----------------------------------------------------------------------===//
549
550void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
551 Value basePtr, Value element,
552 ValueRange indices) {
553 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
554 assert(type && "Unable to deduce return type based on basePtr and indices");
555 build(builder, state, type, basePtr, element, indices);
556}
557
558LogicalResult InBoundsPtrAccessChainOp::verify() {
559 return verifyAccessChain(*this, getIndices());
560}
561
562//===----------------------------------------------------------------------===//
563// spirv.PtrAccessChainOp
564//===----------------------------------------------------------------------===//
565
566void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
567 Value basePtr, Value element, ValueRange indices) {
568 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
569 assert(type && "Unable to deduce return type based on basePtr and indices");
570 build(builder, state, type, basePtr, element, indices);
571}
572
573LogicalResult PtrAccessChainOp::verify() {
574 return verifyAccessChain(*this, getIndices());
575}
576
577//===----------------------------------------------------------------------===//
578// spirv.Variable
579//===----------------------------------------------------------------------===//
580
581ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
582 // Parse optional initializer
583 std::optional<OpAsmParser::UnresolvedOperand> initInfo;
584 if (succeeded(parser.parseOptionalKeyword("init"))) {
585 initInfo = OpAsmParser::UnresolvedOperand();
586 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
587 parser.parseRParen())
588 return failure();
589 }
590
591 if (parseVariableDecorations(parser, result)) {
592 return failure();
593 }
594
595 // Parse result pointer type
596 Type type;
597 if (parser.parseColon())
598 return failure();
599 auto loc = parser.getCurrentLocation();
600 if (parser.parseType(type))
601 return failure();
602
603 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
604 if (!ptrType)
605 return parser.emitError(loc, "expected spirv.ptr type");
606 result.addTypes(ptrType);
607
608 // Resolve the initializer operand
609 if (initInfo) {
610 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
611 result.operands))
612 return failure();
613 }
614
615 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
616 ptrType.getStorageClass());
617 result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
618
619 return success();
620}
621
622void VariableOp::print(OpAsmPrinter &printer) {
623 SmallVector<StringRef, 4> elidedAttrs{
624 spirv::attributeName<spirv::StorageClass>()};
625 // Print optional initializer
626 if (getNumOperands() != 0)
627 printer << " init(" << getInitializer() << ")";
628
629 printVariableDecorations(*this, printer, elidedAttrs);
630 printer << " : " << getType();
631}
632
633LogicalResult VariableOp::verify() {
634 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
635 // object. It cannot be Generic. It must be the same as the Storage Class
636 // operand of the Result Type."
637 if (getStorageClass() != spirv::StorageClass::Function) {
638 return emitOpError(
639 "can only be used to model function-level variables. Use "
640 "spirv.GlobalVariable for module-level variables.");
641 }
642
643 auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
644 if (getStorageClass() != pointerType.getStorageClass())
645 return emitOpError(
646 "storage class must match result pointer's storage class");
647
648 if (getNumOperands() != 0) {
649 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
650 // a global (module scope) OpVariable instruction".
651 auto *initOp = getOperand(0).getDefiningOp();
652 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
653 spirv::ReferenceOfOp, // for spec constant
654 spirv::AddressOfOp>(initOp))
655 return emitOpError("initializer must be the result of a "
656 "constant or spirv.GlobalVariable op");
657 }
658
659 auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
660 return op->getAttr(
661 llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
662 };
663
664 // TODO: generate these strings using ODS.
665 for (auto decoration :
666 {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
667 spirv::Decoration::BuiltIn}) {
668 if (auto attr = getDecorationAttr(decoration))
669 return emitOpError("cannot have '")
670 << llvm::convertToSnakeFromCamelCase(
671 stringifyDecoration(decoration))
672 << "' attribute (only allowed in spirv.GlobalVariable)";
673 }
674
675 // From SPV_KHR_physical_storage_buffer:
676 // > If an OpVariable's pointee type is a pointer (or array of pointers) in
677 // > PhysicalStorageBuffer storage class, then the variable must be decorated
678 // > with exactly one of AliasedPointer or RestrictPointer.
679 auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
680 if (!pointeePtrType) {
681 if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
682 pointeePtrType =
683 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
684 }
685 }
686
687 if (pointeePtrType && pointeePtrType.getStorageClass() ==
688 spirv::StorageClass::PhysicalStorageBuffer) {
689 bool hasAliasedPtr =
690 getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
691 bool hasRestrictPtr =
692 getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
693
694 if (!hasAliasedPtr && !hasRestrictPtr)
695 return emitOpError() << " with physical buffer pointer must be decorated "
696 "either 'AliasedPointer' or 'RestrictPointer'";
697
698 if (hasAliasedPtr && hasRestrictPtr)
699 return emitOpError()
700 << " with physical buffer pointer must have exactly one "
701 "aliasing decoration";
702 }
703
704 return success();
705}
706
707} // namespace mlir::spirv
708

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp