1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
21#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
22#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
23#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/OpDefinition.h"
28#include "mlir/IR/OpImplementation.h"
29#include "mlir/IR/Operation.h"
30#include "mlir/IR/TypeUtilities.h"
31#include "mlir/Interfaces/FunctionImplementation.h"
32#include "llvm/ADT/APFloat.h"
33#include "llvm/ADT/APInt.h"
34#include "llvm/ADT/ArrayRef.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/StringExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/InterleavedRange.h"
39#include <cassert>
40#include <numeric>
41#include <optional>
42#include <type_traits>
43
44using namespace mlir;
45using namespace mlir::spirv::AttrNames;
46
47//===----------------------------------------------------------------------===//
48// Common utility functions
49//===----------------------------------------------------------------------===//
50
51LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
52 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
53 if (!constOp) {
54 return failure();
55 }
56 auto valueAttr = constOp.getValue();
57 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58 if (!integerValueAttr) {
59 return failure();
60 }
61
62 if (integerValueAttr.getType().isSignlessInteger())
63 value = integerValueAttr.getInt();
64 else
65 value = integerValueAttr.getSInt();
66
67 return success();
68}
69
70LogicalResult
71spirv::verifyMemorySemantics(Operation *op,
72 spirv::MemorySemantics memorySemantics) {
73 // According to the SPIR-V specification:
74 // "Despite being a mask and allowing multiple bits to be combined, it is
75 // invalid for more than one of these four bits to be set: Acquire, Release,
76 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
77 // Release semantics is done by setting the AcquireRelease bit, not by setting
78 // two bits."
79 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80 spirv::MemorySemantics::Release |
81 spirv::MemorySemantics::AcquireRelease |
82 spirv::MemorySemantics::SequentiallyConsistent;
83
84 auto bitCount =
85 llvm::popcount(Value: static_cast<uint32_t>(memorySemantics & atMostOneInSet));
86 if (bitCount > 1) {
87 return op->emitError(
88 message: "expected at most one of these four memory constraints "
89 "to be set: `Acquire`, `Release`,"
90 "`AcquireRelease` or `SequentiallyConsistent`");
91 }
92 return success();
93}
94
95void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer,
96 SmallVectorImpl<StringRef> &elidedAttrs) {
97 // Print optional descriptor binding
98 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::DescriptorSet));
100 auto bindingName = llvm::convertToSnakeFromCamelCase(
101 stringifyDecoration(spirv::Decoration::Binding));
102 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
103 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
104 if (descriptorSet && binding) {
105 elidedAttrs.push_back(Elt: descriptorSetName);
106 elidedAttrs.push_back(Elt: bindingName);
107 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
108 << ")";
109 }
110
111 // Print BuiltIn attribute if present
112 auto builtInName = llvm::convertToSnakeFromCamelCase(
113 stringifyDecoration(spirv::Decoration::BuiltIn));
114 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
115 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
116 elidedAttrs.push_back(Elt: builtInName);
117 }
118
119 printer.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs);
120}
121
122static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
123 OperationState &result) {
124 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
125 Type type;
126 // If the operand list is in-between parentheses, then we have a generic form.
127 // (see the fallback in `printOneResultOp`).
128 SMLoc loc = parser.getCurrentLocation();
129 if (!parser.parseOptionalLParen()) {
130 if (parser.parseOperandList(result&: ops) || parser.parseRParen() ||
131 parser.parseOptionalAttrDict(result&: result.attributes) ||
132 parser.parseColon() || parser.parseType(result&: type))
133 return failure();
134 auto fnType = llvm::dyn_cast<FunctionType>(type);
135 if (!fnType) {
136 parser.emitError(loc, message: "expected function type");
137 return failure();
138 }
139 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
140 return failure();
141 result.addTypes(fnType.getResults());
142 return success();
143 }
144 return failure(IsFailure: parser.parseOperandList(result&: ops) ||
145 parser.parseOptionalAttrDict(result&: result.attributes) ||
146 parser.parseColonType(result&: type) ||
147 parser.resolveOperands(operands&: ops, type, result&: result.operands) ||
148 parser.addTypeToList(type, result&: result.types));
149}
150
151static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
152 assert(op->getNumResults() == 1 && "op should have one result");
153
154 // If not all the operand and result types are the same, just use the
155 // generic assembly form to avoid omitting information in printing.
156 auto resultType = op->getResult(idx: 0).getType();
157 if (llvm::any_of(Range: op->getOperandTypes(),
158 P: [&](Type type) { return type != resultType; })) {
159 p.printGenericOp(op, /*printOpName=*/false);
160 return;
161 }
162
163 p << ' ';
164 p.printOperands(container: op->getOperands());
165 p.printOptionalAttrDict(attrs: op->getAttrs());
166 // Now we can output only one type for all operands and the result.
167 p << " : " << resultType;
168}
169
170template <typename BlockReadWriteOpTy>
171static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
172 Value ptr, Value val) {
173 auto valType = val.getType();
174 if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
175 valType = valVecTy.getElementType();
176
177 if (valType !=
178 llvm::cast<spirv::PointerType>(Val: ptr.getType()).getPointeeType()) {
179 return op.emitOpError("mismatch in result type and pointer type");
180 }
181 return success();
182}
183
184/// Walks the given type hierarchy with the given indices, potentially down
185/// to component granularity, to select an element type. Returns null type and
186/// emits errors with the given loc on failure.
187static Type
188getElementType(Type type, ArrayRef<int32_t> indices,
189 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
190 if (indices.empty()) {
191 emitErrorFn("expected at least one index for spirv.CompositeExtract");
192 return nullptr;
193 }
194
195 for (auto index : indices) {
196 if (auto cType = llvm::dyn_cast<spirv::CompositeType>(Val&: type)) {
197 if (cType.hasCompileTimeKnownNumElements() &&
198 (index < 0 ||
199 static_cast<uint64_t>(index) >= cType.getNumElements())) {
200 emitErrorFn("index ") << index << " out of bounds for " << type;
201 return nullptr;
202 }
203 type = cType.getElementType(index);
204 } else {
205 emitErrorFn("cannot extract from non-composite type ")
206 << type << " with index " << index;
207 return nullptr;
208 }
209 }
210 return type;
211}
212
213static Type
214getElementType(Type type, Attribute indices,
215 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
216 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
217 if (!indicesArrayAttr) {
218 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
219 return nullptr;
220 }
221 if (indicesArrayAttr.empty()) {
222 emitErrorFn("expected at least one index for spirv.CompositeExtract");
223 return nullptr;
224 }
225
226 SmallVector<int32_t, 2> indexVals;
227 for (auto indexAttr : indicesArrayAttr) {
228 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
229 if (!indexIntAttr) {
230 emitErrorFn("expected an 32-bit integer for index, but found '")
231 << indexAttr << "'";
232 return nullptr;
233 }
234 indexVals.push_back(indexIntAttr.getInt());
235 }
236 return getElementType(type, indices: indexVals, emitErrorFn);
237}
238
239static Type getElementType(Type type, Attribute indices, Location loc) {
240 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
241 return ::mlir::emitError(loc, message: err);
242 };
243 return getElementType(type, indices, emitErrorFn: errorFn);
244}
245
246static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
247 SMLoc loc) {
248 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
249 return parser.emitError(loc, message: err);
250 };
251 return getElementType(type, indices, emitErrorFn: errorFn);
252}
253
254template <typename ExtendedBinaryOp>
255static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
256 auto resultType = llvm::cast<spirv::StructType>(op.getType());
257 if (resultType.getNumElements() != 2)
258 return op.emitOpError("expected result struct type containing two members");
259
260 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
261 resultType.getElementType(0),
262 resultType.getElementType(1)}))
263 return op.emitOpError(
264 "expected all operand types and struct member types are the same");
265
266 return success();
267}
268
269static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
270 OperationState &result) {
271 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
272 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
273 parser.parseOperandList(result&: operands) || parser.parseColon())
274 return failure();
275
276 Type resultType;
277 SMLoc loc = parser.getCurrentLocation();
278 if (parser.parseType(result&: resultType))
279 return failure();
280
281 auto structType = llvm::dyn_cast<spirv::StructType>(Val&: resultType);
282 if (!structType || structType.getNumElements() != 2)
283 return parser.emitError(loc, message: "expected spirv.struct type with two members");
284
285 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
286 if (parser.resolveOperands(operands, types&: operandTypes, loc, result&: result.operands))
287 return failure();
288
289 result.addTypes(newTypes: resultType);
290 return success();
291}
292
293static void printArithmeticExtendedBinaryOp(Operation *op,
294 OpAsmPrinter &printer) {
295 printer << ' ';
296 printer.printOptionalAttrDict(attrs: op->getAttrs());
297 printer.printOperands(container: op->getOperands());
298 printer << " : " << op->getResultTypes().front();
299}
300
301static LogicalResult verifyShiftOp(Operation *op) {
302 if (op->getOperand(idx: 0).getType() != op->getResult(idx: 0).getType()) {
303 return op->emitError(message: "expected the same type for the first operand and "
304 "result, but provided ")
305 << op->getOperand(idx: 0).getType() << " and "
306 << op->getResult(idx: 0).getType();
307 }
308 return success();
309}
310
311//===----------------------------------------------------------------------===//
312// spirv.mlir.addressof
313//===----------------------------------------------------------------------===//
314
315void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
316 spirv::GlobalVariableOp var) {
317 build(builder, state, var.getType(), SymbolRefAttr::get(var));
318}
319
320LogicalResult spirv::AddressOfOp::verify() {
321 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
322 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
323 getVariableAttr()));
324 if (!varOp) {
325 return emitOpError("expected spirv.GlobalVariable symbol");
326 }
327 if (getPointer().getType() != varOp.getType()) {
328 return emitOpError(
329 "result type mismatch with the referenced global variable's type");
330 }
331 return success();
332}
333
334//===----------------------------------------------------------------------===//
335// spirv.CompositeConstruct
336//===----------------------------------------------------------------------===//
337
338LogicalResult spirv::CompositeConstructOp::verify() {
339 operand_range constituents = this->getConstituents();
340
341 // There are 4 cases with varying verification rules:
342 // 1. Cooperative Matrices (1 constituent)
343 // 2. Structs (1 constituent for each member)
344 // 3. Arrays (1 constituent for each array element)
345 // 4. Vectors (1 constituent (sub-)element for each vector element)
346
347 auto coopElementType =
348 llvm::TypeSwitch<Type, Type>(getType())
349 .Case<spirv::CooperativeMatrixType>(
350 [](auto coopType) { return coopType.getElementType(); })
351 .Default([](Type) { return nullptr; });
352
353 // Case 1. -- matrices.
354 if (coopElementType) {
355 if (constituents.size() != 1)
356 return emitOpError("has incorrect number of operands: expected ")
357 << "1, but provided " << constituents.size();
358 if (coopElementType != constituents.front().getType())
359 return emitOpError("operand type mismatch: expected operand type ")
360 << coopElementType << ", but provided "
361 << constituents.front().getType();
362 return success();
363 }
364
365 // Case 2./3./4. -- number of constituents matches the number of elements.
366 auto cType = llvm::cast<spirv::CompositeType>(getType());
367 if (constituents.size() == cType.getNumElements()) {
368 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
369 if (constituents[index].getType() != cType.getElementType(index)) {
370 return emitOpError("operand type mismatch: expected operand type ")
371 << cType.getElementType(index) << ", but provided "
372 << constituents[index].getType();
373 }
374 }
375 return success();
376 }
377
378 // Case 4. -- check that all constituents add up tp the expected vector type.
379 auto resultType = llvm::dyn_cast<VectorType>(cType);
380 if (!resultType)
381 return emitOpError(
382 "expected to return a vector or cooperative matrix when the number of "
383 "constituents is less than what the result needs");
384
385 SmallVector<unsigned> sizes;
386 for (Value component : constituents) {
387 if (!llvm::isa<VectorType>(component.getType()) &&
388 !component.getType().isIntOrFloat())
389 return emitOpError("operand type mismatch: expected operand to have "
390 "a scalar or vector type, but provided ")
391 << component.getType();
392
393 Type elementType = component.getType();
394 if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
395 sizes.push_back(vectorType.getNumElements());
396 elementType = vectorType.getElementType();
397 } else {
398 sizes.push_back(1);
399 }
400
401 if (elementType != resultType.getElementType())
402 return emitOpError("operand element type mismatch: expected to be ")
403 << resultType.getElementType() << ", but provided " << elementType;
404 }
405 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
406 if (totalCount != cType.getNumElements())
407 return emitOpError("has incorrect number of operands: expected ")
408 << cType.getNumElements() << ", but provided " << totalCount;
409 return success();
410}
411
412//===----------------------------------------------------------------------===//
413// spirv.CompositeExtractOp
414//===----------------------------------------------------------------------===//
415
416void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
417 Value composite,
418 ArrayRef<int32_t> indices) {
419 auto indexAttr = builder.getI32ArrayAttr(indices);
420 auto elementType =
421 getElementType(composite.getType(), indexAttr, state.location);
422 if (!elementType) {
423 return;
424 }
425 build(builder, state, elementType, composite, indexAttr);
426}
427
428ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
429 OperationState &result) {
430 OpAsmParser::UnresolvedOperand compositeInfo;
431 Attribute indicesAttr;
432 StringRef indicesAttrName =
433 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
434 Type compositeType;
435 SMLoc attrLocation;
436
437 if (parser.parseOperand(compositeInfo) ||
438 parser.getCurrentLocation(&attrLocation) ||
439 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
440 parser.parseColonType(compositeType) ||
441 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
442 return failure();
443 }
444
445 Type resultType =
446 getElementType(compositeType, indicesAttr, parser, attrLocation);
447 if (!resultType) {
448 return failure();
449 }
450 result.addTypes(resultType);
451 return success();
452}
453
454void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
455 printer << ' ' << getComposite() << getIndices() << " : "
456 << getComposite().getType();
457}
458
459LogicalResult spirv::CompositeExtractOp::verify() {
460 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
461 auto resultType =
462 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
463 if (!resultType)
464 return failure();
465
466 if (resultType != getType()) {
467 return emitOpError("invalid result type: expected ")
468 << resultType << " but provided " << getType();
469 }
470
471 return success();
472}
473
474//===----------------------------------------------------------------------===//
475// spirv.CompositeInsert
476//===----------------------------------------------------------------------===//
477
478void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
479 Value object, Value composite,
480 ArrayRef<int32_t> indices) {
481 auto indexAttr = builder.getI32ArrayAttr(indices);
482 build(builder, state, composite.getType(), object, composite, indexAttr);
483}
484
485ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
486 OperationState &result) {
487 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
488 Type objectType, compositeType;
489 Attribute indicesAttr;
490 StringRef indicesAttrName =
491 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
492 auto loc = parser.getCurrentLocation();
493
494 return failure(
495 parser.parseOperandList(operands, 2) ||
496 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
497 parser.parseColonType(objectType) ||
498 parser.parseKeywordType("into", compositeType) ||
499 parser.resolveOperands(operands, {objectType, compositeType}, loc,
500 result.operands) ||
501 parser.addTypesToList(compositeType, result.types));
502}
503
504LogicalResult spirv::CompositeInsertOp::verify() {
505 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
506 auto objectType =
507 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
508 if (!objectType)
509 return failure();
510
511 if (objectType != getObject().getType()) {
512 return emitOpError("object operand type should be ")
513 << objectType << ", but found " << getObject().getType();
514 }
515
516 if (getComposite().getType() != getType()) {
517 return emitOpError("result type should be the same as "
518 "the composite type, but found ")
519 << getComposite().getType() << " vs " << getType();
520 }
521
522 return success();
523}
524
525void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
526 printer << " " << getObject() << ", " << getComposite() << getIndices()
527 << " : " << getObject().getType() << " into "
528 << getComposite().getType();
529}
530
531//===----------------------------------------------------------------------===//
532// spirv.Constant
533//===----------------------------------------------------------------------===//
534
535ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
536 OperationState &result) {
537 Attribute value;
538 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
539 if (parser.parseAttribute(value, valueAttrName, result.attributes))
540 return failure();
541
542 Type type = NoneType::get(parser.getContext());
543 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
544 type = typedAttr.getType();
545 if (llvm::isa<NoneType, TensorType>(type)) {
546 if (parser.parseColonType(type))
547 return failure();
548 }
549
550 return parser.addTypeToList(type, result.types);
551}
552
553void spirv::ConstantOp::print(OpAsmPrinter &printer) {
554 printer << ' ' << getValue();
555 if (llvm::isa<spirv::ArrayType>(getType()))
556 printer << " : " << getType();
557}
558
559static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
560 Type opType) {
561 if (isa<spirv::CooperativeMatrixType>(Val: opType)) {
562 auto denseAttr = dyn_cast<DenseElementsAttr>(Val&: value);
563 if (!denseAttr || !denseAttr.isSplat())
564 return op.emitOpError("expected a splat dense attribute for cooperative "
565 "matrix constant, but found ")
566 << denseAttr;
567 }
568 if (llvm::isa<IntegerAttr, FloatAttr>(Val: value)) {
569 auto valueType = llvm::cast<TypedAttr>(value).getType();
570 if (valueType != opType)
571 return op.emitOpError("result type (")
572 << opType << ") does not match value type (" << valueType << ")";
573 return success();
574 }
575 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
576 auto valueType = llvm::cast<TypedAttr>(value).getType();
577 if (valueType == opType)
578 return success();
579 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(Val&: opType);
580 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
581 if (!arrayType)
582 return op.emitOpError("result or element type (")
583 << opType << ") does not match value type (" << valueType
584 << "), must be the same or spirv.array";
585
586 int numElements = arrayType.getNumElements();
587 auto opElemType = arrayType.getElementType();
588 while (auto t = llvm::dyn_cast<spirv::ArrayType>(Val&: opElemType)) {
589 numElements *= t.getNumElements();
590 opElemType = t.getElementType();
591 }
592 if (!opElemType.isIntOrFloat())
593 return op.emitOpError("only support nested array result type");
594
595 auto valueElemType = shapedType.getElementType();
596 if (valueElemType != opElemType) {
597 return op.emitOpError("result element type (")
598 << opElemType << ") does not match value element type ("
599 << valueElemType << ")";
600 }
601
602 if (numElements != shapedType.getNumElements()) {
603 return op.emitOpError("result number of elements (")
604 << numElements << ") does not match value number of elements ("
605 << shapedType.getNumElements() << ")";
606 }
607 return success();
608 }
609 if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
610 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(Val&: opType);
611 if (!arrayType)
612 return op.emitOpError(
613 "must have spirv.array result type for array value");
614 Type elemType = arrayType.getElementType();
615 for (Attribute element : arrayAttr.getValue()) {
616 // Verify array elements recursively.
617 if (failed(verifyConstantType(op, element, elemType)))
618 return failure();
619 }
620 return success();
621 }
622 return op.emitOpError("cannot have attribute: ") << value;
623}
624
625LogicalResult spirv::ConstantOp::verify() {
626 // ODS already generates checks to make sure the result type is valid. We just
627 // need to additionally check that the value's attribute type is consistent
628 // with the result type.
629 return verifyConstantType(*this, getValueAttr(), getType());
630}
631
632bool spirv::ConstantOp::isBuildableWith(Type type) {
633 // Must be valid SPIR-V type first.
634 if (!llvm::isa<spirv::SPIRVType>(type))
635 return false;
636
637 if (isa<SPIRVDialect>(type.getDialect())) {
638 // TODO: support constant struct
639 return llvm::isa<spirv::ArrayType>(type);
640 }
641
642 return true;
643}
644
645spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
646 OpBuilder &builder) {
647 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
648 unsigned width = intType.getWidth();
649 if (width == 1)
650 return builder.create<spirv::ConstantOp>(loc, type,
651 builder.getBoolAttr(false));
652 return builder.create<spirv::ConstantOp>(
653 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
654 }
655 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
656 return builder.create<spirv::ConstantOp>(
657 loc, type, builder.getFloatAttr(floatType, 0.0));
658 }
659 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
660 Type elemType = vectorType.getElementType();
661 if (llvm::isa<IntegerType>(elemType)) {
662 return builder.create<spirv::ConstantOp>(
663 loc, type,
664 DenseElementsAttr::get(vectorType,
665 IntegerAttr::get(elemType, 0).getValue()));
666 }
667 if (llvm::isa<FloatType>(elemType)) {
668 return builder.create<spirv::ConstantOp>(
669 loc, type,
670 DenseFPElementsAttr::get(vectorType,
671 FloatAttr::get(elemType, 0.0).getValue()));
672 }
673 }
674
675 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
676}
677
678spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
679 OpBuilder &builder) {
680 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
681 unsigned width = intType.getWidth();
682 if (width == 1)
683 return builder.create<spirv::ConstantOp>(loc, type,
684 builder.getBoolAttr(true));
685 return builder.create<spirv::ConstantOp>(
686 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
687 }
688 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
689 return builder.create<spirv::ConstantOp>(
690 loc, type, builder.getFloatAttr(floatType, 1.0));
691 }
692 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
693 Type elemType = vectorType.getElementType();
694 if (llvm::isa<IntegerType>(elemType)) {
695 return builder.create<spirv::ConstantOp>(
696 loc, type,
697 DenseElementsAttr::get(vectorType,
698 IntegerAttr::get(elemType, 1).getValue()));
699 }
700 if (llvm::isa<FloatType>(elemType)) {
701 return builder.create<spirv::ConstantOp>(
702 loc, type,
703 DenseFPElementsAttr::get(vectorType,
704 FloatAttr::get(elemType, 1.0).getValue()));
705 }
706 }
707
708 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
709}
710
711void mlir::spirv::ConstantOp::getAsmResultNames(
712 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
713 Type type = getType();
714
715 SmallString<32> specialNameBuffer;
716 llvm::raw_svector_ostream specialName(specialNameBuffer);
717 specialName << "cst";
718
719 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
720
721 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
722 if (intTy && intTy.getWidth() == 1) {
723 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
724 }
725
726 if (intTy.isSignless()) {
727 specialName << intCst.getInt();
728 } else if (intTy.isUnsigned()) {
729 specialName << intCst.getUInt();
730 } else {
731 specialName << intCst.getSInt();
732 }
733 }
734
735 if (intTy || llvm::isa<FloatType>(type)) {
736 specialName << '_' << type;
737 }
738
739 if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
740 specialName << "_vec_";
741 specialName << vecType.getDimSize(0);
742
743 Type elementType = vecType.getElementType();
744
745 if (llvm::isa<IntegerType>(elementType) ||
746 llvm::isa<FloatType>(elementType)) {
747 specialName << "x" << elementType;
748 }
749 }
750
751 setNameFn(getResult(), specialName.str());
752}
753
754void mlir::spirv::AddressOfOp::getAsmResultNames(
755 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
756 SmallString<32> specialNameBuffer;
757 llvm::raw_svector_ostream specialName(specialNameBuffer);
758 specialName << getVariable() << "_addr";
759 setNameFn(getResult(), specialName.str());
760}
761
762//===----------------------------------------------------------------------===//
763// spirv.ControlBarrierOp
764//===----------------------------------------------------------------------===//
765
766LogicalResult spirv::ControlBarrierOp::verify() {
767 return verifyMemorySemantics(getOperation(), getMemorySemantics());
768}
769
770//===----------------------------------------------------------------------===//
771// spirv.EntryPoint
772//===----------------------------------------------------------------------===//
773
774void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
775 spirv::ExecutionModel executionModel,
776 spirv::FuncOp function,
777 ArrayRef<Attribute> interfaceVars) {
778 build(builder, state,
779 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
780 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
781}
782
783ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
784 OperationState &result) {
785 spirv::ExecutionModel execModel;
786 SmallVector<Attribute, 4> interfaceVars;
787
788 FlatSymbolRefAttr fn;
789 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
790 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
791 return failure();
792 }
793
794 if (!parser.parseOptionalComma()) {
795 // Parse the interface variables
796 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
797 // The name of the interface variable attribute isnt important
798 FlatSymbolRefAttr var;
799 NamedAttrList attrs;
800 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
801 return failure();
802 interfaceVars.push_back(var);
803 return success();
804 }))
805 return failure();
806 }
807 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
808 parser.getBuilder().getArrayAttr(interfaceVars));
809 return success();
810}
811
812void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
813 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
814 printer.printSymbolName(getFn());
815 auto interfaceVars = getInterface().getValue();
816 if (!interfaceVars.empty())
817 printer << ", " << llvm::interleaved(interfaceVars);
818}
819
820LogicalResult spirv::EntryPointOp::verify() {
821 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
822 // verification.
823 return success();
824}
825
826//===----------------------------------------------------------------------===//
827// spirv.ExecutionMode
828//===----------------------------------------------------------------------===//
829
830void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
831 spirv::FuncOp function,
832 spirv::ExecutionMode executionMode,
833 ArrayRef<int32_t> params) {
834 build(builder, state, SymbolRefAttr::get(function),
835 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
836 builder.getI32ArrayAttr(params));
837}
838
839ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
840 OperationState &result) {
841 spirv::ExecutionMode execMode;
842 Attribute fn;
843 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
844 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
845 return failure();
846 }
847
848 SmallVector<int32_t, 4> values;
849 Type i32Type = parser.getBuilder().getIntegerType(32);
850 while (!parser.parseOptionalComma()) {
851 NamedAttrList attr;
852 Attribute value;
853 if (parser.parseAttribute(value, i32Type, "value", attr)) {
854 return failure();
855 }
856 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
857 }
858 StringRef valuesAttrName =
859 spirv::ExecutionModeOp::getValuesAttrName(result.name);
860 result.addAttribute(valuesAttrName,
861 parser.getBuilder().getI32ArrayAttr(values));
862 return success();
863}
864
865void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
866 printer << " ";
867 printer.printSymbolName(getFn());
868 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
869 ArrayAttr values = this->getValues();
870 if (!values.empty())
871 printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
872}
873
874//===----------------------------------------------------------------------===//
875// spirv.func
876//===----------------------------------------------------------------------===//
877
878ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
879 SmallVector<OpAsmParser::Argument> entryArgs;
880 SmallVector<DictionaryAttr> resultAttrs;
881 SmallVector<Type> resultTypes;
882 auto &builder = parser.getBuilder();
883
884 // Parse the name as a symbol.
885 StringAttr nameAttr;
886 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
887 result.attributes))
888 return failure();
889
890 // Parse the function signature.
891 bool isVariadic = false;
892 if (function_interface_impl::parseFunctionSignatureWithArguments(
893 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
894 resultAttrs))
895 return failure();
896
897 SmallVector<Type> argTypes;
898 for (auto &arg : entryArgs)
899 argTypes.push_back(arg.type);
900 auto fnType = builder.getFunctionType(argTypes, resultTypes);
901 result.addAttribute(getFunctionTypeAttrName(result.name),
902 TypeAttr::get(fnType));
903
904 // Parse the optional function control keyword.
905 spirv::FunctionControl fnControl;
906 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
907 return failure();
908
909 // If additional attributes are present, parse them.
910 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
911 return failure();
912
913 // Add the attributes to the function arguments.
914 assert(resultAttrs.size() == resultTypes.size());
915 call_interface_impl::addArgAndResultAttrs(
916 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
917 getResAttrsAttrName(result.name));
918
919 // Parse the optional function body.
920 auto *body = result.addRegion();
921 OptionalParseResult parseResult =
922 parser.parseOptionalRegion(*body, entryArgs);
923 return failure(parseResult.has_value() && failed(*parseResult));
924}
925
926void spirv::FuncOp::print(OpAsmPrinter &printer) {
927 // Print function name, signature, and control.
928 printer << " ";
929 printer.printSymbolName(getSymName());
930 auto fnType = getFunctionType();
931 function_interface_impl::printFunctionSignature(
932 printer, *this, fnType.getInputs(),
933 /*isVariadic=*/false, fnType.getResults());
934 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
935 << "\"";
936 function_interface_impl::printFunctionAttributes(
937 printer, *this,
938 {spirv::attributeName<spirv::FunctionControl>(),
939 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
940 getFunctionControlAttrName()});
941
942 // Print the body if this is not an external function.
943 Region &body = this->getBody();
944 if (!body.empty()) {
945 printer << ' ';
946 printer.printRegion(body, /*printEntryBlockArgs=*/false,
947 /*printBlockTerminators=*/true);
948 }
949}
950
951LogicalResult spirv::FuncOp::verifyType() {
952 FunctionType fnType = getFunctionType();
953 if (fnType.getNumResults() > 1)
954 return emitOpError("cannot have more than one result");
955
956 auto hasDecorationAttr = [&](spirv::Decoration decoration,
957 unsigned argIndex) {
958 auto func = llvm::cast<FunctionOpInterface>(getOperation());
959 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
960 if (argAttr.getName() != spirv::DecorationAttr::name)
961 continue;
962 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
963 return decAttr.getValue() == decoration;
964 }
965 return false;
966 };
967
968 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
969 Type param = fnType.getInputs()[i];
970 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
971 if (!inputPtrType)
972 continue;
973
974 auto pointeePtrType =
975 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
976 if (pointeePtrType) {
977 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
978 // > If an OpFunctionParameter is a pointer (or contains a pointer)
979 // > and the type it points to is a pointer in the PhysicalStorageBuffer
980 // > storage class, the function parameter must be decorated with exactly
981 // > one of AliasedPointer or RestrictPointer.
982 if (pointeePtrType.getStorageClass() !=
983 spirv::StorageClass::PhysicalStorageBuffer)
984 continue;
985
986 bool hasAliasedPtr =
987 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
988 bool hasRestrictPtr =
989 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
990 if (!hasAliasedPtr && !hasRestrictPtr)
991 return emitOpError()
992 << "with a pointer points to a physical buffer pointer must "
993 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
994 continue;
995 }
996 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
997 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
998 // > the PhysicalStorageBuffer storage class, the function parameter must
999 // > be decorated with exactly one of Aliased or Restrict.
1000 if (auto pointeeArrayType =
1001 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1002 pointeePtrType =
1003 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1004 } else {
1005 pointeePtrType = inputPtrType;
1006 }
1007
1008 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1009 spirv::StorageClass::PhysicalStorageBuffer)
1010 continue;
1011
1012 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1013 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1014 if (!hasAliased && !hasRestrict)
1015 return emitOpError() << "with physical buffer pointer must be decorated "
1016 "either 'Aliased' or 'Restrict'";
1017 }
1018
1019 return success();
1020}
1021
1022LogicalResult spirv::FuncOp::verifyBody() {
1023 FunctionType fnType = getFunctionType();
1024 if (!isExternal()) {
1025 Block &entryBlock = front();
1026
1027 unsigned numArguments = this->getNumArguments();
1028 if (entryBlock.getNumArguments() != numArguments)
1029 return emitOpError("entry block must have ")
1030 << numArguments << " arguments to match function signature";
1031
1032 for (auto [index, fnArgType, blockArgType] :
1033 llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1034 if (blockArgType != fnArgType) {
1035 return emitOpError("type of entry block argument #")
1036 << index << '(' << blockArgType
1037 << ") must match the type of the corresponding argument in "
1038 << "function signature(" << fnArgType << ')';
1039 }
1040 }
1041 }
1042
1043 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1044 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1045 if (fnType.getNumResults() != 0)
1046 return retOp.emitOpError("cannot be used in functions returning value");
1047 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1048 if (fnType.getNumResults() != 1)
1049 return retOp.emitOpError(
1050 "returns 1 value but enclosing function requires ")
1051 << fnType.getNumResults() << " results";
1052
1053 auto retOperandType = retOp.getValue().getType();
1054 auto fnResultType = fnType.getResult(0);
1055 if (retOperandType != fnResultType)
1056 return retOp.emitOpError(" return value's type (")
1057 << retOperandType << ") mismatch with function's result type ("
1058 << fnResultType << ")";
1059 }
1060 return WalkResult::advance();
1061 });
1062
1063 // TODO: verify other bits like linkage type.
1064
1065 return failure(walkResult.wasInterrupted());
1066}
1067
1068void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1069 StringRef name, FunctionType type,
1070 spirv::FunctionControl control,
1071 ArrayRef<NamedAttribute> attrs) {
1072 state.addAttribute(SymbolTable::getSymbolAttrName(),
1073 builder.getStringAttr(name));
1074 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1075 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1076 builder.getAttr<spirv::FunctionControlAttr>(control));
1077 state.attributes.append(attrs.begin(), attrs.end());
1078 state.addRegion();
1079}
1080
1081//===----------------------------------------------------------------------===//
1082// spirv.GLFClampOp
1083//===----------------------------------------------------------------------===//
1084
1085ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1086 OperationState &result) {
1087 return parseOneResultSameOperandTypeOp(parser, result);
1088}
1089void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1090
1091//===----------------------------------------------------------------------===//
1092// spirv.GLUClampOp
1093//===----------------------------------------------------------------------===//
1094
1095ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1096 OperationState &result) {
1097 return parseOneResultSameOperandTypeOp(parser, result);
1098}
1099void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1100
1101//===----------------------------------------------------------------------===//
1102// spirv.GLSClampOp
1103//===----------------------------------------------------------------------===//
1104
1105ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1106 OperationState &result) {
1107 return parseOneResultSameOperandTypeOp(parser, result);
1108}
1109void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1110
1111//===----------------------------------------------------------------------===//
1112// spirv.GLFmaOp
1113//===----------------------------------------------------------------------===//
1114
1115ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1116 return parseOneResultSameOperandTypeOp(parser, result);
1117}
1118void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1119
1120//===----------------------------------------------------------------------===//
1121// spirv.GlobalVariable
1122//===----------------------------------------------------------------------===//
1123
1124void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1125 Type type, StringRef name,
1126 unsigned descriptorSet, unsigned binding) {
1127 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1128 state.addAttribute(
1129 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1130 builder.getI32IntegerAttr(descriptorSet));
1131 state.addAttribute(
1132 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1133 builder.getI32IntegerAttr(binding));
1134}
1135
1136void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1137 Type type, StringRef name,
1138 spirv::BuiltIn builtin) {
1139 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1140 state.addAttribute(
1141 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1142 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1143}
1144
1145ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1146 OperationState &result) {
1147 // Parse variable name.
1148 StringAttr nameAttr;
1149 StringRef initializerAttrName =
1150 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1151 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1152 result.attributes)) {
1153 return failure();
1154 }
1155
1156 // Parse optional initializer
1157 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1158 FlatSymbolRefAttr initSymbol;
1159 if (parser.parseLParen() ||
1160 parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1161 result.attributes) ||
1162 parser.parseRParen())
1163 return failure();
1164 }
1165
1166 if (parseVariableDecorations(parser, result)) {
1167 return failure();
1168 }
1169
1170 Type type;
1171 StringRef typeAttrName =
1172 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1173 auto loc = parser.getCurrentLocation();
1174 if (parser.parseColonType(type)) {
1175 return failure();
1176 }
1177 if (!llvm::isa<spirv::PointerType>(type)) {
1178 return parser.emitError(loc, "expected spirv.ptr type");
1179 }
1180 result.addAttribute(typeAttrName, TypeAttr::get(type));
1181
1182 return success();
1183}
1184
1185void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1186 SmallVector<StringRef, 4> elidedAttrs{
1187 spirv::attributeName<spirv::StorageClass>()};
1188
1189 // Print variable name.
1190 printer << ' ';
1191 printer.printSymbolName(getSymName());
1192 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1193
1194 StringRef initializerAttrName = this->getInitializerAttrName();
1195 // Print optional initializer
1196 if (auto initializer = this->getInitializer()) {
1197 printer << " " << initializerAttrName << '(';
1198 printer.printSymbolName(*initializer);
1199 printer << ')';
1200 elidedAttrs.push_back(initializerAttrName);
1201 }
1202
1203 StringRef typeAttrName = this->getTypeAttrName();
1204 elidedAttrs.push_back(typeAttrName);
1205 spirv::printVariableDecorations(*this, printer, elidedAttrs);
1206 printer << " : " << getType();
1207}
1208
1209LogicalResult spirv::GlobalVariableOp::verify() {
1210 if (!llvm::isa<spirv::PointerType>(getType()))
1211 return emitOpError("result must be of a !spv.ptr type");
1212
1213 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1214 // object. It cannot be Generic. It must be the same as the Storage Class
1215 // operand of the Result Type."
1216 // Also, Function storage class is reserved by spirv.Variable.
1217 auto storageClass = this->storageClass();
1218 if (storageClass == spirv::StorageClass::Generic ||
1219 storageClass == spirv::StorageClass::Function) {
1220 return emitOpError("storage class cannot be '")
1221 << stringifyStorageClass(storageClass) << "'";
1222 }
1223
1224 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1225 this->getInitializerAttrName())) {
1226 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
1227 (*this)->getParentOp(), init.getAttr());
1228 // TODO: Currently only variable initialization with specialization
1229 // constants and other variables is supported. They could be normal
1230 // constants in the module scope as well.
1231 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1232 spirv::SpecConstantCompositeOp>(initOp)) {
1233 return emitOpError("initializer must be result of a "
1234 "spirv.SpecConstant or spirv.GlobalVariable or "
1235 "spirv.SpecConstantCompositeOp op");
1236 }
1237 }
1238
1239 return success();
1240}
1241
1242//===----------------------------------------------------------------------===//
1243// spirv.INTEL.SubgroupBlockRead
1244//===----------------------------------------------------------------------===//
1245
1246LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1247 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1248 return failure();
1249
1250 return success();
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// spirv.INTEL.SubgroupBlockWrite
1255//===----------------------------------------------------------------------===//
1256
1257ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1258 OperationState &result) {
1259 // Parse the storage class specification
1260 spirv::StorageClass storageClass;
1261 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1262 auto loc = parser.getCurrentLocation();
1263 Type elementType;
1264 if (parseEnumStrAttr(storageClass, parser) ||
1265 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1266 parser.parseType(elementType)) {
1267 return failure();
1268 }
1269
1270 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1271 if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1272 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1273
1274 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1275 result.operands)) {
1276 return failure();
1277 }
1278 return success();
1279}
1280
1281void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1282 printer << " " << getPtr() << ", " << getValue() << " : "
1283 << getValue().getType();
1284}
1285
1286LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1287 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1288 return failure();
1289
1290 return success();
1291}
1292
1293//===----------------------------------------------------------------------===//
1294// spirv.IAddCarryOp
1295//===----------------------------------------------------------------------===//
1296
1297LogicalResult spirv::IAddCarryOp::verify() {
1298 return ::verifyArithmeticExtendedBinaryOp(*this);
1299}
1300
1301ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1302 OperationState &result) {
1303 return ::parseArithmeticExtendedBinaryOp(parser, result);
1304}
1305
1306void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1307 ::printArithmeticExtendedBinaryOp(*this, printer);
1308}
1309
1310//===----------------------------------------------------------------------===//
1311// spirv.ISubBorrowOp
1312//===----------------------------------------------------------------------===//
1313
1314LogicalResult spirv::ISubBorrowOp::verify() {
1315 return ::verifyArithmeticExtendedBinaryOp(*this);
1316}
1317
1318ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1319 OperationState &result) {
1320 return ::parseArithmeticExtendedBinaryOp(parser, result);
1321}
1322
1323void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1324 ::printArithmeticExtendedBinaryOp(*this, printer);
1325}
1326
1327//===----------------------------------------------------------------------===//
1328// spirv.SMulExtended
1329//===----------------------------------------------------------------------===//
1330
1331LogicalResult spirv::SMulExtendedOp::verify() {
1332 return ::verifyArithmeticExtendedBinaryOp(*this);
1333}
1334
1335ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1336 OperationState &result) {
1337 return ::parseArithmeticExtendedBinaryOp(parser, result);
1338}
1339
1340void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1341 ::printArithmeticExtendedBinaryOp(*this, printer);
1342}
1343
1344//===----------------------------------------------------------------------===//
1345// spirv.UMulExtended
1346//===----------------------------------------------------------------------===//
1347
1348LogicalResult spirv::UMulExtendedOp::verify() {
1349 return ::verifyArithmeticExtendedBinaryOp(*this);
1350}
1351
1352ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1353 OperationState &result) {
1354 return ::parseArithmeticExtendedBinaryOp(parser, result);
1355}
1356
1357void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1358 ::printArithmeticExtendedBinaryOp(*this, printer);
1359}
1360
1361//===----------------------------------------------------------------------===//
1362// spirv.MemoryBarrierOp
1363//===----------------------------------------------------------------------===//
1364
1365LogicalResult spirv::MemoryBarrierOp::verify() {
1366 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1367}
1368
1369//===----------------------------------------------------------------------===//
1370// spirv.module
1371//===----------------------------------------------------------------------===//
1372
1373void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1374 std::optional<StringRef> name) {
1375 OpBuilder::InsertionGuard guard(builder);
1376 builder.createBlock(state.addRegion());
1377 if (name) {
1378 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1379 builder.getStringAttr(*name));
1380 }
1381}
1382
1383void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1384 spirv::AddressingModel addressingModel,
1385 spirv::MemoryModel memoryModel,
1386 std::optional<VerCapExtAttr> vceTriple,
1387 std::optional<StringRef> name) {
1388 state.addAttribute(
1389 "addressing_model",
1390 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1391 state.addAttribute("memory_model",
1392 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1393 OpBuilder::InsertionGuard guard(builder);
1394 builder.createBlock(state.addRegion());
1395 if (vceTriple)
1396 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1397 if (name)
1398 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1399 builder.getStringAttr(*name));
1400}
1401
1402ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1403 OperationState &result) {
1404 Region *body = result.addRegion();
1405
1406 // If the name is present, parse it.
1407 StringAttr nameAttr;
1408 (void)parser.parseOptionalSymbolName(
1409 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1410
1411 // Parse attributes
1412 spirv::AddressingModel addrModel;
1413 spirv::MemoryModel memoryModel;
1414 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1415 result) ||
1416 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1417 result))
1418 return failure();
1419
1420 if (succeeded(parser.parseOptionalKeyword("requires"))) {
1421 spirv::VerCapExtAttr vceTriple;
1422 if (parser.parseAttribute(vceTriple,
1423 spirv::ModuleOp::getVCETripleAttrName(),
1424 result.attributes))
1425 return failure();
1426 }
1427
1428 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1429 parser.parseRegion(*body, /*arguments=*/{}))
1430 return failure();
1431
1432 // Make sure we have at least one block.
1433 if (body->empty())
1434 body->push_back(new Block());
1435
1436 return success();
1437}
1438
1439void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1440 if (std::optional<StringRef> name = getName()) {
1441 printer << ' ';
1442 printer.printSymbolName(*name);
1443 }
1444
1445 SmallVector<StringRef, 2> elidedAttrs;
1446
1447 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1448 << spirv::stringifyMemoryModel(getMemoryModel());
1449 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1450 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1451 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1452 mlir::SymbolTable::getSymbolAttrName()});
1453
1454 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1455 printer << " requires " << *triple;
1456 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1457 }
1458
1459 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1460 printer << ' ';
1461 printer.printRegion(getRegion());
1462}
1463
1464LogicalResult spirv::ModuleOp::verifyRegions() {
1465 Dialect *dialect = (*this)->getDialect();
1466 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
1467 entryPoints;
1468 mlir::SymbolTable table(*this);
1469
1470 for (auto &op : *getBody()) {
1471 if (op.getDialect() != dialect)
1472 return op.emitError("'spirv.module' can only contain spirv.* ops");
1473
1474 // For EntryPoint op, check that the function and execution model is not
1475 // duplicated in EntryPointOps. Also verify that the interface specified
1476 // comes from globalVariables here to make this check cheaper.
1477 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1478 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1479 if (!funcOp) {
1480 return entryPointOp.emitError("function '")
1481 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1482 }
1483 if (auto interface = entryPointOp.getInterface()) {
1484 for (Attribute varRef : interface) {
1485 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1486 if (!varSymRef) {
1487 return entryPointOp.emitError(
1488 "expected symbol reference for interface "
1489 "specification instead of '")
1490 << varRef;
1491 }
1492 auto variableOp =
1493 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1494 if (!variableOp) {
1495 return entryPointOp.emitError("expected spirv.GlobalVariable "
1496 "symbol reference instead of'")
1497 << varSymRef << "'";
1498 }
1499 }
1500 }
1501
1502 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1503 funcOp, entryPointOp.getExecutionModel());
1504 if (!entryPoints.try_emplace(key, entryPointOp).second)
1505 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1506 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1507 // If the function is external and does not have 'Import'
1508 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1509 // LinkageAttributes is used to import external functions.
1510 auto linkageAttr = funcOp.getLinkageAttributes();
1511 auto hasImportLinkage =
1512 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1513 spirv::LinkageType::Import);
1514 if (funcOp.isExternal() && !hasImportLinkage)
1515 return op.emitError(
1516 "'spirv.module' cannot contain external functions "
1517 "without 'Import' linkage_attributes (LinkageAttributes)");
1518
1519 // TODO: move this check to spirv.func.
1520 for (auto &block : funcOp)
1521 for (auto &op : block) {
1522 if (op.getDialect() != dialect)
1523 return op.emitError(
1524 "functions in 'spirv.module' can only contain spirv.* ops");
1525 }
1526 }
1527 }
1528
1529 return success();
1530}
1531
1532//===----------------------------------------------------------------------===//
1533// spirv.mlir.referenceof
1534//===----------------------------------------------------------------------===//
1535
1536LogicalResult spirv::ReferenceOfOp::verify() {
1537 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1538 (*this)->getParentOp(), getSpecConstAttr());
1539 Type constType;
1540
1541 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1542 if (specConstOp)
1543 constType = specConstOp.getDefaultValue().getType();
1544
1545 auto specConstCompositeOp =
1546 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1547 if (specConstCompositeOp)
1548 constType = specConstCompositeOp.getType();
1549
1550 if (!specConstOp && !specConstCompositeOp)
1551 return emitOpError(
1552 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1553
1554 if (getReference().getType() != constType)
1555 return emitOpError("result type mismatch with the referenced "
1556 "specialization constant's type");
1557
1558 return success();
1559}
1560
1561//===----------------------------------------------------------------------===//
1562// spirv.SpecConstant
1563//===----------------------------------------------------------------------===//
1564
1565ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1566 OperationState &result) {
1567 StringAttr nameAttr;
1568 Attribute valueAttr;
1569 StringRef defaultValueAttrName =
1570 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1571
1572 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1573 result.attributes))
1574 return failure();
1575
1576 // Parse optional spec_id.
1577 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1578 IntegerAttr specIdAttr;
1579 if (parser.parseLParen() ||
1580 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1581 parser.parseRParen())
1582 return failure();
1583 }
1584
1585 if (parser.parseEqual() ||
1586 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1587 return failure();
1588
1589 return success();
1590}
1591
1592void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1593 printer << ' ';
1594 printer.printSymbolName(getSymName());
1595 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1596 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1597 printer << " = " << getDefaultValue();
1598}
1599
1600LogicalResult spirv::SpecConstantOp::verify() {
1601 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1602 if (specID.getValue().isNegative())
1603 return emitOpError("SpecId cannot be negative");
1604
1605 auto value = getDefaultValue();
1606 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1607 // Make sure bitwidth is allowed.
1608 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1609 return emitOpError("default value bitwidth disallowed");
1610 return success();
1611 }
1612 return emitOpError(
1613 "default value can only be a bool, integer, or float scalar");
1614}
1615
1616//===----------------------------------------------------------------------===//
1617// spirv.VectorShuffle
1618//===----------------------------------------------------------------------===//
1619
1620LogicalResult spirv::VectorShuffleOp::verify() {
1621 VectorType resultType = llvm::cast<VectorType>(getType());
1622
1623 size_t numResultElements = resultType.getNumElements();
1624 if (numResultElements != getComponents().size())
1625 return emitOpError("result type element count (")
1626 << numResultElements
1627 << ") mismatch with the number of component selectors ("
1628 << getComponents().size() << ")";
1629
1630 size_t totalSrcElements =
1631 llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1632 llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1633
1634 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1635 uint32_t index = selector.getZExtValue();
1636 if (index >= totalSrcElements &&
1637 index != std::numeric_limits<uint32_t>().max())
1638 return emitOpError("component selector ")
1639 << index << " out of range: expected to be in [0, "
1640 << totalSrcElements << ") or 0xffffffff";
1641 }
1642 return success();
1643}
1644
1645//===----------------------------------------------------------------------===//
1646// spirv.MatrixTimesScalar
1647//===----------------------------------------------------------------------===//
1648
1649LogicalResult spirv::MatrixTimesScalarOp::verify() {
1650 Type elementType =
1651 llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1652 .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
1653 [](auto matrixType) { return matrixType.getElementType(); })
1654 .Default([](Type) { return nullptr; });
1655
1656 assert(elementType && "Unhandled type");
1657
1658 // Check that the scalar type is the same as the matrix element type.
1659 if (getScalar().getType() != elementType)
1660 return emitOpError("input matrix components' type and scaling value must "
1661 "have the same type");
1662
1663 return success();
1664}
1665
1666//===----------------------------------------------------------------------===//
1667// spirv.Transpose
1668//===----------------------------------------------------------------------===//
1669
1670LogicalResult spirv::TransposeOp::verify() {
1671 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1672 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1673
1674 // Verify that the input and output matrices have correct shapes.
1675 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1676 return emitError("input matrix rows count must be equal to "
1677 "output matrix columns count");
1678
1679 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1680 return emitError("input matrix columns count must be equal to "
1681 "output matrix rows count");
1682
1683 // Verify that the input and output matrices have the same component type
1684 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1685 return emitError("input and output matrices must have the same "
1686 "component type");
1687
1688 return success();
1689}
1690
1691//===----------------------------------------------------------------------===//
1692// spirv.MatrixTimesVector
1693//===----------------------------------------------------------------------===//
1694
1695LogicalResult spirv::MatrixTimesVectorOp::verify() {
1696 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1697 auto vectorType = llvm::cast<VectorType>(getVector().getType());
1698 auto resultType = llvm::cast<VectorType>(getType());
1699
1700 if (matrixType.getNumColumns() != vectorType.getNumElements())
1701 return emitOpError("matrix columns (")
1702 << matrixType.getNumColumns() << ") must match vector operand size ("
1703 << vectorType.getNumElements() << ")";
1704
1705 if (resultType.getNumElements() != matrixType.getNumRows())
1706 return emitOpError("result size (")
1707 << resultType.getNumElements() << ") must match the matrix rows ("
1708 << matrixType.getNumRows() << ")";
1709
1710 if (matrixType.getElementType() != resultType.getElementType())
1711 return emitOpError("matrix and result element types must match");
1712
1713 return success();
1714}
1715
1716//===----------------------------------------------------------------------===//
1717// spirv.VectorTimesMatrix
1718//===----------------------------------------------------------------------===//
1719
1720LogicalResult spirv::VectorTimesMatrixOp::verify() {
1721 auto vectorType = llvm::cast<VectorType>(getVector().getType());
1722 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1723 auto resultType = llvm::cast<VectorType>(getType());
1724
1725 if (matrixType.getNumRows() != vectorType.getNumElements())
1726 return emitOpError("number of components in vector must equal the number "
1727 "of components in each column in matrix");
1728
1729 if (resultType.getNumElements() != matrixType.getNumColumns())
1730 return emitOpError("number of columns in matrix must equal the number of "
1731 "components in result");
1732
1733 if (matrixType.getElementType() != resultType.getElementType())
1734 return emitOpError("matrix must be a matrix with the same component type "
1735 "as the component type in result");
1736
1737 return success();
1738}
1739
1740//===----------------------------------------------------------------------===//
1741// spirv.MatrixTimesMatrix
1742//===----------------------------------------------------------------------===//
1743
1744LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1745 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1746 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1747 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1748
1749 // left matrix columns' count and right matrix rows' count must be equal
1750 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1751 return emitError("left matrix columns' count must be equal to "
1752 "the right matrix rows' count");
1753
1754 // right and result matrices columns' count must be the same
1755 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1756 return emitError(
1757 "right and result matrices must have equal columns' count");
1758
1759 // right and result matrices component type must be the same
1760 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1761 return emitError("right and result matrices' component type must"
1762 " be the same");
1763
1764 // left and result matrices component type must be the same
1765 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1766 return emitError("left and result matrices' component type"
1767 " must be the same");
1768
1769 // left and result matrices rows count must be the same
1770 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1771 return emitError("left and result matrices must have equal rows' count");
1772
1773 return success();
1774}
1775
1776//===----------------------------------------------------------------------===//
1777// spirv.SpecConstantComposite
1778//===----------------------------------------------------------------------===//
1779
1780ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1781 OperationState &result) {
1782
1783 StringAttr compositeName;
1784 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1785 result.attributes))
1786 return failure();
1787
1788 if (parser.parseLParen())
1789 return failure();
1790
1791 SmallVector<Attribute, 4> constituents;
1792
1793 do {
1794 // The name of the constituent attribute isn't important
1795 const char *attrName = "spec_const";
1796 FlatSymbolRefAttr specConstRef;
1797 NamedAttrList attrs;
1798
1799 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1800 return failure();
1801
1802 constituents.push_back(specConstRef);
1803 } while (!parser.parseOptionalComma());
1804
1805 if (parser.parseRParen())
1806 return failure();
1807
1808 StringAttr compositeSpecConstituentsName =
1809 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1810 result.addAttribute(compositeSpecConstituentsName,
1811 parser.getBuilder().getArrayAttr(constituents));
1812
1813 Type type;
1814 if (parser.parseColonType(type))
1815 return failure();
1816
1817 StringAttr typeAttrName =
1818 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1819 result.addAttribute(typeAttrName, TypeAttr::get(type));
1820
1821 return success();
1822}
1823
1824void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1825 printer << " ";
1826 printer.printSymbolName(getSymName());
1827 printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1828 << ") : " << getType();
1829}
1830
1831LogicalResult spirv::SpecConstantCompositeOp::verify() {
1832 auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1833 auto constituents = this->getConstituents().getValue();
1834
1835 if (!cType)
1836 return emitError("result type must be a composite type, but provided ")
1837 << getType();
1838
1839 if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1840 return emitError("unsupported composite type ") << cType;
1841 if (constituents.size() != cType.getNumElements())
1842 return emitError("has incorrect number of operands: expected ")
1843 << cType.getNumElements() << ", but provided "
1844 << constituents.size();
1845
1846 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1847 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1848
1849 auto constituentSpecConstOp =
1850 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1851 (*this)->getParentOp(), constituent.getAttr()));
1852
1853 if (constituentSpecConstOp.getDefaultValue().getType() !=
1854 cType.getElementType(index))
1855 return emitError("has incorrect types of operands: expected ")
1856 << cType.getElementType(index) << ", but provided "
1857 << constituentSpecConstOp.getDefaultValue().getType();
1858 }
1859
1860 return success();
1861}
1862
1863//===----------------------------------------------------------------------===//
1864// spirv.SpecConstantOperation
1865//===----------------------------------------------------------------------===//
1866
1867ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1868 OperationState &result) {
1869 Region *body = result.addRegion();
1870
1871 if (parser.parseKeyword("wraps"))
1872 return failure();
1873
1874 body->push_back(new Block);
1875 Block &block = body->back();
1876 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1877
1878 if (!wrappedOp)
1879 return failure();
1880
1881 OpBuilder builder(parser.getContext());
1882 builder.setInsertionPointToEnd(&block);
1883 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1884 result.location = wrappedOp->getLoc();
1885
1886 result.addTypes(wrappedOp->getResult(0).getType());
1887
1888 if (parser.parseOptionalAttrDict(result.attributes))
1889 return failure();
1890
1891 return success();
1892}
1893
1894void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
1895 printer << " wraps ";
1896 printer.printGenericOp(&getBody().front().front());
1897}
1898
1899LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1900 Block &block = getRegion().getBlocks().front();
1901
1902 if (block.getOperations().size() != 2)
1903 return emitOpError("expected exactly 2 nested ops");
1904
1905 Operation &enclosedOp = block.getOperations().front();
1906
1907 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
1908 return emitOpError("invalid enclosed op");
1909
1910 for (auto operand : enclosedOp.getOperands())
1911 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1912 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1913 return emitOpError(
1914 "invalid operand, must be defined by a constant operation");
1915
1916 return success();
1917}
1918
1919//===----------------------------------------------------------------------===//
1920// spirv.GL.FrexpStruct
1921//===----------------------------------------------------------------------===//
1922
1923LogicalResult spirv::GLFrexpStructOp::verify() {
1924 spirv::StructType structTy =
1925 llvm::dyn_cast<spirv::StructType>(getResult().getType());
1926
1927 if (structTy.getNumElements() != 2)
1928 return emitError("result type must be a struct type with two memebers");
1929
1930 Type significandTy = structTy.getElementType(0);
1931 Type exponentTy = structTy.getElementType(1);
1932 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1933 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1934
1935 Type operandTy = getOperand().getType();
1936 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1937 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1938
1939 if (significandTy != operandTy)
1940 return emitError("member zero of the resulting struct type must be the "
1941 "same type as the operand");
1942
1943 if (exponentVecTy) {
1944 IntegerType componentIntTy =
1945 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1946 if (!componentIntTy || componentIntTy.getWidth() != 32)
1947 return emitError("member one of the resulting struct type must"
1948 "be a scalar or vector of 32 bit integer type");
1949 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1950 return emitError("member one of the resulting struct type "
1951 "must be a scalar or vector of 32 bit integer type");
1952 }
1953
1954 // Check that the two member types have the same number of components
1955 if (operandVecTy && exponentVecTy &&
1956 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1957 return success();
1958
1959 if (operandFTy && exponentIntTy)
1960 return success();
1961
1962 return emitError("member one of the resulting struct type must have the same "
1963 "number of components as the operand type");
1964}
1965
1966//===----------------------------------------------------------------------===//
1967// spirv.GL.Ldexp
1968//===----------------------------------------------------------------------===//
1969
1970LogicalResult spirv::GLLdexpOp::verify() {
1971 Type significandType = getX().getType();
1972 Type exponentType = getExp().getType();
1973
1974 if (llvm::isa<FloatType>(significandType) !=
1975 llvm::isa<IntegerType>(exponentType))
1976 return emitOpError("operands must both be scalars or vectors");
1977
1978 auto getNumElements = [](Type type) -> unsigned {
1979 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1980 return vectorType.getNumElements();
1981 return 1;
1982 };
1983
1984 if (getNumElements(significandType) != getNumElements(exponentType))
1985 return emitOpError("operands must have the same number of elements");
1986
1987 return success();
1988}
1989
1990//===----------------------------------------------------------------------===//
1991// spirv.ShiftLeftLogicalOp
1992//===----------------------------------------------------------------------===//
1993
1994LogicalResult spirv::ShiftLeftLogicalOp::verify() {
1995 return verifyShiftOp(*this);
1996}
1997
1998//===----------------------------------------------------------------------===//
1999// spirv.ShiftRightArithmeticOp
2000//===----------------------------------------------------------------------===//
2001
2002LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2003 return verifyShiftOp(*this);
2004}
2005
2006//===----------------------------------------------------------------------===//
2007// spirv.ShiftRightLogicalOp
2008//===----------------------------------------------------------------------===//
2009
2010LogicalResult spirv::ShiftRightLogicalOp::verify() {
2011 return verifyShiftOp(*this);
2012}
2013
2014//===----------------------------------------------------------------------===//
2015// spirv.VectorTimesScalarOp
2016//===----------------------------------------------------------------------===//
2017
2018LogicalResult spirv::VectorTimesScalarOp::verify() {
2019 if (getVector().getType() != getType())
2020 return emitOpError("vector operand and result type mismatch");
2021 auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2022 if (getScalar().getType() != scalarType)
2023 return emitOpError("scalar operand and result element type match");
2024 return success();
2025}
2026

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp