1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
21#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
22#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
23#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/OpDefinition.h"
28#include "mlir/IR/OpImplementation.h"
29#include "mlir/IR/Operation.h"
30#include "mlir/IR/TypeUtilities.h"
31#include "mlir/Interfaces/FunctionImplementation.h"
32#include "mlir/Support/LogicalResult.h"
33#include "llvm/ADT/APFloat.h"
34#include "llvm/ADT/APInt.h"
35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/StringExtras.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include <cassert>
40#include <numeric>
41#include <optional>
42#include <type_traits>
43
44using namespace mlir;
45using namespace mlir::spirv::AttrNames;
46
47//===----------------------------------------------------------------------===//
48// Common utility functions
49//===----------------------------------------------------------------------===//
50
51LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
52 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
53 if (!constOp) {
54 return failure();
55 }
56 auto valueAttr = constOp.getValue();
57 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58 if (!integerValueAttr) {
59 return failure();
60 }
61
62 if (integerValueAttr.getType().isSignlessInteger())
63 value = integerValueAttr.getInt();
64 else
65 value = integerValueAttr.getSInt();
66
67 return success();
68}
69
70LogicalResult
71spirv::verifyMemorySemantics(Operation *op,
72 spirv::MemorySemantics memorySemantics) {
73 // According to the SPIR-V specification:
74 // "Despite being a mask and allowing multiple bits to be combined, it is
75 // invalid for more than one of these four bits to be set: Acquire, Release,
76 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
77 // Release semantics is done by setting the AcquireRelease bit, not by setting
78 // two bits."
79 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80 spirv::MemorySemantics::Release |
81 spirv::MemorySemantics::AcquireRelease |
82 spirv::MemorySemantics::SequentiallyConsistent;
83
84 auto bitCount =
85 llvm::popcount(Value: static_cast<uint32_t>(memorySemantics & atMostOneInSet));
86 if (bitCount > 1) {
87 return op->emitError(
88 message: "expected at most one of these four memory constraints "
89 "to be set: `Acquire`, `Release`,"
90 "`AcquireRelease` or `SequentiallyConsistent`");
91 }
92 return success();
93}
94
95void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer,
96 SmallVectorImpl<StringRef> &elidedAttrs) {
97 // Print optional descriptor binding
98 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::DescriptorSet));
100 auto bindingName = llvm::convertToSnakeFromCamelCase(
101 stringifyDecoration(spirv::Decoration::Binding));
102 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
103 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
104 if (descriptorSet && binding) {
105 elidedAttrs.push_back(Elt: descriptorSetName);
106 elidedAttrs.push_back(Elt: bindingName);
107 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
108 << ")";
109 }
110
111 // Print BuiltIn attribute if present
112 auto builtInName = llvm::convertToSnakeFromCamelCase(
113 stringifyDecoration(spirv::Decoration::BuiltIn));
114 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
115 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
116 elidedAttrs.push_back(Elt: builtInName);
117 }
118
119 printer.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs);
120}
121
122static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
123 OperationState &result) {
124 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
125 Type type;
126 // If the operand list is in-between parentheses, then we have a generic form.
127 // (see the fallback in `printOneResultOp`).
128 SMLoc loc = parser.getCurrentLocation();
129 if (!parser.parseOptionalLParen()) {
130 if (parser.parseOperandList(result&: ops) || parser.parseRParen() ||
131 parser.parseOptionalAttrDict(result&: result.attributes) ||
132 parser.parseColon() || parser.parseType(result&: type))
133 return failure();
134 auto fnType = llvm::dyn_cast<FunctionType>(type);
135 if (!fnType) {
136 parser.emitError(loc, message: "expected function type");
137 return failure();
138 }
139 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
140 return failure();
141 result.addTypes(fnType.getResults());
142 return success();
143 }
144 return failure(isFailure: parser.parseOperandList(result&: ops) ||
145 parser.parseOptionalAttrDict(result&: result.attributes) ||
146 parser.parseColonType(result&: type) ||
147 parser.resolveOperands(operands&: ops, type, result&: result.operands) ||
148 parser.addTypeToList(type, result&: result.types));
149}
150
151static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
152 assert(op->getNumResults() == 1 && "op should have one result");
153
154 // If not all the operand and result types are the same, just use the
155 // generic assembly form to avoid omitting information in printing.
156 auto resultType = op->getResult(idx: 0).getType();
157 if (llvm::any_of(Range: op->getOperandTypes(),
158 P: [&](Type type) { return type != resultType; })) {
159 p.printGenericOp(op, /*printOpName=*/false);
160 return;
161 }
162
163 p << ' ';
164 p.printOperands(container: op->getOperands());
165 p.printOptionalAttrDict(attrs: op->getAttrs());
166 // Now we can output only one type for all operands and the result.
167 p << " : " << resultType;
168}
169
170template <typename Op>
171static LogicalResult verifyImageOperands(Op imageOp,
172 spirv::ImageOperandsAttr attr,
173 Operation::operand_range operands) {
174 if (!attr) {
175 if (operands.empty())
176 return success();
177
178 return imageOp.emitError("the Image Operands should encode what operands "
179 "follow, as per Image Operands");
180 }
181
182 // TODO: Add the validation rules for the following Image Operands.
183 spirv::ImageOperands noSupportOperands =
184 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
185 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
186 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
187 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
188 spirv::ImageOperands::MakeTexelAvailable |
189 spirv::ImageOperands::MakeTexelVisible |
190 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
191
192 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
193 llvm_unreachable("unimplemented operands of Image Operands");
194
195 return success();
196}
197
198template <typename BlockReadWriteOpTy>
199static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
200 Value ptr, Value val) {
201 auto valType = val.getType();
202 if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
203 valType = valVecTy.getElementType();
204
205 if (valType !=
206 llvm::cast<spirv::PointerType>(Val: ptr.getType()).getPointeeType()) {
207 return op.emitOpError("mismatch in result type and pointer type");
208 }
209 return success();
210}
211
212/// Walks the given type hierarchy with the given indices, potentially down
213/// to component granularity, to select an element type. Returns null type and
214/// emits errors with the given loc on failure.
215static Type
216getElementType(Type type, ArrayRef<int32_t> indices,
217 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
218 if (indices.empty()) {
219 emitErrorFn("expected at least one index for spirv.CompositeExtract");
220 return nullptr;
221 }
222
223 for (auto index : indices) {
224 if (auto cType = llvm::dyn_cast<spirv::CompositeType>(Val&: type)) {
225 if (cType.hasCompileTimeKnownNumElements() &&
226 (index < 0 ||
227 static_cast<uint64_t>(index) >= cType.getNumElements())) {
228 emitErrorFn("index ") << index << " out of bounds for " << type;
229 return nullptr;
230 }
231 type = cType.getElementType(index);
232 } else {
233 emitErrorFn("cannot extract from non-composite type ")
234 << type << " with index " << index;
235 return nullptr;
236 }
237 }
238 return type;
239}
240
241static Type
242getElementType(Type type, Attribute indices,
243 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
244 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
245 if (!indicesArrayAttr) {
246 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
247 return nullptr;
248 }
249 if (indicesArrayAttr.empty()) {
250 emitErrorFn("expected at least one index for spirv.CompositeExtract");
251 return nullptr;
252 }
253
254 SmallVector<int32_t, 2> indexVals;
255 for (auto indexAttr : indicesArrayAttr) {
256 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
257 if (!indexIntAttr) {
258 emitErrorFn("expected an 32-bit integer for index, but found '")
259 << indexAttr << "'";
260 return nullptr;
261 }
262 indexVals.push_back(indexIntAttr.getInt());
263 }
264 return getElementType(type, indices: indexVals, emitErrorFn);
265}
266
267static Type getElementType(Type type, Attribute indices, Location loc) {
268 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
269 return ::mlir::emitError(loc, message: err);
270 };
271 return getElementType(type, indices, emitErrorFn: errorFn);
272}
273
274static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
275 SMLoc loc) {
276 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
277 return parser.emitError(loc, message: err);
278 };
279 return getElementType(type, indices, emitErrorFn: errorFn);
280}
281
282template <typename ExtendedBinaryOp>
283static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
284 auto resultType = llvm::cast<spirv::StructType>(op.getType());
285 if (resultType.getNumElements() != 2)
286 return op.emitOpError("expected result struct type containing two members");
287
288 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
289 resultType.getElementType(0),
290 resultType.getElementType(1)}))
291 return op.emitOpError(
292 "expected all operand types and struct member types are the same");
293
294 return success();
295}
296
297static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
298 OperationState &result) {
299 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
300 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
301 parser.parseOperandList(result&: operands) || parser.parseColon())
302 return failure();
303
304 Type resultType;
305 SMLoc loc = parser.getCurrentLocation();
306 if (parser.parseType(result&: resultType))
307 return failure();
308
309 auto structType = llvm::dyn_cast<spirv::StructType>(Val&: resultType);
310 if (!structType || structType.getNumElements() != 2)
311 return parser.emitError(loc, message: "expected spirv.struct type with two members");
312
313 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
314 if (parser.resolveOperands(operands, types&: operandTypes, loc, result&: result.operands))
315 return failure();
316
317 result.addTypes(newTypes: resultType);
318 return success();
319}
320
321static void printArithmeticExtendedBinaryOp(Operation *op,
322 OpAsmPrinter &printer) {
323 printer << ' ';
324 printer.printOptionalAttrDict(attrs: op->getAttrs());
325 printer.printOperands(container: op->getOperands());
326 printer << " : " << op->getResultTypes().front();
327}
328
329static LogicalResult verifyShiftOp(Operation *op) {
330 if (op->getOperand(idx: 0).getType() != op->getResult(idx: 0).getType()) {
331 return op->emitError(message: "expected the same type for the first operand and "
332 "result, but provided ")
333 << op->getOperand(idx: 0).getType() << " and "
334 << op->getResult(idx: 0).getType();
335 }
336 return success();
337}
338
339//===----------------------------------------------------------------------===//
340// spirv.mlir.addressof
341//===----------------------------------------------------------------------===//
342
343void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
344 spirv::GlobalVariableOp var) {
345 build(builder, state, var.getType(), SymbolRefAttr::get(var));
346}
347
348LogicalResult spirv::AddressOfOp::verify() {
349 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
350 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
351 getVariableAttr()));
352 if (!varOp) {
353 return emitOpError("expected spirv.GlobalVariable symbol");
354 }
355 if (getPointer().getType() != varOp.getType()) {
356 return emitOpError(
357 "result type mismatch with the referenced global variable's type");
358 }
359 return success();
360}
361
362//===----------------------------------------------------------------------===//
363// spirv.CompositeConstruct
364//===----------------------------------------------------------------------===//
365
366LogicalResult spirv::CompositeConstructOp::verify() {
367 operand_range constituents = this->getConstituents();
368
369 // There are 4 cases with varying verification rules:
370 // 1. Cooperative Matrices (1 constituent)
371 // 2. Structs (1 constituent for each member)
372 // 3. Arrays (1 constituent for each array element)
373 // 4. Vectors (1 constituent (sub-)element for each vector element)
374
375 auto coopElementType =
376 llvm::TypeSwitch<Type, Type>(getType())
377 .Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>(
378 [](auto coopType) { return coopType.getElementType(); })
379 .Default([](Type) { return nullptr; });
380
381 // Case 1. -- matrices.
382 if (coopElementType) {
383 if (constituents.size() != 1)
384 return emitOpError("has incorrect number of operands: expected ")
385 << "1, but provided " << constituents.size();
386 if (coopElementType != constituents.front().getType())
387 return emitOpError("operand type mismatch: expected operand type ")
388 << coopElementType << ", but provided "
389 << constituents.front().getType();
390 return success();
391 }
392
393 // Case 2./3./4. -- number of constituents matches the number of elements.
394 auto cType = llvm::cast<spirv::CompositeType>(getType());
395 if (constituents.size() == cType.getNumElements()) {
396 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
397 if (constituents[index].getType() != cType.getElementType(index)) {
398 return emitOpError("operand type mismatch: expected operand type ")
399 << cType.getElementType(index) << ", but provided "
400 << constituents[index].getType();
401 }
402 }
403 return success();
404 }
405
406 // Case 4. -- check that all constituents add up tp the expected vector type.
407 auto resultType = llvm::dyn_cast<VectorType>(cType);
408 if (!resultType)
409 return emitOpError(
410 "expected to return a vector or cooperative matrix when the number of "
411 "constituents is less than what the result needs");
412
413 SmallVector<unsigned> sizes;
414 for (Value component : constituents) {
415 if (!llvm::isa<VectorType>(component.getType()) &&
416 !component.getType().isIntOrFloat())
417 return emitOpError("operand type mismatch: expected operand to have "
418 "a scalar or vector type, but provided ")
419 << component.getType();
420
421 Type elementType = component.getType();
422 if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
423 sizes.push_back(vectorType.getNumElements());
424 elementType = vectorType.getElementType();
425 } else {
426 sizes.push_back(1);
427 }
428
429 if (elementType != resultType.getElementType())
430 return emitOpError("operand element type mismatch: expected to be ")
431 << resultType.getElementType() << ", but provided " << elementType;
432 }
433 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
434 if (totalCount != cType.getNumElements())
435 return emitOpError("has incorrect number of operands: expected ")
436 << cType.getNumElements() << ", but provided " << totalCount;
437 return success();
438}
439
440//===----------------------------------------------------------------------===//
441// spirv.CompositeExtractOp
442//===----------------------------------------------------------------------===//
443
444void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
445 Value composite,
446 ArrayRef<int32_t> indices) {
447 auto indexAttr = builder.getI32ArrayAttr(indices);
448 auto elementType =
449 getElementType(composite.getType(), indexAttr, state.location);
450 if (!elementType) {
451 return;
452 }
453 build(builder, state, elementType, composite, indexAttr);
454}
455
456ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
457 OperationState &result) {
458 OpAsmParser::UnresolvedOperand compositeInfo;
459 Attribute indicesAttr;
460 StringRef indicesAttrName =
461 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
462 Type compositeType;
463 SMLoc attrLocation;
464
465 if (parser.parseOperand(compositeInfo) ||
466 parser.getCurrentLocation(&attrLocation) ||
467 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
468 parser.parseColonType(compositeType) ||
469 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
470 return failure();
471 }
472
473 Type resultType =
474 getElementType(compositeType, indicesAttr, parser, attrLocation);
475 if (!resultType) {
476 return failure();
477 }
478 result.addTypes(resultType);
479 return success();
480}
481
482void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
483 printer << ' ' << getComposite() << getIndices() << " : "
484 << getComposite().getType();
485}
486
487LogicalResult spirv::CompositeExtractOp::verify() {
488 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
489 auto resultType =
490 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
491 if (!resultType)
492 return failure();
493
494 if (resultType != getType()) {
495 return emitOpError("invalid result type: expected ")
496 << resultType << " but provided " << getType();
497 }
498
499 return success();
500}
501
502//===----------------------------------------------------------------------===//
503// spirv.CompositeInsert
504//===----------------------------------------------------------------------===//
505
506void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
507 Value object, Value composite,
508 ArrayRef<int32_t> indices) {
509 auto indexAttr = builder.getI32ArrayAttr(indices);
510 build(builder, state, composite.getType(), object, composite, indexAttr);
511}
512
513ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
514 OperationState &result) {
515 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
516 Type objectType, compositeType;
517 Attribute indicesAttr;
518 StringRef indicesAttrName =
519 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
520 auto loc = parser.getCurrentLocation();
521
522 return failure(
523 parser.parseOperandList(operands, 2) ||
524 parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
525 parser.parseColonType(objectType) ||
526 parser.parseKeywordType("into", compositeType) ||
527 parser.resolveOperands(operands, {objectType, compositeType}, loc,
528 result.operands) ||
529 parser.addTypesToList(compositeType, result.types));
530}
531
532LogicalResult spirv::CompositeInsertOp::verify() {
533 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
534 auto objectType =
535 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
536 if (!objectType)
537 return failure();
538
539 if (objectType != getObject().getType()) {
540 return emitOpError("object operand type should be ")
541 << objectType << ", but found " << getObject().getType();
542 }
543
544 if (getComposite().getType() != getType()) {
545 return emitOpError("result type should be the same as "
546 "the composite type, but found ")
547 << getComposite().getType() << " vs " << getType();
548 }
549
550 return success();
551}
552
553void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
554 printer << " " << getObject() << ", " << getComposite() << getIndices()
555 << " : " << getObject().getType() << " into "
556 << getComposite().getType();
557}
558
559//===----------------------------------------------------------------------===//
560// spirv.Constant
561//===----------------------------------------------------------------------===//
562
563ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
564 OperationState &result) {
565 Attribute value;
566 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
567 if (parser.parseAttribute(value, valueAttrName, result.attributes))
568 return failure();
569
570 Type type = NoneType::get(parser.getContext());
571 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
572 type = typedAttr.getType();
573 if (llvm::isa<NoneType, TensorType>(type)) {
574 if (parser.parseColonType(type))
575 return failure();
576 }
577
578 return parser.addTypeToList(type, result.types);
579}
580
581void spirv::ConstantOp::print(OpAsmPrinter &printer) {
582 printer << ' ' << getValue();
583 if (llvm::isa<spirv::ArrayType>(getType()))
584 printer << " : " << getType();
585}
586
587static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
588 Type opType) {
589 if (llvm::isa<IntegerAttr, FloatAttr>(Val: value)) {
590 auto valueType = llvm::cast<TypedAttr>(value).getType();
591 if (valueType != opType)
592 return op.emitOpError("result type (")
593 << opType << ") does not match value type (" << valueType << ")";
594 return success();
595 }
596 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
597 auto valueType = llvm::cast<TypedAttr>(value).getType();
598 if (valueType == opType)
599 return success();
600 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(Val&: opType);
601 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
602 if (!arrayType)
603 return op.emitOpError("result or element type (")
604 << opType << ") does not match value type (" << valueType
605 << "), must be the same or spirv.array";
606
607 int numElements = arrayType.getNumElements();
608 auto opElemType = arrayType.getElementType();
609 while (auto t = llvm::dyn_cast<spirv::ArrayType>(Val&: opElemType)) {
610 numElements *= t.getNumElements();
611 opElemType = t.getElementType();
612 }
613 if (!opElemType.isIntOrFloat())
614 return op.emitOpError("only support nested array result type");
615
616 auto valueElemType = shapedType.getElementType();
617 if (valueElemType != opElemType) {
618 return op.emitOpError("result element type (")
619 << opElemType << ") does not match value element type ("
620 << valueElemType << ")";
621 }
622
623 if (numElements != shapedType.getNumElements()) {
624 return op.emitOpError("result number of elements (")
625 << numElements << ") does not match value number of elements ("
626 << shapedType.getNumElements() << ")";
627 }
628 return success();
629 }
630 if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
631 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(Val&: opType);
632 if (!arrayType)
633 return op.emitOpError(
634 "must have spirv.array result type for array value");
635 Type elemType = arrayType.getElementType();
636 for (Attribute element : arrayAttr.getValue()) {
637 // Verify array elements recursively.
638 if (failed(verifyConstantType(op, element, elemType)))
639 return failure();
640 }
641 return success();
642 }
643 return op.emitOpError("cannot have attribute: ") << value;
644}
645
646LogicalResult spirv::ConstantOp::verify() {
647 // ODS already generates checks to make sure the result type is valid. We just
648 // need to additionally check that the value's attribute type is consistent
649 // with the result type.
650 return verifyConstantType(*this, getValueAttr(), getType());
651}
652
653bool spirv::ConstantOp::isBuildableWith(Type type) {
654 // Must be valid SPIR-V type first.
655 if (!llvm::isa<spirv::SPIRVType>(type))
656 return false;
657
658 if (isa<SPIRVDialect>(type.getDialect())) {
659 // TODO: support constant struct
660 return llvm::isa<spirv::ArrayType>(type);
661 }
662
663 return true;
664}
665
666spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
667 OpBuilder &builder) {
668 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
669 unsigned width = intType.getWidth();
670 if (width == 1)
671 return builder.create<spirv::ConstantOp>(loc, type,
672 builder.getBoolAttr(false));
673 return builder.create<spirv::ConstantOp>(
674 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
675 }
676 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
677 return builder.create<spirv::ConstantOp>(
678 loc, type, builder.getFloatAttr(floatType, 0.0));
679 }
680 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
681 Type elemType = vectorType.getElementType();
682 if (llvm::isa<IntegerType>(elemType)) {
683 return builder.create<spirv::ConstantOp>(
684 loc, type,
685 DenseElementsAttr::get(vectorType,
686 IntegerAttr::get(elemType, 0).getValue()));
687 }
688 if (llvm::isa<FloatType>(elemType)) {
689 return builder.create<spirv::ConstantOp>(
690 loc, type,
691 DenseFPElementsAttr::get(vectorType,
692 FloatAttr::get(elemType, 0.0).getValue()));
693 }
694 }
695
696 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
697}
698
699spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
700 OpBuilder &builder) {
701 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
702 unsigned width = intType.getWidth();
703 if (width == 1)
704 return builder.create<spirv::ConstantOp>(loc, type,
705 builder.getBoolAttr(true));
706 return builder.create<spirv::ConstantOp>(
707 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
708 }
709 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
710 return builder.create<spirv::ConstantOp>(
711 loc, type, builder.getFloatAttr(floatType, 1.0));
712 }
713 if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
714 Type elemType = vectorType.getElementType();
715 if (llvm::isa<IntegerType>(elemType)) {
716 return builder.create<spirv::ConstantOp>(
717 loc, type,
718 DenseElementsAttr::get(vectorType,
719 IntegerAttr::get(elemType, 1).getValue()));
720 }
721 if (llvm::isa<FloatType>(elemType)) {
722 return builder.create<spirv::ConstantOp>(
723 loc, type,
724 DenseFPElementsAttr::get(vectorType,
725 FloatAttr::get(elemType, 1.0).getValue()));
726 }
727 }
728
729 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
730}
731
732void mlir::spirv::ConstantOp::getAsmResultNames(
733 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
734 Type type = getType();
735
736 SmallString<32> specialNameBuffer;
737 llvm::raw_svector_ostream specialName(specialNameBuffer);
738 specialName << "cst";
739
740 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
741
742 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
743 if (intTy && intTy.getWidth() == 1) {
744 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
745 }
746
747 if (intTy.isSignless()) {
748 specialName << intCst.getInt();
749 } else if (intTy.isUnsigned()) {
750 specialName << intCst.getUInt();
751 } else {
752 specialName << intCst.getSInt();
753 }
754 }
755
756 if (intTy || llvm::isa<FloatType>(type)) {
757 specialName << '_' << type;
758 }
759
760 if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
761 specialName << "_vec_";
762 specialName << vecType.getDimSize(0);
763
764 Type elementType = vecType.getElementType();
765
766 if (llvm::isa<IntegerType>(elementType) ||
767 llvm::isa<FloatType>(elementType)) {
768 specialName << "x" << elementType;
769 }
770 }
771
772 setNameFn(getResult(), specialName.str());
773}
774
775void mlir::spirv::AddressOfOp::getAsmResultNames(
776 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
777 SmallString<32> specialNameBuffer;
778 llvm::raw_svector_ostream specialName(specialNameBuffer);
779 specialName << getVariable() << "_addr";
780 setNameFn(getResult(), specialName.str());
781}
782
783//===----------------------------------------------------------------------===//
784// spirv.ControlBarrierOp
785//===----------------------------------------------------------------------===//
786
787LogicalResult spirv::ControlBarrierOp::verify() {
788 return verifyMemorySemantics(getOperation(), getMemorySemantics());
789}
790
791//===----------------------------------------------------------------------===//
792// spirv.EntryPoint
793//===----------------------------------------------------------------------===//
794
795void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
796 spirv::ExecutionModel executionModel,
797 spirv::FuncOp function,
798 ArrayRef<Attribute> interfaceVars) {
799 build(builder, state,
800 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
801 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
802}
803
804ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
805 OperationState &result) {
806 spirv::ExecutionModel execModel;
807 SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
808 SmallVector<Type, 0> idTypes;
809 SmallVector<Attribute, 4> interfaceVars;
810
811 FlatSymbolRefAttr fn;
812 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
813 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
814 return failure();
815 }
816
817 if (!parser.parseOptionalComma()) {
818 // Parse the interface variables
819 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
820 // The name of the interface variable attribute isnt important
821 FlatSymbolRefAttr var;
822 NamedAttrList attrs;
823 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
824 return failure();
825 interfaceVars.push_back(var);
826 return success();
827 }))
828 return failure();
829 }
830 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
831 parser.getBuilder().getArrayAttr(interfaceVars));
832 return success();
833}
834
835void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
836 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
837 printer.printSymbolName(getFn());
838 auto interfaceVars = getInterface().getValue();
839 if (!interfaceVars.empty()) {
840 printer << ", ";
841 llvm::interleaveComma(interfaceVars, printer);
842 }
843}
844
845LogicalResult spirv::EntryPointOp::verify() {
846 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
847 // verification.
848 return success();
849}
850
851//===----------------------------------------------------------------------===//
852// spirv.ExecutionMode
853//===----------------------------------------------------------------------===//
854
855void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
856 spirv::FuncOp function,
857 spirv::ExecutionMode executionMode,
858 ArrayRef<int32_t> params) {
859 build(builder, state, SymbolRefAttr::get(function),
860 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
861 builder.getI32ArrayAttr(params));
862}
863
864ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
865 OperationState &result) {
866 spirv::ExecutionMode execMode;
867 Attribute fn;
868 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
869 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
870 return failure();
871 }
872
873 SmallVector<int32_t, 4> values;
874 Type i32Type = parser.getBuilder().getIntegerType(32);
875 while (!parser.parseOptionalComma()) {
876 NamedAttrList attr;
877 Attribute value;
878 if (parser.parseAttribute(value, i32Type, "value", attr)) {
879 return failure();
880 }
881 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
882 }
883 StringRef valuesAttrName =
884 spirv::ExecutionModeOp::getValuesAttrName(result.name);
885 result.addAttribute(valuesAttrName,
886 parser.getBuilder().getI32ArrayAttr(values));
887 return success();
888}
889
890void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
891 printer << " ";
892 printer.printSymbolName(getFn());
893 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
894 auto values = this->getValues();
895 if (values.empty())
896 return;
897 printer << ", ";
898 llvm::interleaveComma(values, printer, [&](Attribute a) {
899 printer << llvm::cast<IntegerAttr>(a).getInt();
900 });
901}
902
903//===----------------------------------------------------------------------===//
904// spirv.func
905//===----------------------------------------------------------------------===//
906
907ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
908 SmallVector<OpAsmParser::Argument> entryArgs;
909 SmallVector<DictionaryAttr> resultAttrs;
910 SmallVector<Type> resultTypes;
911 auto &builder = parser.getBuilder();
912
913 // Parse the name as a symbol.
914 StringAttr nameAttr;
915 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
916 result.attributes))
917 return failure();
918
919 // Parse the function signature.
920 bool isVariadic = false;
921 if (function_interface_impl::parseFunctionSignature(
922 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
923 resultAttrs))
924 return failure();
925
926 SmallVector<Type> argTypes;
927 for (auto &arg : entryArgs)
928 argTypes.push_back(arg.type);
929 auto fnType = builder.getFunctionType(argTypes, resultTypes);
930 result.addAttribute(getFunctionTypeAttrName(result.name),
931 TypeAttr::get(fnType));
932
933 // Parse the optional function control keyword.
934 spirv::FunctionControl fnControl;
935 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
936 return failure();
937
938 // If additional attributes are present, parse them.
939 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
940 return failure();
941
942 // Add the attributes to the function arguments.
943 assert(resultAttrs.size() == resultTypes.size());
944 function_interface_impl::addArgAndResultAttrs(
945 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
946 getResAttrsAttrName(result.name));
947
948 // Parse the optional function body.
949 auto *body = result.addRegion();
950 OptionalParseResult parseResult =
951 parser.parseOptionalRegion(*body, entryArgs);
952 return failure(parseResult.has_value() && failed(*parseResult));
953}
954
955void spirv::FuncOp::print(OpAsmPrinter &printer) {
956 // Print function name, signature, and control.
957 printer << " ";
958 printer.printSymbolName(getSymName());
959 auto fnType = getFunctionType();
960 function_interface_impl::printFunctionSignature(
961 printer, *this, fnType.getInputs(),
962 /*isVariadic=*/false, fnType.getResults());
963 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
964 << "\"";
965 function_interface_impl::printFunctionAttributes(
966 printer, *this,
967 {spirv::attributeName<spirv::FunctionControl>(),
968 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
969 getFunctionControlAttrName()});
970
971 // Print the body if this is not an external function.
972 Region &body = this->getBody();
973 if (!body.empty()) {
974 printer << ' ';
975 printer.printRegion(body, /*printEntryBlockArgs=*/false,
976 /*printBlockTerminators=*/true);
977 }
978}
979
980LogicalResult spirv::FuncOp::verifyType() {
981 FunctionType fnType = getFunctionType();
982 if (fnType.getNumResults() > 1)
983 return emitOpError("cannot have more than one result");
984
985 auto hasDecorationAttr = [&](spirv::Decoration decoration,
986 unsigned argIndex) {
987 auto func = llvm::cast<FunctionOpInterface>(getOperation());
988 for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
989 if (argAttr.getName() != spirv::DecorationAttr::name)
990 continue;
991 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
992 return decAttr.getValue() == decoration;
993 }
994 return false;
995 };
996
997 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
998 Type param = fnType.getInputs()[i];
999 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1000 if (!inputPtrType)
1001 continue;
1002
1003 auto pointeePtrType =
1004 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1005 if (pointeePtrType) {
1006 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1007 // > If an OpFunctionParameter is a pointer (or contains a pointer)
1008 // > and the type it points to is a pointer in the PhysicalStorageBuffer
1009 // > storage class, the function parameter must be decorated with exactly
1010 // > one of AliasedPointer or RestrictPointer.
1011 if (pointeePtrType.getStorageClass() !=
1012 spirv::StorageClass::PhysicalStorageBuffer)
1013 continue;
1014
1015 bool hasAliasedPtr =
1016 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1017 bool hasRestrictPtr =
1018 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1019 if (!hasAliasedPtr && !hasRestrictPtr)
1020 return emitOpError()
1021 << "with a pointer points to a physical buffer pointer must "
1022 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1023 continue;
1024 }
1025 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1026 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1027 // > the PhysicalStorageBuffer storage class, the function parameter must
1028 // > be decorated with exactly one of Aliased or Restrict.
1029 if (auto pointeeArrayType =
1030 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1031 pointeePtrType =
1032 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1033 } else {
1034 pointeePtrType = inputPtrType;
1035 }
1036
1037 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1038 spirv::StorageClass::PhysicalStorageBuffer)
1039 continue;
1040
1041 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1042 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1043 if (!hasAliased && !hasRestrict)
1044 return emitOpError() << "with physical buffer pointer must be decorated "
1045 "either 'Aliased' or 'Restrict'";
1046 }
1047
1048 return success();
1049}
1050
1051LogicalResult spirv::FuncOp::verifyBody() {
1052 FunctionType fnType = getFunctionType();
1053
1054 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1055 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1056 if (fnType.getNumResults() != 0)
1057 return retOp.emitOpError("cannot be used in functions returning value");
1058 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1059 if (fnType.getNumResults() != 1)
1060 return retOp.emitOpError(
1061 "returns 1 value but enclosing function requires ")
1062 << fnType.getNumResults() << " results";
1063
1064 auto retOperandType = retOp.getValue().getType();
1065 auto fnResultType = fnType.getResult(0);
1066 if (retOperandType != fnResultType)
1067 return retOp.emitOpError(" return value's type (")
1068 << retOperandType << ") mismatch with function's result type ("
1069 << fnResultType << ")";
1070 }
1071 return WalkResult::advance();
1072 });
1073
1074 // TODO: verify other bits like linkage type.
1075
1076 return failure(walkResult.wasInterrupted());
1077}
1078
1079void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1080 StringRef name, FunctionType type,
1081 spirv::FunctionControl control,
1082 ArrayRef<NamedAttribute> attrs) {
1083 state.addAttribute(SymbolTable::getSymbolAttrName(),
1084 builder.getStringAttr(name));
1085 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1086 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1087 builder.getAttr<spirv::FunctionControlAttr>(control));
1088 state.attributes.append(attrs.begin(), attrs.end());
1089 state.addRegion();
1090}
1091
1092//===----------------------------------------------------------------------===//
1093// spirv.GLFClampOp
1094//===----------------------------------------------------------------------===//
1095
1096ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1097 OperationState &result) {
1098 return parseOneResultSameOperandTypeOp(parser, result);
1099}
1100void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1101
1102//===----------------------------------------------------------------------===//
1103// spirv.GLUClampOp
1104//===----------------------------------------------------------------------===//
1105
1106ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1107 OperationState &result) {
1108 return parseOneResultSameOperandTypeOp(parser, result);
1109}
1110void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1111
1112//===----------------------------------------------------------------------===//
1113// spirv.GLSClampOp
1114//===----------------------------------------------------------------------===//
1115
1116ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1117 OperationState &result) {
1118 return parseOneResultSameOperandTypeOp(parser, result);
1119}
1120void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1121
1122//===----------------------------------------------------------------------===//
1123// spirv.GLFmaOp
1124//===----------------------------------------------------------------------===//
1125
1126ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1127 return parseOneResultSameOperandTypeOp(parser, result);
1128}
1129void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1130
1131//===----------------------------------------------------------------------===//
1132// spirv.GlobalVariable
1133//===----------------------------------------------------------------------===//
1134
1135void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1136 Type type, StringRef name,
1137 unsigned descriptorSet, unsigned binding) {
1138 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1139 state.addAttribute(
1140 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1141 builder.getI32IntegerAttr(descriptorSet));
1142 state.addAttribute(
1143 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1144 builder.getI32IntegerAttr(binding));
1145}
1146
1147void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1148 Type type, StringRef name,
1149 spirv::BuiltIn builtin) {
1150 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1151 state.addAttribute(
1152 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1153 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1154}
1155
1156ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1157 OperationState &result) {
1158 // Parse variable name.
1159 StringAttr nameAttr;
1160 StringRef initializerAttrName =
1161 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1162 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1163 result.attributes)) {
1164 return failure();
1165 }
1166
1167 // Parse optional initializer
1168 if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1169 FlatSymbolRefAttr initSymbol;
1170 if (parser.parseLParen() ||
1171 parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1172 result.attributes) ||
1173 parser.parseRParen())
1174 return failure();
1175 }
1176
1177 if (parseVariableDecorations(parser, result)) {
1178 return failure();
1179 }
1180
1181 Type type;
1182 StringRef typeAttrName =
1183 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1184 auto loc = parser.getCurrentLocation();
1185 if (parser.parseColonType(type)) {
1186 return failure();
1187 }
1188 if (!llvm::isa<spirv::PointerType>(type)) {
1189 return parser.emitError(loc, "expected spirv.ptr type");
1190 }
1191 result.addAttribute(typeAttrName, TypeAttr::get(type));
1192
1193 return success();
1194}
1195
1196void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1197 SmallVector<StringRef, 4> elidedAttrs{
1198 spirv::attributeName<spirv::StorageClass>()};
1199
1200 // Print variable name.
1201 printer << ' ';
1202 printer.printSymbolName(getSymName());
1203 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1204
1205 StringRef initializerAttrName = this->getInitializerAttrName();
1206 // Print optional initializer
1207 if (auto initializer = this->getInitializer()) {
1208 printer << " " << initializerAttrName << '(';
1209 printer.printSymbolName(*initializer);
1210 printer << ')';
1211 elidedAttrs.push_back(initializerAttrName);
1212 }
1213
1214 StringRef typeAttrName = this->getTypeAttrName();
1215 elidedAttrs.push_back(typeAttrName);
1216 spirv::printVariableDecorations(*this, printer, elidedAttrs);
1217 printer << " : " << getType();
1218}
1219
1220LogicalResult spirv::GlobalVariableOp::verify() {
1221 if (!llvm::isa<spirv::PointerType>(getType()))
1222 return emitOpError("result must be of a !spv.ptr type");
1223
1224 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1225 // object. It cannot be Generic. It must be the same as the Storage Class
1226 // operand of the Result Type."
1227 // Also, Function storage class is reserved by spirv.Variable.
1228 auto storageClass = this->storageClass();
1229 if (storageClass == spirv::StorageClass::Generic ||
1230 storageClass == spirv::StorageClass::Function) {
1231 return emitOpError("storage class cannot be '")
1232 << stringifyStorageClass(storageClass) << "'";
1233 }
1234
1235 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1236 this->getInitializerAttrName())) {
1237 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
1238 (*this)->getParentOp(), init.getAttr());
1239 // TODO: Currently only variable initialization with specialization
1240 // constants and other variables is supported. They could be normal
1241 // constants in the module scope as well.
1242 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1243 spirv::SpecConstantCompositeOp>(initOp)) {
1244 return emitOpError("initializer must be result of a "
1245 "spirv.SpecConstant or spirv.GlobalVariable or "
1246 "spirv.SpecConstantCompositeOp op");
1247 }
1248 }
1249
1250 return success();
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// spirv.INTEL.SubgroupBlockRead
1255//===----------------------------------------------------------------------===//
1256
1257ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
1258 OperationState &result) {
1259 // Parse the storage class specification
1260 spirv::StorageClass storageClass;
1261 OpAsmParser::UnresolvedOperand ptrInfo;
1262 Type elementType;
1263 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
1264 parser.parseColon() || parser.parseType(elementType)) {
1265 return failure();
1266 }
1267
1268 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1269 if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1270 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1271
1272 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
1273 return failure();
1274 }
1275
1276 result.addTypes(elementType);
1277 return success();
1278}
1279
1280void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
1281 printer << " " << getPtr() << " : " << getType();
1282}
1283
1284LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1285 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1286 return failure();
1287
1288 return success();
1289}
1290
1291//===----------------------------------------------------------------------===//
1292// spirv.INTEL.SubgroupBlockWrite
1293//===----------------------------------------------------------------------===//
1294
1295ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1296 OperationState &result) {
1297 // Parse the storage class specification
1298 spirv::StorageClass storageClass;
1299 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1300 auto loc = parser.getCurrentLocation();
1301 Type elementType;
1302 if (parseEnumStrAttr(storageClass, parser) ||
1303 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1304 parser.parseType(elementType)) {
1305 return failure();
1306 }
1307
1308 auto ptrType = spirv::PointerType::get(elementType, storageClass);
1309 if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1310 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1311
1312 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1313 result.operands)) {
1314 return failure();
1315 }
1316 return success();
1317}
1318
1319void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1320 printer << " " << getPtr() << ", " << getValue() << " : "
1321 << getValue().getType();
1322}
1323
1324LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1325 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1326 return failure();
1327
1328 return success();
1329}
1330
1331//===----------------------------------------------------------------------===//
1332// spirv.IAddCarryOp
1333//===----------------------------------------------------------------------===//
1334
1335LogicalResult spirv::IAddCarryOp::verify() {
1336 return ::verifyArithmeticExtendedBinaryOp(*this);
1337}
1338
1339ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1340 OperationState &result) {
1341 return ::parseArithmeticExtendedBinaryOp(parser, result);
1342}
1343
1344void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1345 ::printArithmeticExtendedBinaryOp(*this, printer);
1346}
1347
1348//===----------------------------------------------------------------------===//
1349// spirv.ISubBorrowOp
1350//===----------------------------------------------------------------------===//
1351
1352LogicalResult spirv::ISubBorrowOp::verify() {
1353 return ::verifyArithmeticExtendedBinaryOp(*this);
1354}
1355
1356ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1357 OperationState &result) {
1358 return ::parseArithmeticExtendedBinaryOp(parser, result);
1359}
1360
1361void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1362 ::printArithmeticExtendedBinaryOp(*this, printer);
1363}
1364
1365//===----------------------------------------------------------------------===//
1366// spirv.SMulExtended
1367//===----------------------------------------------------------------------===//
1368
1369LogicalResult spirv::SMulExtendedOp::verify() {
1370 return ::verifyArithmeticExtendedBinaryOp(*this);
1371}
1372
1373ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1374 OperationState &result) {
1375 return ::parseArithmeticExtendedBinaryOp(parser, result);
1376}
1377
1378void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1379 ::printArithmeticExtendedBinaryOp(*this, printer);
1380}
1381
1382//===----------------------------------------------------------------------===//
1383// spirv.UMulExtended
1384//===----------------------------------------------------------------------===//
1385
1386LogicalResult spirv::UMulExtendedOp::verify() {
1387 return ::verifyArithmeticExtendedBinaryOp(*this);
1388}
1389
1390ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1391 OperationState &result) {
1392 return ::parseArithmeticExtendedBinaryOp(parser, result);
1393}
1394
1395void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1396 ::printArithmeticExtendedBinaryOp(*this, printer);
1397}
1398
1399//===----------------------------------------------------------------------===//
1400// spirv.MemoryBarrierOp
1401//===----------------------------------------------------------------------===//
1402
1403LogicalResult spirv::MemoryBarrierOp::verify() {
1404 return verifyMemorySemantics(getOperation(), getMemorySemantics());
1405}
1406
1407//===----------------------------------------------------------------------===//
1408// spirv.module
1409//===----------------------------------------------------------------------===//
1410
1411void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1412 std::optional<StringRef> name) {
1413 OpBuilder::InsertionGuard guard(builder);
1414 builder.createBlock(state.addRegion());
1415 if (name) {
1416 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1417 builder.getStringAttr(*name));
1418 }
1419}
1420
1421void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1422 spirv::AddressingModel addressingModel,
1423 spirv::MemoryModel memoryModel,
1424 std::optional<VerCapExtAttr> vceTriple,
1425 std::optional<StringRef> name) {
1426 state.addAttribute(
1427 "addressing_model",
1428 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1429 state.addAttribute("memory_model",
1430 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1431 OpBuilder::InsertionGuard guard(builder);
1432 builder.createBlock(state.addRegion());
1433 if (vceTriple)
1434 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1435 if (name)
1436 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1437 builder.getStringAttr(*name));
1438}
1439
1440ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1441 OperationState &result) {
1442 Region *body = result.addRegion();
1443
1444 // If the name is present, parse it.
1445 StringAttr nameAttr;
1446 (void)parser.parseOptionalSymbolName(
1447 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1448
1449 // Parse attributes
1450 spirv::AddressingModel addrModel;
1451 spirv::MemoryModel memoryModel;
1452 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1453 result) ||
1454 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1455 result))
1456 return failure();
1457
1458 if (succeeded(parser.parseOptionalKeyword("requires"))) {
1459 spirv::VerCapExtAttr vceTriple;
1460 if (parser.parseAttribute(vceTriple,
1461 spirv::ModuleOp::getVCETripleAttrName(),
1462 result.attributes))
1463 return failure();
1464 }
1465
1466 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1467 parser.parseRegion(*body, /*arguments=*/{}))
1468 return failure();
1469
1470 // Make sure we have at least one block.
1471 if (body->empty())
1472 body->push_back(new Block());
1473
1474 return success();
1475}
1476
1477void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1478 if (std::optional<StringRef> name = getName()) {
1479 printer << ' ';
1480 printer.printSymbolName(*name);
1481 }
1482
1483 SmallVector<StringRef, 2> elidedAttrs;
1484
1485 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1486 << spirv::stringifyMemoryModel(getMemoryModel());
1487 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1488 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1489 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1490 mlir::SymbolTable::getSymbolAttrName()});
1491
1492 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1493 printer << " requires " << *triple;
1494 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1495 }
1496
1497 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1498 printer << ' ';
1499 printer.printRegion(getRegion());
1500}
1501
1502LogicalResult spirv::ModuleOp::verifyRegions() {
1503 Dialect *dialect = (*this)->getDialect();
1504 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
1505 entryPoints;
1506 mlir::SymbolTable table(*this);
1507
1508 for (auto &op : *getBody()) {
1509 if (op.getDialect() != dialect)
1510 return op.emitError("'spirv.module' can only contain spirv.* ops");
1511
1512 // For EntryPoint op, check that the function and execution model is not
1513 // duplicated in EntryPointOps. Also verify that the interface specified
1514 // comes from globalVariables here to make this check cheaper.
1515 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1516 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1517 if (!funcOp) {
1518 return entryPointOp.emitError("function '")
1519 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1520 }
1521 if (auto interface = entryPointOp.getInterface()) {
1522 for (Attribute varRef : interface) {
1523 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1524 if (!varSymRef) {
1525 return entryPointOp.emitError(
1526 "expected symbol reference for interface "
1527 "specification instead of '")
1528 << varRef;
1529 }
1530 auto variableOp =
1531 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1532 if (!variableOp) {
1533 return entryPointOp.emitError("expected spirv.GlobalVariable "
1534 "symbol reference instead of'")
1535 << varSymRef << "'";
1536 }
1537 }
1538 }
1539
1540 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1541 funcOp, entryPointOp.getExecutionModel());
1542 auto entryPtIt = entryPoints.find(key);
1543 if (entryPtIt != entryPoints.end()) {
1544 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1545 }
1546 entryPoints[key] = entryPointOp;
1547 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1548 // If the function is external and does not have 'Import'
1549 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1550 // LinkageAttributes is used to import external functions.
1551 auto linkageAttr = funcOp.getLinkageAttributes();
1552 auto hasImportLinkage =
1553 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1554 spirv::LinkageType::Import);
1555 if (funcOp.isExternal() && !hasImportLinkage)
1556 return op.emitError(
1557 "'spirv.module' cannot contain external functions "
1558 "without 'Import' linkage_attributes (LinkageAttributes)");
1559
1560 // TODO: move this check to spirv.func.
1561 for (auto &block : funcOp)
1562 for (auto &op : block) {
1563 if (op.getDialect() != dialect)
1564 return op.emitError(
1565 "functions in 'spirv.module' can only contain spirv.* ops");
1566 }
1567 }
1568 }
1569
1570 return success();
1571}
1572
1573//===----------------------------------------------------------------------===//
1574// spirv.mlir.referenceof
1575//===----------------------------------------------------------------------===//
1576
1577LogicalResult spirv::ReferenceOfOp::verify() {
1578 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1579 (*this)->getParentOp(), getSpecConstAttr());
1580 Type constType;
1581
1582 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1583 if (specConstOp)
1584 constType = specConstOp.getDefaultValue().getType();
1585
1586 auto specConstCompositeOp =
1587 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1588 if (specConstCompositeOp)
1589 constType = specConstCompositeOp.getType();
1590
1591 if (!specConstOp && !specConstCompositeOp)
1592 return emitOpError(
1593 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1594
1595 if (getReference().getType() != constType)
1596 return emitOpError("result type mismatch with the referenced "
1597 "specialization constant's type");
1598
1599 return success();
1600}
1601
1602//===----------------------------------------------------------------------===//
1603// spirv.SpecConstant
1604//===----------------------------------------------------------------------===//
1605
1606ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1607 OperationState &result) {
1608 StringAttr nameAttr;
1609 Attribute valueAttr;
1610 StringRef defaultValueAttrName =
1611 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1612
1613 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1614 result.attributes))
1615 return failure();
1616
1617 // Parse optional spec_id.
1618 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1619 IntegerAttr specIdAttr;
1620 if (parser.parseLParen() ||
1621 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1622 parser.parseRParen())
1623 return failure();
1624 }
1625
1626 if (parser.parseEqual() ||
1627 parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1628 return failure();
1629
1630 return success();
1631}
1632
1633void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1634 printer << ' ';
1635 printer.printSymbolName(getSymName());
1636 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1637 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1638 printer << " = " << getDefaultValue();
1639}
1640
1641LogicalResult spirv::SpecConstantOp::verify() {
1642 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1643 if (specID.getValue().isNegative())
1644 return emitOpError("SpecId cannot be negative");
1645
1646 auto value = getDefaultValue();
1647 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1648 // Make sure bitwidth is allowed.
1649 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1650 return emitOpError("default value bitwidth disallowed");
1651 return success();
1652 }
1653 return emitOpError(
1654 "default value can only be a bool, integer, or float scalar");
1655}
1656
1657//===----------------------------------------------------------------------===//
1658// spirv.VectorShuffle
1659//===----------------------------------------------------------------------===//
1660
1661LogicalResult spirv::VectorShuffleOp::verify() {
1662 VectorType resultType = llvm::cast<VectorType>(getType());
1663
1664 size_t numResultElements = resultType.getNumElements();
1665 if (numResultElements != getComponents().size())
1666 return emitOpError("result type element count (")
1667 << numResultElements
1668 << ") mismatch with the number of component selectors ("
1669 << getComponents().size() << ")";
1670
1671 size_t totalSrcElements =
1672 llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1673 llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1674
1675 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1676 uint32_t index = selector.getZExtValue();
1677 if (index >= totalSrcElements &&
1678 index != std::numeric_limits<uint32_t>().max())
1679 return emitOpError("component selector ")
1680 << index << " out of range: expected to be in [0, "
1681 << totalSrcElements << ") or 0xffffffff";
1682 }
1683 return success();
1684}
1685
1686//===----------------------------------------------------------------------===//
1687// spirv.MatrixTimesScalar
1688//===----------------------------------------------------------------------===//
1689
1690LogicalResult spirv::MatrixTimesScalarOp::verify() {
1691 Type elementType =
1692 llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1693 .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
1694 [](auto matrixType) { return matrixType.getElementType(); })
1695 .Default([](Type) { return nullptr; });
1696
1697 assert(elementType && "Unhandled type");
1698
1699 // Check that the scalar type is the same as the matrix element type.
1700 if (getScalar().getType() != elementType)
1701 return emitOpError("input matrix components' type and scaling value must "
1702 "have the same type");
1703
1704 return success();
1705}
1706
1707//===----------------------------------------------------------------------===//
1708// spirv.Transpose
1709//===----------------------------------------------------------------------===//
1710
1711LogicalResult spirv::TransposeOp::verify() {
1712 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1713 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1714
1715 // Verify that the input and output matrices have correct shapes.
1716 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1717 return emitError("input matrix rows count must be equal to "
1718 "output matrix columns count");
1719
1720 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1721 return emitError("input matrix columns count must be equal to "
1722 "output matrix rows count");
1723
1724 // Verify that the input and output matrices have the same component type
1725 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1726 return emitError("input and output matrices must have the same "
1727 "component type");
1728
1729 return success();
1730}
1731
1732//===----------------------------------------------------------------------===//
1733// spirv.MatrixTimesMatrix
1734//===----------------------------------------------------------------------===//
1735
1736LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1737 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1738 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1739 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1740
1741 // left matrix columns' count and right matrix rows' count must be equal
1742 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1743 return emitError("left matrix columns' count must be equal to "
1744 "the right matrix rows' count");
1745
1746 // right and result matrices columns' count must be the same
1747 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1748 return emitError(
1749 "right and result matrices must have equal columns' count");
1750
1751 // right and result matrices component type must be the same
1752 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1753 return emitError("right and result matrices' component type must"
1754 " be the same");
1755
1756 // left and result matrices component type must be the same
1757 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1758 return emitError("left and result matrices' component type"
1759 " must be the same");
1760
1761 // left and result matrices rows count must be the same
1762 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1763 return emitError("left and result matrices must have equal rows' count");
1764
1765 return success();
1766}
1767
1768//===----------------------------------------------------------------------===//
1769// spirv.SpecConstantComposite
1770//===----------------------------------------------------------------------===//
1771
1772ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1773 OperationState &result) {
1774
1775 StringAttr compositeName;
1776 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1777 result.attributes))
1778 return failure();
1779
1780 if (parser.parseLParen())
1781 return failure();
1782
1783 SmallVector<Attribute, 4> constituents;
1784
1785 do {
1786 // The name of the constituent attribute isn't important
1787 const char *attrName = "spec_const";
1788 FlatSymbolRefAttr specConstRef;
1789 NamedAttrList attrs;
1790
1791 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1792 return failure();
1793
1794 constituents.push_back(specConstRef);
1795 } while (!parser.parseOptionalComma());
1796
1797 if (parser.parseRParen())
1798 return failure();
1799
1800 StringAttr compositeSpecConstituentsName =
1801 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1802 result.addAttribute(compositeSpecConstituentsName,
1803 parser.getBuilder().getArrayAttr(constituents));
1804
1805 Type type;
1806 if (parser.parseColonType(type))
1807 return failure();
1808
1809 StringAttr typeAttrName =
1810 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1811 result.addAttribute(typeAttrName, TypeAttr::get(type));
1812
1813 return success();
1814}
1815
1816void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1817 printer << " ";
1818 printer.printSymbolName(getSymName());
1819 printer << " (";
1820 auto constituents = this->getConstituents().getValue();
1821
1822 if (!constituents.empty())
1823 llvm::interleaveComma(constituents, printer);
1824
1825 printer << ") : " << getType();
1826}
1827
1828LogicalResult spirv::SpecConstantCompositeOp::verify() {
1829 auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1830 auto constituents = this->getConstituents().getValue();
1831
1832 if (!cType)
1833 return emitError("result type must be a composite type, but provided ")
1834 << getType();
1835
1836 if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1837 return emitError("unsupported composite type ") << cType;
1838 if (llvm::isa<spirv::JointMatrixINTELType>(cType))
1839 return emitError("unsupported composite type ") << cType;
1840 if (constituents.size() != cType.getNumElements())
1841 return emitError("has incorrect number of operands: expected ")
1842 << cType.getNumElements() << ", but provided "
1843 << constituents.size();
1844
1845 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1846 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1847
1848 auto constituentSpecConstOp =
1849 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1850 (*this)->getParentOp(), constituent.getAttr()));
1851
1852 if (constituentSpecConstOp.getDefaultValue().getType() !=
1853 cType.getElementType(index))
1854 return emitError("has incorrect types of operands: expected ")
1855 << cType.getElementType(index) << ", but provided "
1856 << constituentSpecConstOp.getDefaultValue().getType();
1857 }
1858
1859 return success();
1860}
1861
1862//===----------------------------------------------------------------------===//
1863// spirv.SpecConstantOperation
1864//===----------------------------------------------------------------------===//
1865
1866ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1867 OperationState &result) {
1868 Region *body = result.addRegion();
1869
1870 if (parser.parseKeyword("wraps"))
1871 return failure();
1872
1873 body->push_back(new Block);
1874 Block &block = body->back();
1875 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1876
1877 if (!wrappedOp)
1878 return failure();
1879
1880 OpBuilder builder(parser.getContext());
1881 builder.setInsertionPointToEnd(&block);
1882 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1883 result.location = wrappedOp->getLoc();
1884
1885 result.addTypes(wrappedOp->getResult(0).getType());
1886
1887 if (parser.parseOptionalAttrDict(result.attributes))
1888 return failure();
1889
1890 return success();
1891}
1892
1893void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
1894 printer << " wraps ";
1895 printer.printGenericOp(&getBody().front().front());
1896}
1897
1898LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1899 Block &block = getRegion().getBlocks().front();
1900
1901 if (block.getOperations().size() != 2)
1902 return emitOpError("expected exactly 2 nested ops");
1903
1904 Operation &enclosedOp = block.getOperations().front();
1905
1906 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
1907 return emitOpError("invalid enclosed op");
1908
1909 for (auto operand : enclosedOp.getOperands())
1910 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1911 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1912 return emitOpError(
1913 "invalid operand, must be defined by a constant operation");
1914
1915 return success();
1916}
1917
1918//===----------------------------------------------------------------------===//
1919// spirv.GL.FrexpStruct
1920//===----------------------------------------------------------------------===//
1921
1922LogicalResult spirv::GLFrexpStructOp::verify() {
1923 spirv::StructType structTy =
1924 llvm::dyn_cast<spirv::StructType>(getResult().getType());
1925
1926 if (structTy.getNumElements() != 2)
1927 return emitError("result type must be a struct type with two memebers");
1928
1929 Type significandTy = structTy.getElementType(0);
1930 Type exponentTy = structTy.getElementType(1);
1931 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1932 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1933
1934 Type operandTy = getOperand().getType();
1935 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1936 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1937
1938 if (significandTy != operandTy)
1939 return emitError("member zero of the resulting struct type must be the "
1940 "same type as the operand");
1941
1942 if (exponentVecTy) {
1943 IntegerType componentIntTy =
1944 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1945 if (!componentIntTy || componentIntTy.getWidth() != 32)
1946 return emitError("member one of the resulting struct type must"
1947 "be a scalar or vector of 32 bit integer type");
1948 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1949 return emitError("member one of the resulting struct type "
1950 "must be a scalar or vector of 32 bit integer type");
1951 }
1952
1953 // Check that the two member types have the same number of components
1954 if (operandVecTy && exponentVecTy &&
1955 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1956 return success();
1957
1958 if (operandFTy && exponentIntTy)
1959 return success();
1960
1961 return emitError("member one of the resulting struct type must have the same "
1962 "number of components as the operand type");
1963}
1964
1965//===----------------------------------------------------------------------===//
1966// spirv.GL.Ldexp
1967//===----------------------------------------------------------------------===//
1968
1969LogicalResult spirv::GLLdexpOp::verify() {
1970 Type significandType = getX().getType();
1971 Type exponentType = getExp().getType();
1972
1973 if (llvm::isa<FloatType>(significandType) !=
1974 llvm::isa<IntegerType>(exponentType))
1975 return emitOpError("operands must both be scalars or vectors");
1976
1977 auto getNumElements = [](Type type) -> unsigned {
1978 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1979 return vectorType.getNumElements();
1980 return 1;
1981 };
1982
1983 if (getNumElements(significandType) != getNumElements(exponentType))
1984 return emitOpError("operands must have the same number of elements");
1985
1986 return success();
1987}
1988
1989//===----------------------------------------------------------------------===//
1990// spirv.ImageDrefGather
1991//===----------------------------------------------------------------------===//
1992
1993LogicalResult spirv::ImageDrefGatherOp::verify() {
1994 VectorType resultType = llvm::cast<VectorType>(getResult().getType());
1995 auto sampledImageType =
1996 llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
1997 auto imageType =
1998 llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
1999
2000 if (resultType.getNumElements() != 4)
2001 return emitOpError("result type must be a vector of four components");
2002
2003 Type elementType = resultType.getElementType();
2004 Type sampledElementType = imageType.getElementType();
2005 if (!llvm::isa<NoneType>(sampledElementType) &&
2006 elementType != sampledElementType)
2007 return emitOpError(
2008 "the component type of result must be the same as sampled type of the "
2009 "underlying image type");
2010
2011 spirv::Dim imageDim = imageType.getDim();
2012 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
2013
2014 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
2015 imageDim != spirv::Dim::Rect)
2016 return emitOpError(
2017 "the Dim operand of the underlying image type must be 2D, Cube, or "
2018 "Rect");
2019
2020 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
2021 return emitOpError("the MS operand of the underlying image type must be 0");
2022
2023 spirv::ImageOperandsAttr attr = getImageoperandsAttr();
2024 auto operandArguments = getOperandArguments();
2025
2026 return verifyImageOperands(*this, attr, operandArguments);
2027}
2028
2029//===----------------------------------------------------------------------===//
2030// spirv.ShiftLeftLogicalOp
2031//===----------------------------------------------------------------------===//
2032
2033LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2034 return verifyShiftOp(*this);
2035}
2036
2037//===----------------------------------------------------------------------===//
2038// spirv.ShiftRightArithmeticOp
2039//===----------------------------------------------------------------------===//
2040
2041LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2042 return verifyShiftOp(*this);
2043}
2044
2045//===----------------------------------------------------------------------===//
2046// spirv.ShiftRightLogicalOp
2047//===----------------------------------------------------------------------===//
2048
2049LogicalResult spirv::ShiftRightLogicalOp::verify() {
2050 return verifyShiftOp(*this);
2051}
2052
2053//===----------------------------------------------------------------------===//
2054// spirv.ImageQuerySize
2055//===----------------------------------------------------------------------===//
2056
2057LogicalResult spirv::ImageQuerySizeOp::verify() {
2058 spirv::ImageType imageType =
2059 llvm::cast<spirv::ImageType>(getImage().getType());
2060 Type resultType = getResult().getType();
2061
2062 spirv::Dim dim = imageType.getDim();
2063 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
2064 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
2065 switch (dim) {
2066 case spirv::Dim::Dim1D:
2067 case spirv::Dim::Dim2D:
2068 case spirv::Dim::Dim3D:
2069 case spirv::Dim::Cube:
2070 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
2071 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
2072 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
2073 return emitError(
2074 "if Dim is 1D, 2D, 3D, or Cube, "
2075 "it must also have either an MS of 1 or a Sampled of 0 or 2");
2076 break;
2077 case spirv::Dim::Buffer:
2078 case spirv::Dim::Rect:
2079 break;
2080 default:
2081 return emitError("the Dim operand of the image type must "
2082 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
2083 }
2084
2085 unsigned componentNumber = 0;
2086 switch (dim) {
2087 case spirv::Dim::Dim1D:
2088 case spirv::Dim::Buffer:
2089 componentNumber = 1;
2090 break;
2091 case spirv::Dim::Dim2D:
2092 case spirv::Dim::Cube:
2093 case spirv::Dim::Rect:
2094 componentNumber = 2;
2095 break;
2096 case spirv::Dim::Dim3D:
2097 componentNumber = 3;
2098 break;
2099 default:
2100 break;
2101 }
2102
2103 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
2104 componentNumber += 1;
2105
2106 unsigned resultComponentNumber = 1;
2107 if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
2108 resultComponentNumber = resultVectorType.getNumElements();
2109
2110 if (componentNumber != resultComponentNumber)
2111 return emitError("expected the result to have ")
2112 << componentNumber << " component(s), but found "
2113 << resultComponentNumber << " component(s)";
2114
2115 return success();
2116}
2117
2118//===----------------------------------------------------------------------===//
2119// spirv.VectorTimesScalarOp
2120//===----------------------------------------------------------------------===//
2121
2122LogicalResult spirv::VectorTimesScalarOp::verify() {
2123 if (getVector().getType() != getType())
2124 return emitOpError("vector operand and result type mismatch");
2125 auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2126 if (getScalar().getType() != scalarType)
2127 return emitOpError("scalar operand and result element type match");
2128 return success();
2129}
2130

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp