1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
21#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
22#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
23#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/OpDefinition.h"
27#include "mlir/IR/OpImplementation.h"
28#include "mlir/IR/Operation.h"
29#include "mlir/IR/TypeUtilities.h"
30#include "mlir/Interfaces/FunctionImplementation.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cassert>
39#include <numeric>
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::spirv::AttrNames;
44
45//===----------------------------------------------------------------------===//
46// Common utility functions
47//===----------------------------------------------------------------------===//
48
49LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(Val: op);
51 if (!constOp) {
52 return failure();
53 }
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(Val&: valueAttr);
56 if (!integerValueAttr) {
57 return failure();
58 }
59
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
62 else
63 value = integerValueAttr.getSInt();
64
65 return success();
66}
67
68LogicalResult
69spirv::verifyMemorySemantics(Operation *op,
70 spirv::MemorySemantics memorySemantics) {
71 // According to the SPIR-V specification:
72 // "Despite being a mask and allowing multiple bits to be combined, it is
73 // invalid for more than one of these four bits to be set: Acquire, Release,
74 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75 // Release semantics is done by setting the AcquireRelease bit, not by setting
76 // two bits."
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
81
82 auto bitCount =
83 llvm::popcount(Value: static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84 if (bitCount > 1) {
85 return op->emitError(
86 message: "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
89 }
90 return success();
91}
92
93void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer,
94 SmallVectorImpl<StringRef> &elidedAttrs) {
95 // Print optional descriptor binding
96 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97 input: stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 input: stringifyDecoration(spirv::Decoration::Binding));
100 auto descriptorSet = op->getAttrOfType<IntegerAttr>(name: descriptorSetName);
101 auto binding = op->getAttrOfType<IntegerAttr>(name: bindingName);
102 if (descriptorSet && binding) {
103 elidedAttrs.push_back(Elt: descriptorSetName);
104 elidedAttrs.push_back(Elt: bindingName);
105 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106 << ")";
107 }
108
109 // Print BuiltIn attribute if present
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 input: stringifyDecoration(spirv::Decoration::BuiltIn));
112 if (auto builtin = op->getAttrOfType<StringAttr>(name: builtInName)) {
113 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114 elidedAttrs.push_back(Elt: builtInName);
115 }
116
117 printer.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs);
118}
119
120static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
121 OperationState &result) {
122 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
123 Type type;
124 // If the operand list is in-between parentheses, then we have a generic form.
125 // (see the fallback in `printOneResultOp`).
126 SMLoc loc = parser.getCurrentLocation();
127 if (!parser.parseOptionalLParen()) {
128 if (parser.parseOperandList(result&: ops) || parser.parseRParen() ||
129 parser.parseOptionalAttrDict(result&: result.attributes) ||
130 parser.parseColon() || parser.parseType(result&: type))
131 return failure();
132 auto fnType = llvm::dyn_cast<FunctionType>(Val&: type);
133 if (!fnType) {
134 parser.emitError(loc, message: "expected function type");
135 return failure();
136 }
137 if (parser.resolveOperands(operands&: ops, types: fnType.getInputs(), loc, result&: result.operands))
138 return failure();
139 result.addTypes(newTypes: fnType.getResults());
140 return success();
141 }
142 return failure(IsFailure: parser.parseOperandList(result&: ops) ||
143 parser.parseOptionalAttrDict(result&: result.attributes) ||
144 parser.parseColonType(result&: type) ||
145 parser.resolveOperands(operands&: ops, type, result&: result.operands) ||
146 parser.addTypeToList(type, result&: result.types));
147}
148
149static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
150 assert(op->getNumResults() == 1 && "op should have one result");
151
152 // If not all the operand and result types are the same, just use the
153 // generic assembly form to avoid omitting information in printing.
154 auto resultType = op->getResult(idx: 0).getType();
155 if (llvm::any_of(Range: op->getOperandTypes(),
156 P: [&](Type type) { return type != resultType; })) {
157 p.printGenericOp(op, /*printOpName=*/false);
158 return;
159 }
160
161 p << ' ';
162 p.printOperands(container: op->getOperands());
163 p.printOptionalAttrDict(attrs: op->getAttrs());
164 // Now we can output only one type for all operands and the result.
165 p << " : " << resultType;
166}
167
168template <typename BlockReadWriteOpTy>
169static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170 Value ptr, Value val) {
171 auto valType = val.getType();
172 if (auto valVecTy = llvm::dyn_cast<VectorType>(Val&: valType))
173 valType = valVecTy.getElementType();
174
175 if (valType !=
176 llvm::cast<spirv::PointerType>(Val: ptr.getType()).getPointeeType()) {
177 return op.emitOpError("mismatch in result type and pointer type");
178 }
179 return success();
180}
181
182/// Walks the given type hierarchy with the given indices, potentially down
183/// to component granularity, to select an element type. Returns null type and
184/// emits errors with the given loc on failure.
185static Type
186getElementType(Type type, ArrayRef<int32_t> indices,
187 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
188 if (indices.empty()) {
189 emitErrorFn("expected at least one index for spirv.CompositeExtract");
190 return nullptr;
191 }
192
193 for (auto index : indices) {
194 if (auto cType = llvm::dyn_cast<spirv::CompositeType>(Val&: type)) {
195 if (cType.hasCompileTimeKnownNumElements() &&
196 (index < 0 ||
197 static_cast<uint64_t>(index) >= cType.getNumElements())) {
198 emitErrorFn("index ") << index << " out of bounds for " << type;
199 return nullptr;
200 }
201 type = cType.getElementType(index);
202 } else {
203 emitErrorFn("cannot extract from non-composite type ")
204 << type << " with index " << index;
205 return nullptr;
206 }
207 }
208 return type;
209}
210
211static Type
212getElementType(Type type, Attribute indices,
213 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
214 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(Val&: indices);
215 if (!indicesArrayAttr) {
216 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
217 return nullptr;
218 }
219 if (indicesArrayAttr.empty()) {
220 emitErrorFn("expected at least one index for spirv.CompositeExtract");
221 return nullptr;
222 }
223
224 SmallVector<int32_t, 2> indexVals;
225 for (auto indexAttr : indicesArrayAttr) {
226 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(Val&: indexAttr);
227 if (!indexIntAttr) {
228 emitErrorFn("expected an 32-bit integer for index, but found '")
229 << indexAttr << "'";
230 return nullptr;
231 }
232 indexVals.push_back(Elt: indexIntAttr.getInt());
233 }
234 return getElementType(type, indices: indexVals, emitErrorFn);
235}
236
237static Type getElementType(Type type, Attribute indices, Location loc) {
238 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
239 return ::mlir::emitError(loc, message: err);
240 };
241 return getElementType(type, indices, emitErrorFn: errorFn);
242}
243
244static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
245 SMLoc loc) {
246 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
247 return parser.emitError(loc, message: err);
248 };
249 return getElementType(type, indices, emitErrorFn: errorFn);
250}
251
252template <typename ExtendedBinaryOp>
253static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
254 auto resultType = llvm::cast<spirv::StructType>(op.getType());
255 if (resultType.getNumElements() != 2)
256 return op.emitOpError("expected result struct type containing two members");
257
258 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
259 resultType.getElementType(0),
260 resultType.getElementType(1)}))
261 return op.emitOpError(
262 "expected all operand types and struct member types are the same");
263
264 return success();
265}
266
267static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
268 OperationState &result) {
269 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
270 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
271 parser.parseOperandList(result&: operands) || parser.parseColon())
272 return failure();
273
274 Type resultType;
275 SMLoc loc = parser.getCurrentLocation();
276 if (parser.parseType(result&: resultType))
277 return failure();
278
279 auto structType = llvm::dyn_cast<spirv::StructType>(Val&: resultType);
280 if (!structType || structType.getNumElements() != 2)
281 return parser.emitError(loc, message: "expected spirv.struct type with two members");
282
283 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
284 if (parser.resolveOperands(operands, types&: operandTypes, loc, result&: result.operands))
285 return failure();
286
287 result.addTypes(newTypes: resultType);
288 return success();
289}
290
291static void printArithmeticExtendedBinaryOp(Operation *op,
292 OpAsmPrinter &printer) {
293 printer << ' ';
294 printer.printOptionalAttrDict(attrs: op->getAttrs());
295 printer.printOperands(container: op->getOperands());
296 printer << " : " << op->getResultTypes().front();
297}
298
299static LogicalResult verifyShiftOp(Operation *op) {
300 if (op->getOperand(idx: 0).getType() != op->getResult(idx: 0).getType()) {
301 return op->emitError(message: "expected the same type for the first operand and "
302 "result, but provided ")
303 << op->getOperand(idx: 0).getType() << " and "
304 << op->getResult(idx: 0).getType();
305 }
306 return success();
307}
308
309//===----------------------------------------------------------------------===//
310// spirv.mlir.addressof
311//===----------------------------------------------------------------------===//
312
313void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
314 spirv::GlobalVariableOp var) {
315 build(odsBuilder&: builder, odsState&: state, pointer: var.getType(), variable: SymbolRefAttr::get(symbol: var));
316}
317
318LogicalResult spirv::AddressOfOp::verify() {
319 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
320 Val: SymbolTable::lookupNearestSymbolFrom(from: (*this)->getParentOp(),
321 symbol: getVariableAttr()));
322 if (!varOp) {
323 return emitOpError(message: "expected spirv.GlobalVariable symbol");
324 }
325 if (getPointer().getType() != varOp.getType()) {
326 return emitOpError(
327 message: "result type mismatch with the referenced global variable's type");
328 }
329 return success();
330}
331
332//===----------------------------------------------------------------------===//
333// spirv.CompositeConstruct
334//===----------------------------------------------------------------------===//
335
336LogicalResult spirv::CompositeConstructOp::verify() {
337 operand_range constituents = this->getConstituents();
338
339 // There are 4 cases with varying verification rules:
340 // 1. Cooperative Matrices (1 constituent)
341 // 2. Structs (1 constituent for each member)
342 // 3. Arrays (1 constituent for each array element)
343 // 4. Vectors (1 constituent (sub-)element for each vector element)
344
345 auto coopElementType =
346 llvm::TypeSwitch<Type, Type>(getType())
347 .Case<spirv::CooperativeMatrixType>(
348 caseFn: [](auto coopType) { return coopType.getElementType(); })
349 .Default(defaultFn: [](Type) { return nullptr; });
350
351 // Case 1. -- matrices.
352 if (coopElementType) {
353 if (constituents.size() != 1)
354 return emitOpError(message: "has incorrect number of operands: expected ")
355 << "1, but provided " << constituents.size();
356 if (coopElementType != constituents.front().getType())
357 return emitOpError(message: "operand type mismatch: expected operand type ")
358 << coopElementType << ", but provided "
359 << constituents.front().getType();
360 return success();
361 }
362
363 // Case 2./3./4. -- number of constituents matches the number of elements.
364 auto cType = llvm::cast<spirv::CompositeType>(Val: getType());
365 if (constituents.size() == cType.getNumElements()) {
366 for (auto index : llvm::seq<uint32_t>(Begin: 0, End: constituents.size())) {
367 if (constituents[index].getType() != cType.getElementType(index)) {
368 return emitOpError(message: "operand type mismatch: expected operand type ")
369 << cType.getElementType(index) << ", but provided "
370 << constituents[index].getType();
371 }
372 }
373 return success();
374 }
375
376 // Case 4. -- check that all constituents add up tp the expected vector type.
377 auto resultType = llvm::dyn_cast<VectorType>(Val&: cType);
378 if (!resultType)
379 return emitOpError(
380 message: "expected to return a vector or cooperative matrix when the number of "
381 "constituents is less than what the result needs");
382
383 SmallVector<unsigned> sizes;
384 for (Value component : constituents) {
385 if (!llvm::isa<VectorType>(Val: component.getType()) &&
386 !component.getType().isIntOrFloat())
387 return emitOpError(message: "operand type mismatch: expected operand to have "
388 "a scalar or vector type, but provided ")
389 << component.getType();
390
391 Type elementType = component.getType();
392 if (auto vectorType = llvm::dyn_cast<VectorType>(Val: component.getType())) {
393 sizes.push_back(Elt: vectorType.getNumElements());
394 elementType = vectorType.getElementType();
395 } else {
396 sizes.push_back(Elt: 1);
397 }
398
399 if (elementType != resultType.getElementType())
400 return emitOpError(message: "operand element type mismatch: expected to be ")
401 << resultType.getElementType() << ", but provided " << elementType;
402 }
403 unsigned totalCount = std::accumulate(first: sizes.begin(), last: sizes.end(), init: 0);
404 if (totalCount != cType.getNumElements())
405 return emitOpError(message: "has incorrect number of operands: expected ")
406 << cType.getNumElements() << ", but provided " << totalCount;
407 return success();
408}
409
410//===----------------------------------------------------------------------===//
411// spirv.CompositeExtractOp
412//===----------------------------------------------------------------------===//
413
414void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
415 Value composite,
416 ArrayRef<int32_t> indices) {
417 auto indexAttr = builder.getI32ArrayAttr(values: indices);
418 auto elementType =
419 getElementType(type: composite.getType(), indices: indexAttr, loc: state.location);
420 if (!elementType) {
421 return;
422 }
423 build(odsBuilder&: builder, odsState&: state, component: elementType, composite, indices: indexAttr);
424}
425
426ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
427 OperationState &result) {
428 OpAsmParser::UnresolvedOperand compositeInfo;
429 Attribute indicesAttr;
430 StringRef indicesAttrName =
431 spirv::CompositeExtractOp::getIndicesAttrName(name: result.name);
432 Type compositeType;
433 SMLoc attrLocation;
434
435 if (parser.parseOperand(result&: compositeInfo) ||
436 parser.getCurrentLocation(loc: &attrLocation) ||
437 parser.parseAttribute(result&: indicesAttr, attrName: indicesAttrName, attrs&: result.attributes) ||
438 parser.parseColonType(result&: compositeType) ||
439 parser.resolveOperand(operand: compositeInfo, type: compositeType, result&: result.operands)) {
440 return failure();
441 }
442
443 Type resultType =
444 getElementType(type: compositeType, indices: indicesAttr, parser, loc: attrLocation);
445 if (!resultType) {
446 return failure();
447 }
448 result.addTypes(newTypes: resultType);
449 return success();
450}
451
452void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
453 printer << ' ' << getComposite() << getIndices() << " : "
454 << getComposite().getType();
455}
456
457LogicalResult spirv::CompositeExtractOp::verify() {
458 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(Val: getIndices());
459 auto resultType =
460 getElementType(type: getComposite().getType(), indices: indicesArrayAttr, loc: getLoc());
461 if (!resultType)
462 return failure();
463
464 if (resultType != getType()) {
465 return emitOpError(message: "invalid result type: expected ")
466 << resultType << " but provided " << getType();
467 }
468
469 return success();
470}
471
472//===----------------------------------------------------------------------===//
473// spirv.CompositeInsert
474//===----------------------------------------------------------------------===//
475
476void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
477 Value object, Value composite,
478 ArrayRef<int32_t> indices) {
479 auto indexAttr = builder.getI32ArrayAttr(values: indices);
480 build(odsBuilder&: builder, odsState&: state, result: composite.getType(), object, composite, indices: indexAttr);
481}
482
483ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
484 OperationState &result) {
485 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
486 Type objectType, compositeType;
487 Attribute indicesAttr;
488 StringRef indicesAttrName =
489 spirv::CompositeInsertOp::getIndicesAttrName(name: result.name);
490 auto loc = parser.getCurrentLocation();
491
492 return failure(
493 IsFailure: parser.parseOperandList(result&: operands, requiredOperandCount: 2) ||
494 parser.parseAttribute(result&: indicesAttr, attrName: indicesAttrName, attrs&: result.attributes) ||
495 parser.parseColonType(result&: objectType) ||
496 parser.parseKeywordType(keyword: "into", result&: compositeType) ||
497 parser.resolveOperands(operands, types: {objectType, compositeType}, loc,
498 result&: result.operands) ||
499 parser.addTypesToList(types: compositeType, result&: result.types));
500}
501
502LogicalResult spirv::CompositeInsertOp::verify() {
503 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(Val: getIndices());
504 auto objectType =
505 getElementType(type: getComposite().getType(), indices: indicesArrayAttr, loc: getLoc());
506 if (!objectType)
507 return failure();
508
509 if (objectType != getObject().getType()) {
510 return emitOpError(message: "object operand type should be ")
511 << objectType << ", but found " << getObject().getType();
512 }
513
514 if (getComposite().getType() != getType()) {
515 return emitOpError(message: "result type should be the same as "
516 "the composite type, but found ")
517 << getComposite().getType() << " vs " << getType();
518 }
519
520 return success();
521}
522
523void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
524 printer << " " << getObject() << ", " << getComposite() << getIndices()
525 << " : " << getObject().getType() << " into "
526 << getComposite().getType();
527}
528
529//===----------------------------------------------------------------------===//
530// spirv.Constant
531//===----------------------------------------------------------------------===//
532
533ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
534 OperationState &result) {
535 Attribute value;
536 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(name: result.name);
537 if (parser.parseAttribute(result&: value, attrName: valueAttrName, attrs&: result.attributes))
538 return failure();
539
540 Type type = NoneType::get(context: parser.getContext());
541 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(Val&: value))
542 type = typedAttr.getType();
543 if (llvm::isa<NoneType, TensorType>(Val: type)) {
544 if (parser.parseColonType(result&: type))
545 return failure();
546 }
547
548 if (llvm::isa<TensorArmType>(Val: type)) {
549 if (parser.parseOptionalColon().succeeded())
550 if (parser.parseType(result&: type))
551 return failure();
552 }
553
554 return parser.addTypeToList(type, result&: result.types);
555}
556
557void spirv::ConstantOp::print(OpAsmPrinter &printer) {
558 printer << ' ' << getValue();
559 if (llvm::isa<spirv::ArrayType>(Val: getType()))
560 printer << " : " << getType();
561}
562
563static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
564 Type opType) {
565 if (isa<spirv::CooperativeMatrixType>(Val: opType)) {
566 auto denseAttr = dyn_cast<DenseElementsAttr>(Val&: value);
567 if (!denseAttr || !denseAttr.isSplat())
568 return op.emitOpError(message: "expected a splat dense attribute for cooperative "
569 "matrix constant, but found ")
570 << denseAttr;
571 }
572 if (llvm::isa<IntegerAttr, FloatAttr>(Val: value)) {
573 auto valueType = llvm::cast<TypedAttr>(Val&: value).getType();
574 if (valueType != opType)
575 return op.emitOpError(message: "result type (")
576 << opType << ") does not match value type (" << valueType << ")";
577 return success();
578 }
579 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(Val: value)) {
580 auto valueType = llvm::cast<TypedAttr>(Val&: value).getType();
581 if (valueType == opType)
582 return success();
583 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(Val&: opType);
584 auto shapedType = llvm::dyn_cast<ShapedType>(Val&: valueType);
585 if (!arrayType)
586 return op.emitOpError(message: "result or element type (")
587 << opType << ") does not match value type (" << valueType
588 << "), must be the same or spirv.array";
589
590 int numElements = arrayType.getNumElements();
591 auto opElemType = arrayType.getElementType();
592 while (auto t = llvm::dyn_cast<spirv::ArrayType>(Val&: opElemType)) {
593 numElements *= t.getNumElements();
594 opElemType = t.getElementType();
595 }
596 if (!opElemType.isIntOrFloat())
597 return op.emitOpError(message: "only support nested array result type");
598
599 auto valueElemType = shapedType.getElementType();
600 if (valueElemType != opElemType) {
601 return op.emitOpError(message: "result element type (")
602 << opElemType << ") does not match value element type ("
603 << valueElemType << ")";
604 }
605
606 if (numElements != shapedType.getNumElements()) {
607 return op.emitOpError(message: "result number of elements (")
608 << numElements << ") does not match value number of elements ("
609 << shapedType.getNumElements() << ")";
610 }
611 return success();
612 }
613 if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(Val&: value)) {
614 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(Val&: opType);
615 if (!arrayType)
616 return op.emitOpError(
617 message: "must have spirv.array result type for array value");
618 Type elemType = arrayType.getElementType();
619 for (Attribute element : arrayAttr.getValue()) {
620 // Verify array elements recursively.
621 if (failed(Result: verifyConstantType(op, value: element, opType: elemType)))
622 return failure();
623 }
624 return success();
625 }
626 return op.emitOpError(message: "cannot have attribute: ") << value;
627}
628
629LogicalResult spirv::ConstantOp::verify() {
630 // ODS already generates checks to make sure the result type is valid. We just
631 // need to additionally check that the value's attribute type is consistent
632 // with the result type.
633 return verifyConstantType(op: *this, value: getValueAttr(), opType: getType());
634}
635
636bool spirv::ConstantOp::isBuildableWith(Type type) {
637 // Must be valid SPIR-V type first.
638 if (!llvm::isa<spirv::SPIRVType>(Val: type))
639 return false;
640
641 if (isa<SPIRVDialect>(Val: type.getDialect())) {
642 // TODO: support constant struct
643 return llvm::isa<spirv::ArrayType>(Val: type);
644 }
645
646 return true;
647}
648
649spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
650 OpBuilder &builder) {
651 if (auto intType = llvm::dyn_cast<IntegerType>(Val&: type)) {
652 unsigned width = intType.getWidth();
653 if (width == 1)
654 return builder.create<spirv::ConstantOp>(location: loc, args&: type,
655 args: builder.getBoolAttr(value: false));
656 return builder.create<spirv::ConstantOp>(
657 location: loc, args&: type, args: builder.getIntegerAttr(type, value: APInt(width, 0)));
658 }
659 if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) {
660 return builder.create<spirv::ConstantOp>(
661 location: loc, args&: type, args: builder.getFloatAttr(type: floatType, value: 0.0));
662 }
663 if (auto vectorType = llvm::dyn_cast<VectorType>(Val&: type)) {
664 Type elemType = vectorType.getElementType();
665 if (llvm::isa<IntegerType>(Val: elemType)) {
666 return builder.create<spirv::ConstantOp>(
667 location: loc, args&: type,
668 args: DenseElementsAttr::get(type: vectorType,
669 values: IntegerAttr::get(type: elemType, value: 0).getValue()));
670 }
671 if (llvm::isa<FloatType>(Val: elemType)) {
672 return builder.create<spirv::ConstantOp>(
673 location: loc, args&: type,
674 args: DenseFPElementsAttr::get(type: vectorType,
675 arg: FloatAttr::get(type: elemType, value: 0.0).getValue()));
676 }
677 }
678
679 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
680}
681
682spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
683 OpBuilder &builder) {
684 if (auto intType = llvm::dyn_cast<IntegerType>(Val&: type)) {
685 unsigned width = intType.getWidth();
686 if (width == 1)
687 return builder.create<spirv::ConstantOp>(location: loc, args&: type,
688 args: builder.getBoolAttr(value: true));
689 return builder.create<spirv::ConstantOp>(
690 location: loc, args&: type, args: builder.getIntegerAttr(type, value: APInt(width, 1)));
691 }
692 if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) {
693 return builder.create<spirv::ConstantOp>(
694 location: loc, args&: type, args: builder.getFloatAttr(type: floatType, value: 1.0));
695 }
696 if (auto vectorType = llvm::dyn_cast<VectorType>(Val&: type)) {
697 Type elemType = vectorType.getElementType();
698 if (llvm::isa<IntegerType>(Val: elemType)) {
699 return builder.create<spirv::ConstantOp>(
700 location: loc, args&: type,
701 args: DenseElementsAttr::get(type: vectorType,
702 values: IntegerAttr::get(type: elemType, value: 1).getValue()));
703 }
704 if (llvm::isa<FloatType>(Val: elemType)) {
705 return builder.create<spirv::ConstantOp>(
706 location: loc, args&: type,
707 args: DenseFPElementsAttr::get(type: vectorType,
708 arg: FloatAttr::get(type: elemType, value: 1.0).getValue()));
709 }
710 }
711
712 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
713}
714
715void mlir::spirv::ConstantOp::getAsmResultNames(
716 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
717 Type type = getType();
718
719 SmallString<32> specialNameBuffer;
720 llvm::raw_svector_ostream specialName(specialNameBuffer);
721 specialName << "cst";
722
723 IntegerType intTy = llvm::dyn_cast<IntegerType>(Val&: type);
724
725 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(Val: getValue())) {
726 if (intTy && intTy.getWidth() == 1) {
727 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
728 }
729
730 if (intTy.isSignless()) {
731 specialName << intCst.getInt();
732 } else if (intTy.isUnsigned()) {
733 specialName << intCst.getUInt();
734 } else {
735 specialName << intCst.getSInt();
736 }
737 }
738
739 if (intTy || llvm::isa<FloatType>(Val: type)) {
740 specialName << '_' << type;
741 }
742
743 if (auto vecType = llvm::dyn_cast<VectorType>(Val&: type)) {
744 specialName << "_vec_";
745 specialName << vecType.getDimSize(idx: 0);
746
747 Type elementType = vecType.getElementType();
748
749 if (llvm::isa<IntegerType>(Val: elementType) ||
750 llvm::isa<FloatType>(Val: elementType)) {
751 specialName << "x" << elementType;
752 }
753 }
754
755 setNameFn(getResult(), specialName.str());
756}
757
758void mlir::spirv::AddressOfOp::getAsmResultNames(
759 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
760 SmallString<32> specialNameBuffer;
761 llvm::raw_svector_ostream specialName(specialNameBuffer);
762 specialName << getVariable() << "_addr";
763 setNameFn(getResult(), specialName.str());
764}
765
766//===----------------------------------------------------------------------===//
767// spirv.EXTConstantCompositeReplicate
768//===----------------------------------------------------------------------===//
769
770LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
771 Type valueType;
772 if (auto typedAttr = dyn_cast<TypedAttr>(Val: getValue())) {
773 valueType = typedAttr.getType();
774 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(Val: getValue())) {
775 auto typedElemAttr = dyn_cast<TypedAttr>(Val: arrayAttr[0]);
776 if (!typedElemAttr)
777 return emitError(message: "value attribute is not typed");
778 valueType =
779 spirv::ArrayType::get(elementType: typedElemAttr.getType(), elementCount: arrayAttr.size());
780 } else {
781 return emitError(message: "unknown value attribute type");
782 }
783
784 auto compositeType = dyn_cast<spirv::CompositeType>(Val: getType());
785 if (!compositeType)
786 return emitError(message: "result type is not a composite type");
787
788 Type compositeElementType = compositeType.getElementType(0);
789
790 SmallVector<Type, 3> possibleTypes = {compositeElementType};
791 while (auto type = dyn_cast<spirv::CompositeType>(Val&: compositeElementType)) {
792 compositeElementType = type.getElementType(0);
793 possibleTypes.push_back(Elt: compositeElementType);
794 }
795
796 if (!is_contained(Range&: possibleTypes, Element: valueType)) {
797 return emitError(message: "expected value attribute type ")
798 << interleaved(R: possibleTypes, Separator: " or ") << ", but got: " << valueType;
799 }
800
801 return success();
802}
803
804//===----------------------------------------------------------------------===//
805// spirv.ControlBarrierOp
806//===----------------------------------------------------------------------===//
807
808LogicalResult spirv::ControlBarrierOp::verify() {
809 return verifyMemorySemantics(op: getOperation(), memorySemantics: getMemorySemantics());
810}
811
812//===----------------------------------------------------------------------===//
813// spirv.EntryPoint
814//===----------------------------------------------------------------------===//
815
816void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
817 spirv::ExecutionModel executionModel,
818 spirv::FuncOp function,
819 ArrayRef<Attribute> interfaceVars) {
820 build(odsBuilder&: builder, odsState&: state,
821 execution_model: spirv::ExecutionModelAttr::get(context: builder.getContext(), value: executionModel),
822 fn: SymbolRefAttr::get(symbol: function), interface: builder.getArrayAttr(value: interfaceVars));
823}
824
825ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
826 OperationState &result) {
827 spirv::ExecutionModel execModel;
828 SmallVector<Attribute, 4> interfaceVars;
829
830 FlatSymbolRefAttr fn;
831 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(value&: execModel, parser, state&: result) ||
832 parser.parseAttribute(result&: fn, type: Type(), attrName: kFnNameAttrName, attrs&: result.attributes)) {
833 return failure();
834 }
835
836 if (!parser.parseOptionalComma()) {
837 // Parse the interface variables
838 if (parser.parseCommaSeparatedList(parseElementFn: [&]() -> ParseResult {
839 // The name of the interface variable attribute isnt important
840 FlatSymbolRefAttr var;
841 NamedAttrList attrs;
842 if (parser.parseAttribute(result&: var, type: Type(), attrName: "var_symbol", attrs))
843 return failure();
844 interfaceVars.push_back(Elt: var);
845 return success();
846 }))
847 return failure();
848 }
849 result.addAttribute(name: spirv::EntryPointOp::getInterfaceAttrName(name: result.name),
850 attr: parser.getBuilder().getArrayAttr(value: interfaceVars));
851 return success();
852}
853
854void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
855 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
856 printer.printSymbolName(symbolRef: getFn());
857 auto interfaceVars = getInterface().getValue();
858 if (!interfaceVars.empty())
859 printer << ", " << llvm::interleaved(R: interfaceVars);
860}
861
862LogicalResult spirv::EntryPointOp::verify() {
863 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
864 // verification.
865 return success();
866}
867
868//===----------------------------------------------------------------------===//
869// spirv.ExecutionMode
870//===----------------------------------------------------------------------===//
871
872void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
873 spirv::FuncOp function,
874 spirv::ExecutionMode executionMode,
875 ArrayRef<int32_t> params) {
876 build(odsBuilder&: builder, odsState&: state, fn: SymbolRefAttr::get(symbol: function),
877 execution_mode: spirv::ExecutionModeAttr::get(context: builder.getContext(), value: executionMode),
878 values: builder.getI32ArrayAttr(values: params));
879}
880
881ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
882 OperationState &result) {
883 spirv::ExecutionMode execMode;
884 Attribute fn;
885 if (parser.parseAttribute(result&: fn, attrName: kFnNameAttrName, attrs&: result.attributes) ||
886 parseEnumStrAttr<spirv::ExecutionModeAttr>(value&: execMode, parser, state&: result)) {
887 return failure();
888 }
889
890 SmallVector<int32_t, 4> values;
891 Type i32Type = parser.getBuilder().getIntegerType(width: 32);
892 while (!parser.parseOptionalComma()) {
893 NamedAttrList attr;
894 Attribute value;
895 if (parser.parseAttribute(result&: value, type: i32Type, attrName: "value", attrs&: attr)) {
896 return failure();
897 }
898 values.push_back(Elt: llvm::cast<IntegerAttr>(Val&: value).getInt());
899 }
900 StringRef valuesAttrName =
901 spirv::ExecutionModeOp::getValuesAttrName(name: result.name);
902 result.addAttribute(name: valuesAttrName,
903 attr: parser.getBuilder().getI32ArrayAttr(values));
904 return success();
905}
906
907void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
908 printer << " ";
909 printer.printSymbolName(symbolRef: getFn());
910 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
911 ArrayAttr values = this->getValues();
912 if (!values.empty())
913 printer << ", " << llvm::interleaved(R: values.getAsValueRange<IntegerAttr>());
914}
915
916//===----------------------------------------------------------------------===//
917// spirv.func
918//===----------------------------------------------------------------------===//
919
920ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
921 SmallVector<OpAsmParser::Argument> entryArgs;
922 SmallVector<DictionaryAttr> resultAttrs;
923 SmallVector<Type> resultTypes;
924 auto &builder = parser.getBuilder();
925
926 // Parse the name as a symbol.
927 StringAttr nameAttr;
928 if (parser.parseSymbolName(result&: nameAttr, attrName: SymbolTable::getSymbolAttrName(),
929 attrs&: result.attributes))
930 return failure();
931
932 // Parse the function signature.
933 bool isVariadic = false;
934 if (function_interface_impl::parseFunctionSignatureWithArguments(
935 parser, /*allowVariadic=*/false, arguments&: entryArgs, isVariadic, resultTypes,
936 resultAttrs))
937 return failure();
938
939 SmallVector<Type> argTypes;
940 for (auto &arg : entryArgs)
941 argTypes.push_back(Elt: arg.type);
942 auto fnType = builder.getFunctionType(inputs: argTypes, results: resultTypes);
943 result.addAttribute(name: getFunctionTypeAttrName(name: result.name),
944 attr: TypeAttr::get(type: fnType));
945
946 // Parse the optional function control keyword.
947 spirv::FunctionControl fnControl;
948 if (parseEnumStrAttr<spirv::FunctionControlAttr>(value&: fnControl, parser, state&: result))
949 return failure();
950
951 // If additional attributes are present, parse them.
952 if (parser.parseOptionalAttrDictWithKeyword(result&: result.attributes))
953 return failure();
954
955 // Add the attributes to the function arguments.
956 assert(resultAttrs.size() == resultTypes.size());
957 call_interface_impl::addArgAndResultAttrs(
958 builder, result, args: entryArgs, resultAttrs, argAttrsName: getArgAttrsAttrName(name: result.name),
959 resAttrsName: getResAttrsAttrName(name: result.name));
960
961 // Parse the optional function body.
962 auto *body = result.addRegion();
963 OptionalParseResult parseResult =
964 parser.parseOptionalRegion(region&: *body, arguments: entryArgs);
965 return failure(IsFailure: parseResult.has_value() && failed(Result: *parseResult));
966}
967
968void spirv::FuncOp::print(OpAsmPrinter &printer) {
969 // Print function name, signature, and control.
970 printer << " ";
971 printer.printSymbolName(symbolRef: getSymName());
972 auto fnType = getFunctionType();
973 function_interface_impl::printFunctionSignature(
974 p&: printer, op: *this, argTypes: fnType.getInputs(),
975 /*isVariadic=*/false, resultTypes: fnType.getResults());
976 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
977 << "\"";
978 function_interface_impl::printFunctionAttributes(
979 p&: printer, op: *this,
980 elided: {spirv::attributeName<spirv::FunctionControl>(),
981 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
982 getFunctionControlAttrName()});
983
984 // Print the body if this is not an external function.
985 Region &body = this->getBody();
986 if (!body.empty()) {
987 printer << ' ';
988 printer.printRegion(blocks&: body, /*printEntryBlockArgs=*/false,
989 /*printBlockTerminators=*/true);
990 }
991}
992
993LogicalResult spirv::FuncOp::verifyType() {
994 FunctionType fnType = getFunctionType();
995 if (fnType.getNumResults() > 1)
996 return emitOpError(message: "cannot have more than one result");
997
998 auto hasDecorationAttr = [&](spirv::Decoration decoration,
999 unsigned argIndex) {
1000 auto func = llvm::cast<FunctionOpInterface>(Val: getOperation());
1001 for (auto argAttr : cast<FunctionOpInterface>(Val&: func).getArgAttrs(index: argIndex)) {
1002 if (argAttr.getName() != spirv::DecorationAttr::name)
1003 continue;
1004 if (auto decAttr = dyn_cast<spirv::DecorationAttr>(Val: argAttr.getValue()))
1005 return decAttr.getValue() == decoration;
1006 }
1007 return false;
1008 };
1009
1010 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1011 Type param = fnType.getInputs()[i];
1012 auto inputPtrType = dyn_cast<spirv::PointerType>(Val&: param);
1013 if (!inputPtrType)
1014 continue;
1015
1016 auto pointeePtrType =
1017 dyn_cast<spirv::PointerType>(Val: inputPtrType.getPointeeType());
1018 if (pointeePtrType) {
1019 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1020 // > If an OpFunctionParameter is a pointer (or contains a pointer)
1021 // > and the type it points to is a pointer in the PhysicalStorageBuffer
1022 // > storage class, the function parameter must be decorated with exactly
1023 // > one of AliasedPointer or RestrictPointer.
1024 if (pointeePtrType.getStorageClass() !=
1025 spirv::StorageClass::PhysicalStorageBuffer)
1026 continue;
1027
1028 bool hasAliasedPtr =
1029 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1030 bool hasRestrictPtr =
1031 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1032 if (!hasAliasedPtr && !hasRestrictPtr)
1033 return emitOpError()
1034 << "with a pointer points to a physical buffer pointer must "
1035 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1036 continue;
1037 }
1038 // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1039 // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1040 // > the PhysicalStorageBuffer storage class, the function parameter must
1041 // > be decorated with exactly one of Aliased or Restrict.
1042 if (auto pointeeArrayType =
1043 dyn_cast<spirv::ArrayType>(Val: inputPtrType.getPointeeType())) {
1044 pointeePtrType =
1045 dyn_cast<spirv::PointerType>(Val: pointeeArrayType.getElementType());
1046 } else {
1047 pointeePtrType = inputPtrType;
1048 }
1049
1050 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1051 spirv::StorageClass::PhysicalStorageBuffer)
1052 continue;
1053
1054 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1055 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1056 if (!hasAliased && !hasRestrict)
1057 return emitOpError() << "with physical buffer pointer must be decorated "
1058 "either 'Aliased' or 'Restrict'";
1059 }
1060
1061 return success();
1062}
1063
1064LogicalResult spirv::FuncOp::verifyBody() {
1065 FunctionType fnType = getFunctionType();
1066 if (!isExternal()) {
1067 Block &entryBlock = front();
1068
1069 unsigned numArguments = this->getNumArguments();
1070 if (entryBlock.getNumArguments() != numArguments)
1071 return emitOpError(message: "entry block must have ")
1072 << numArguments << " arguments to match function signature";
1073
1074 for (auto [index, fnArgType, blockArgType] :
1075 llvm::enumerate(First: getArgumentTypes(), Rest: entryBlock.getArgumentTypes())) {
1076 if (blockArgType != fnArgType) {
1077 return emitOpError(message: "type of entry block argument #")
1078 << index << '(' << blockArgType
1079 << ") must match the type of the corresponding argument in "
1080 << "function signature(" << fnArgType << ')';
1081 }
1082 }
1083 }
1084
1085 auto walkResult = walk(callback: [fnType](Operation *op) -> WalkResult {
1086 if (auto retOp = dyn_cast<spirv::ReturnOp>(Val: op)) {
1087 if (fnType.getNumResults() != 0)
1088 return retOp.emitOpError(message: "cannot be used in functions returning value");
1089 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(Val: op)) {
1090 if (fnType.getNumResults() != 1)
1091 return retOp.emitOpError(
1092 message: "returns 1 value but enclosing function requires ")
1093 << fnType.getNumResults() << " results";
1094
1095 auto retOperandType = retOp.getValue().getType();
1096 auto fnResultType = fnType.getResult(i: 0);
1097 if (retOperandType != fnResultType)
1098 return retOp.emitOpError(message: " return value's type (")
1099 << retOperandType << ") mismatch with function's result type ("
1100 << fnResultType << ")";
1101 }
1102 return WalkResult::advance();
1103 });
1104
1105 // TODO: verify other bits like linkage type.
1106
1107 return failure(IsFailure: walkResult.wasInterrupted());
1108}
1109
1110void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1111 StringRef name, FunctionType type,
1112 spirv::FunctionControl control,
1113 ArrayRef<NamedAttribute> attrs) {
1114 state.addAttribute(name: SymbolTable::getSymbolAttrName(),
1115 attr: builder.getStringAttr(bytes: name));
1116 state.addAttribute(name: getFunctionTypeAttrName(name: state.name), attr: TypeAttr::get(type));
1117 state.addAttribute(name: spirv::attributeName<spirv::FunctionControl>(),
1118 attr: builder.getAttr<spirv::FunctionControlAttr>(args&: control));
1119 state.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
1120 state.addRegion();
1121}
1122
1123//===----------------------------------------------------------------------===//
1124// spirv.GLFClampOp
1125//===----------------------------------------------------------------------===//
1126
1127ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1128 OperationState &result) {
1129 return parseOneResultSameOperandTypeOp(parser, result);
1130}
1131void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(op: *this, p); }
1132
1133//===----------------------------------------------------------------------===//
1134// spirv.GLUClampOp
1135//===----------------------------------------------------------------------===//
1136
1137ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1138 OperationState &result) {
1139 return parseOneResultSameOperandTypeOp(parser, result);
1140}
1141void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(op: *this, p); }
1142
1143//===----------------------------------------------------------------------===//
1144// spirv.GLSClampOp
1145//===----------------------------------------------------------------------===//
1146
1147ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1148 OperationState &result) {
1149 return parseOneResultSameOperandTypeOp(parser, result);
1150}
1151void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(op: *this, p); }
1152
1153//===----------------------------------------------------------------------===//
1154// spirv.GLFmaOp
1155//===----------------------------------------------------------------------===//
1156
1157ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1158 return parseOneResultSameOperandTypeOp(parser, result);
1159}
1160void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(op: *this, p); }
1161
1162//===----------------------------------------------------------------------===//
1163// spirv.GlobalVariable
1164//===----------------------------------------------------------------------===//
1165
1166void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1167 Type type, StringRef name,
1168 unsigned descriptorSet, unsigned binding) {
1169 build(odsBuilder&: builder, odsState&: state, type: TypeAttr::get(type), sym_name: builder.getStringAttr(bytes: name));
1170 state.addAttribute(
1171 name: spirv::SPIRVDialect::getAttributeName(decoration: spirv::Decoration::DescriptorSet),
1172 attr: builder.getI32IntegerAttr(value: descriptorSet));
1173 state.addAttribute(
1174 name: spirv::SPIRVDialect::getAttributeName(decoration: spirv::Decoration::Binding),
1175 attr: builder.getI32IntegerAttr(value: binding));
1176}
1177
1178void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1179 Type type, StringRef name,
1180 spirv::BuiltIn builtin) {
1181 build(odsBuilder&: builder, odsState&: state, type: TypeAttr::get(type), sym_name: builder.getStringAttr(bytes: name));
1182 state.addAttribute(
1183 name: spirv::SPIRVDialect::getAttributeName(decoration: spirv::Decoration::BuiltIn),
1184 attr: builder.getStringAttr(bytes: spirv::stringifyBuiltIn(builtin)));
1185}
1186
1187ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1188 OperationState &result) {
1189 // Parse variable name.
1190 StringAttr nameAttr;
1191 StringRef initializerAttrName =
1192 spirv::GlobalVariableOp::getInitializerAttrName(name: result.name);
1193 if (parser.parseSymbolName(result&: nameAttr, attrName: SymbolTable::getSymbolAttrName(),
1194 attrs&: result.attributes)) {
1195 return failure();
1196 }
1197
1198 // Parse optional initializer
1199 if (succeeded(Result: parser.parseOptionalKeyword(keyword: initializerAttrName))) {
1200 FlatSymbolRefAttr initSymbol;
1201 if (parser.parseLParen() ||
1202 parser.parseAttribute(result&: initSymbol, type: Type(), attrName: initializerAttrName,
1203 attrs&: result.attributes) ||
1204 parser.parseRParen())
1205 return failure();
1206 }
1207
1208 if (parseVariableDecorations(parser, state&: result)) {
1209 return failure();
1210 }
1211
1212 Type type;
1213 StringRef typeAttrName =
1214 spirv::GlobalVariableOp::getTypeAttrName(name: result.name);
1215 auto loc = parser.getCurrentLocation();
1216 if (parser.parseColonType(result&: type)) {
1217 return failure();
1218 }
1219 if (!llvm::isa<spirv::PointerType>(Val: type)) {
1220 return parser.emitError(loc, message: "expected spirv.ptr type");
1221 }
1222 result.addAttribute(name: typeAttrName, attr: TypeAttr::get(type));
1223
1224 return success();
1225}
1226
1227void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
1228 SmallVector<StringRef, 4> elidedAttrs{
1229 spirv::attributeName<spirv::StorageClass>()};
1230
1231 // Print variable name.
1232 printer << ' ';
1233 printer.printSymbolName(symbolRef: getSymName());
1234 elidedAttrs.push_back(Elt: SymbolTable::getSymbolAttrName());
1235
1236 StringRef initializerAttrName = this->getInitializerAttrName();
1237 // Print optional initializer
1238 if (auto initializer = this->getInitializer()) {
1239 printer << " " << initializerAttrName << '(';
1240 printer.printSymbolName(symbolRef: *initializer);
1241 printer << ')';
1242 elidedAttrs.push_back(Elt: initializerAttrName);
1243 }
1244
1245 StringRef typeAttrName = this->getTypeAttrName();
1246 elidedAttrs.push_back(Elt: typeAttrName);
1247 spirv::printVariableDecorations(op: *this, printer, elidedAttrs);
1248 printer << " : " << getType();
1249}
1250
1251LogicalResult spirv::GlobalVariableOp::verify() {
1252 if (!llvm::isa<spirv::PointerType>(Val: getType()))
1253 return emitOpError(message: "result must be of a !spv.ptr type");
1254
1255 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1256 // object. It cannot be Generic. It must be the same as the Storage Class
1257 // operand of the Result Type."
1258 // Also, Function storage class is reserved by spirv.Variable.
1259 auto storageClass = this->storageClass();
1260 if (storageClass == spirv::StorageClass::Generic ||
1261 storageClass == spirv::StorageClass::Function) {
1262 return emitOpError(message: "storage class cannot be '")
1263 << stringifyStorageClass(storageClass) << "'";
1264 }
1265
1266 if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1267 name: this->getInitializerAttrName())) {
1268 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
1269 from: (*this)->getParentOp(), symbol: init.getAttr());
1270 // TODO: Currently only variable initialization with specialization
1271 // constants and other variables is supported. They could be normal
1272 // constants in the module scope as well.
1273 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1274 spirv::SpecConstantCompositeOp>(Val: initOp)) {
1275 return emitOpError(message: "initializer must be result of a "
1276 "spirv.SpecConstant or spirv.GlobalVariable or "
1277 "spirv.SpecConstantCompositeOp op");
1278 }
1279 }
1280
1281 return success();
1282}
1283
1284//===----------------------------------------------------------------------===//
1285// spirv.INTEL.SubgroupBlockRead
1286//===----------------------------------------------------------------------===//
1287
1288LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1289 if (failed(Result: verifyBlockReadWritePtrAndValTypes(op: *this, ptr: getPtr(), val: getValue())))
1290 return failure();
1291
1292 return success();
1293}
1294
1295//===----------------------------------------------------------------------===//
1296// spirv.INTEL.SubgroupBlockWrite
1297//===----------------------------------------------------------------------===//
1298
1299ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
1300 OperationState &result) {
1301 // Parse the storage class specification
1302 spirv::StorageClass storageClass;
1303 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1304 auto loc = parser.getCurrentLocation();
1305 Type elementType;
1306 if (parseEnumStrAttr(value&: storageClass, parser) ||
1307 parser.parseOperandList(result&: operandInfo, requiredOperandCount: 2) || parser.parseColon() ||
1308 parser.parseType(result&: elementType)) {
1309 return failure();
1310 }
1311
1312 auto ptrType = spirv::PointerType::get(pointeeType: elementType, storageClass);
1313 if (auto valVecTy = llvm::dyn_cast<VectorType>(Val&: elementType))
1314 ptrType = spirv::PointerType::get(pointeeType: valVecTy.getElementType(), storageClass);
1315
1316 if (parser.resolveOperands(operands&: operandInfo, types: {ptrType, elementType}, loc,
1317 result&: result.operands)) {
1318 return failure();
1319 }
1320 return success();
1321}
1322
1323void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
1324 printer << " " << getPtr() << ", " << getValue() << " : "
1325 << getValue().getType();
1326}
1327
1328LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1329 if (failed(Result: verifyBlockReadWritePtrAndValTypes(op: *this, ptr: getPtr(), val: getValue())))
1330 return failure();
1331
1332 return success();
1333}
1334
1335//===----------------------------------------------------------------------===//
1336// spirv.IAddCarryOp
1337//===----------------------------------------------------------------------===//
1338
1339LogicalResult spirv::IAddCarryOp::verify() {
1340 return ::verifyArithmeticExtendedBinaryOp(op: *this);
1341}
1342
1343ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1344 OperationState &result) {
1345 return ::parseArithmeticExtendedBinaryOp(parser, result);
1346}
1347
1348void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1349 ::printArithmeticExtendedBinaryOp(op: *this, printer);
1350}
1351
1352//===----------------------------------------------------------------------===//
1353// spirv.ISubBorrowOp
1354//===----------------------------------------------------------------------===//
1355
1356LogicalResult spirv::ISubBorrowOp::verify() {
1357 return ::verifyArithmeticExtendedBinaryOp(op: *this);
1358}
1359
1360ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1361 OperationState &result) {
1362 return ::parseArithmeticExtendedBinaryOp(parser, result);
1363}
1364
1365void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
1366 ::printArithmeticExtendedBinaryOp(op: *this, printer);
1367}
1368
1369//===----------------------------------------------------------------------===//
1370// spirv.SMulExtended
1371//===----------------------------------------------------------------------===//
1372
1373LogicalResult spirv::SMulExtendedOp::verify() {
1374 return ::verifyArithmeticExtendedBinaryOp(op: *this);
1375}
1376
1377ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1378 OperationState &result) {
1379 return ::parseArithmeticExtendedBinaryOp(parser, result);
1380}
1381
1382void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
1383 ::printArithmeticExtendedBinaryOp(op: *this, printer);
1384}
1385
1386//===----------------------------------------------------------------------===//
1387// spirv.UMulExtended
1388//===----------------------------------------------------------------------===//
1389
1390LogicalResult spirv::UMulExtendedOp::verify() {
1391 return ::verifyArithmeticExtendedBinaryOp(op: *this);
1392}
1393
1394ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1395 OperationState &result) {
1396 return ::parseArithmeticExtendedBinaryOp(parser, result);
1397}
1398
1399void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
1400 ::printArithmeticExtendedBinaryOp(op: *this, printer);
1401}
1402
1403//===----------------------------------------------------------------------===//
1404// spirv.MemoryBarrierOp
1405//===----------------------------------------------------------------------===//
1406
1407LogicalResult spirv::MemoryBarrierOp::verify() {
1408 return verifyMemorySemantics(op: getOperation(), memorySemantics: getMemorySemantics());
1409}
1410
1411//===----------------------------------------------------------------------===//
1412// spirv.module
1413//===----------------------------------------------------------------------===//
1414
1415void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1416 std::optional<StringRef> name) {
1417 OpBuilder::InsertionGuard guard(builder);
1418 builder.createBlock(parent: state.addRegion());
1419 if (name) {
1420 state.attributes.append(name: mlir::SymbolTable::getSymbolAttrName(),
1421 attr: builder.getStringAttr(bytes: *name));
1422 }
1423}
1424
1425void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1426 spirv::AddressingModel addressingModel,
1427 spirv::MemoryModel memoryModel,
1428 std::optional<VerCapExtAttr> vceTriple,
1429 std::optional<StringRef> name) {
1430 state.addAttribute(
1431 name: "addressing_model",
1432 attr: builder.getAttr<spirv::AddressingModelAttr>(args&: addressingModel));
1433 state.addAttribute(name: "memory_model",
1434 attr: builder.getAttr<spirv::MemoryModelAttr>(args&: memoryModel));
1435 OpBuilder::InsertionGuard guard(builder);
1436 builder.createBlock(parent: state.addRegion());
1437 if (vceTriple)
1438 state.addAttribute(name: getVCETripleAttrName(), attr: *vceTriple);
1439 if (name)
1440 state.addAttribute(name: mlir::SymbolTable::getSymbolAttrName(),
1441 attr: builder.getStringAttr(bytes: *name));
1442}
1443
1444ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1445 OperationState &result) {
1446 Region *body = result.addRegion();
1447
1448 // If the name is present, parse it.
1449 StringAttr nameAttr;
1450 (void)parser.parseOptionalSymbolName(
1451 result&: nameAttr, attrName: mlir::SymbolTable::getSymbolAttrName(), attrs&: result.attributes);
1452
1453 // Parse attributes
1454 spirv::AddressingModel addrModel;
1455 spirv::MemoryModel memoryModel;
1456 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(value&: addrModel, parser,
1457 state&: result) ||
1458 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(value&: memoryModel, parser,
1459 state&: result))
1460 return failure();
1461
1462 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "requires"))) {
1463 spirv::VerCapExtAttr vceTriple;
1464 if (parser.parseAttribute(result&: vceTriple,
1465 attrName: spirv::ModuleOp::getVCETripleAttrName(),
1466 attrs&: result.attributes))
1467 return failure();
1468 }
1469
1470 if (parser.parseOptionalAttrDictWithKeyword(result&: result.attributes) ||
1471 parser.parseRegion(region&: *body, /*arguments=*/{}))
1472 return failure();
1473
1474 // Make sure we have at least one block.
1475 if (body->empty())
1476 body->push_back(block: new Block());
1477
1478 return success();
1479}
1480
1481void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1482 if (std::optional<StringRef> name = getName()) {
1483 printer << ' ';
1484 printer.printSymbolName(symbolRef: *name);
1485 }
1486
1487 SmallVector<StringRef, 2> elidedAttrs;
1488
1489 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1490 << spirv::stringifyMemoryModel(getMemoryModel());
1491 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1492 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1493 elidedAttrs.assign(IL: {addressingModelAttrName, memoryModelAttrName,
1494 mlir::SymbolTable::getSymbolAttrName()});
1495
1496 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1497 printer << " requires " << *triple;
1498 elidedAttrs.push_back(Elt: spirv::ModuleOp::getVCETripleAttrName());
1499 }
1500
1501 printer.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs(), elidedAttrs);
1502 printer << ' ';
1503 printer.printRegion(blocks&: getRegion());
1504}
1505
1506LogicalResult spirv::ModuleOp::verifyRegions() {
1507 Dialect *dialect = (*this)->getDialect();
1508 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
1509 entryPoints;
1510 mlir::SymbolTable table(*this);
1511
1512 for (auto &op : *getBody()) {
1513 if (op.getDialect() != dialect)
1514 return op.emitError(message: "'spirv.module' can only contain spirv.* ops");
1515
1516 // For EntryPoint op, check that the function and execution model is not
1517 // duplicated in EntryPointOps. Also verify that the interface specified
1518 // comes from globalVariables here to make this check cheaper.
1519 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(Val&: op)) {
1520 auto funcOp = table.lookup<spirv::FuncOp>(name: entryPointOp.getFn());
1521 if (!funcOp) {
1522 return entryPointOp.emitError(message: "function '")
1523 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1524 }
1525 if (auto interface = entryPointOp.getInterface()) {
1526 for (Attribute varRef : interface) {
1527 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(Val&: varRef);
1528 if (!varSymRef) {
1529 return entryPointOp.emitError(
1530 message: "expected symbol reference for interface "
1531 "specification instead of '")
1532 << varRef;
1533 }
1534 auto variableOp =
1535 table.lookup<spirv::GlobalVariableOp>(name: varSymRef.getValue());
1536 if (!variableOp) {
1537 return entryPointOp.emitError(message: "expected spirv.GlobalVariable "
1538 "symbol reference instead of'")
1539 << varSymRef << "'";
1540 }
1541 }
1542 }
1543
1544 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1545 funcOp, entryPointOp.getExecutionModel());
1546 if (!entryPoints.try_emplace(Key: key, Args&: entryPointOp).second)
1547 return entryPointOp.emitError(message: "duplicate of a previous EntryPointOp");
1548 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(Val&: op)) {
1549 // If the function is external and does not have 'Import'
1550 // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1551 // LinkageAttributes is used to import external functions.
1552 auto linkageAttr = funcOp.getLinkageAttributes();
1553 auto hasImportLinkage =
1554 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1555 spirv::LinkageType::Import);
1556 if (funcOp.isExternal() && !hasImportLinkage)
1557 return op.emitError(
1558 message: "'spirv.module' cannot contain external functions "
1559 "without 'Import' linkage_attributes (LinkageAttributes)");
1560
1561 // TODO: move this check to spirv.func.
1562 for (auto &block : funcOp)
1563 for (auto &op : block) {
1564 if (op.getDialect() != dialect)
1565 return op.emitError(
1566 message: "functions in 'spirv.module' can only contain spirv.* ops");
1567 }
1568 }
1569 }
1570
1571 return success();
1572}
1573
1574//===----------------------------------------------------------------------===//
1575// spirv.mlir.referenceof
1576//===----------------------------------------------------------------------===//
1577
1578LogicalResult spirv::ReferenceOfOp::verify() {
1579 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1580 from: (*this)->getParentOp(), symbol: getSpecConstAttr());
1581 Type constType;
1582
1583 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(Val: specConstSym);
1584 if (specConstOp)
1585 constType = specConstOp.getDefaultValue().getType();
1586
1587 auto specConstCompositeOp =
1588 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(Val: specConstSym);
1589 if (specConstCompositeOp)
1590 constType = specConstCompositeOp.getType();
1591
1592 if (!specConstOp && !specConstCompositeOp)
1593 return emitOpError(
1594 message: "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1595
1596 if (getReference().getType() != constType)
1597 return emitOpError(message: "result type mismatch with the referenced "
1598 "specialization constant's type");
1599
1600 return success();
1601}
1602
1603//===----------------------------------------------------------------------===//
1604// spirv.SpecConstant
1605//===----------------------------------------------------------------------===//
1606
1607ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1608 OperationState &result) {
1609 StringAttr nameAttr;
1610 Attribute valueAttr;
1611 StringRef defaultValueAttrName =
1612 spirv::SpecConstantOp::getDefaultValueAttrName(name: result.name);
1613
1614 if (parser.parseSymbolName(result&: nameAttr, attrName: SymbolTable::getSymbolAttrName(),
1615 attrs&: result.attributes))
1616 return failure();
1617
1618 // Parse optional spec_id.
1619 if (succeeded(Result: parser.parseOptionalKeyword(keyword: kSpecIdAttrName))) {
1620 IntegerAttr specIdAttr;
1621 if (parser.parseLParen() ||
1622 parser.parseAttribute(result&: specIdAttr, attrName: kSpecIdAttrName, attrs&: result.attributes) ||
1623 parser.parseRParen())
1624 return failure();
1625 }
1626
1627 if (parser.parseEqual() ||
1628 parser.parseAttribute(result&: valueAttr, attrName: defaultValueAttrName, attrs&: result.attributes))
1629 return failure();
1630
1631 return success();
1632}
1633
1634void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
1635 printer << ' ';
1636 printer.printSymbolName(symbolRef: getSymName());
1637 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(name: kSpecIdAttrName))
1638 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1639 printer << " = " << getDefaultValue();
1640}
1641
1642LogicalResult spirv::SpecConstantOp::verify() {
1643 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(name: kSpecIdAttrName))
1644 if (specID.getValue().isNegative())
1645 return emitOpError(message: "SpecId cannot be negative");
1646
1647 auto value = getDefaultValue();
1648 if (llvm::isa<IntegerAttr, FloatAttr>(Val: value)) {
1649 // Make sure bitwidth is allowed.
1650 if (!llvm::isa<spirv::SPIRVType>(Val: value.getType()))
1651 return emitOpError(message: "default value bitwidth disallowed");
1652 return success();
1653 }
1654 return emitOpError(
1655 message: "default value can only be a bool, integer, or float scalar");
1656}
1657
1658//===----------------------------------------------------------------------===//
1659// spirv.VectorShuffle
1660//===----------------------------------------------------------------------===//
1661
1662LogicalResult spirv::VectorShuffleOp::verify() {
1663 VectorType resultType = llvm::cast<VectorType>(Val: getType());
1664
1665 size_t numResultElements = resultType.getNumElements();
1666 if (numResultElements != getComponents().size())
1667 return emitOpError(message: "result type element count (")
1668 << numResultElements
1669 << ") mismatch with the number of component selectors ("
1670 << getComponents().size() << ")";
1671
1672 size_t totalSrcElements =
1673 llvm::cast<VectorType>(Val: getVector1().getType()).getNumElements() +
1674 llvm::cast<VectorType>(Val: getVector2().getType()).getNumElements();
1675
1676 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1677 uint32_t index = selector.getZExtValue();
1678 if (index >= totalSrcElements &&
1679 index != std::numeric_limits<uint32_t>().max())
1680 return emitOpError(message: "component selector ")
1681 << index << " out of range: expected to be in [0, "
1682 << totalSrcElements << ") or 0xffffffff";
1683 }
1684 return success();
1685}
1686
1687//===----------------------------------------------------------------------===//
1688// spirv.MatrixTimesScalar
1689//===----------------------------------------------------------------------===//
1690
1691LogicalResult spirv::MatrixTimesScalarOp::verify() {
1692 Type elementType =
1693 llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1694 .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
1695 caseFn: [](auto matrixType) { return matrixType.getElementType(); })
1696 .Default(defaultFn: [](Type) { return nullptr; });
1697
1698 assert(elementType && "Unhandled type");
1699
1700 // Check that the scalar type is the same as the matrix element type.
1701 if (getScalar().getType() != elementType)
1702 return emitOpError(message: "input matrix components' type and scaling value must "
1703 "have the same type");
1704
1705 return success();
1706}
1707
1708//===----------------------------------------------------------------------===//
1709// spirv.Transpose
1710//===----------------------------------------------------------------------===//
1711
1712LogicalResult spirv::TransposeOp::verify() {
1713 auto inputMatrix = llvm::cast<spirv::MatrixType>(Val: getMatrix().getType());
1714 auto resultMatrix = llvm::cast<spirv::MatrixType>(Val: getResult().getType());
1715
1716 // Verify that the input and output matrices have correct shapes.
1717 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1718 return emitError(message: "input matrix rows count must be equal to "
1719 "output matrix columns count");
1720
1721 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1722 return emitError(message: "input matrix columns count must be equal to "
1723 "output matrix rows count");
1724
1725 // Verify that the input and output matrices have the same component type
1726 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1727 return emitError(message: "input and output matrices must have the same "
1728 "component type");
1729
1730 return success();
1731}
1732
1733//===----------------------------------------------------------------------===//
1734// spirv.MatrixTimesVector
1735//===----------------------------------------------------------------------===//
1736
1737LogicalResult spirv::MatrixTimesVectorOp::verify() {
1738 auto matrixType = llvm::cast<spirv::MatrixType>(Val: getMatrix().getType());
1739 auto vectorType = llvm::cast<VectorType>(Val: getVector().getType());
1740 auto resultType = llvm::cast<VectorType>(Val: getType());
1741
1742 if (matrixType.getNumColumns() != vectorType.getNumElements())
1743 return emitOpError(message: "matrix columns (")
1744 << matrixType.getNumColumns() << ") must match vector operand size ("
1745 << vectorType.getNumElements() << ")";
1746
1747 if (resultType.getNumElements() != matrixType.getNumRows())
1748 return emitOpError(message: "result size (")
1749 << resultType.getNumElements() << ") must match the matrix rows ("
1750 << matrixType.getNumRows() << ")";
1751
1752 if (matrixType.getElementType() != resultType.getElementType())
1753 return emitOpError(message: "matrix and result element types must match");
1754
1755 return success();
1756}
1757
1758//===----------------------------------------------------------------------===//
1759// spirv.VectorTimesMatrix
1760//===----------------------------------------------------------------------===//
1761
1762LogicalResult spirv::VectorTimesMatrixOp::verify() {
1763 auto vectorType = llvm::cast<VectorType>(Val: getVector().getType());
1764 auto matrixType = llvm::cast<spirv::MatrixType>(Val: getMatrix().getType());
1765 auto resultType = llvm::cast<VectorType>(Val: getType());
1766
1767 if (matrixType.getNumRows() != vectorType.getNumElements())
1768 return emitOpError(message: "number of components in vector must equal the number "
1769 "of components in each column in matrix");
1770
1771 if (resultType.getNumElements() != matrixType.getNumColumns())
1772 return emitOpError(message: "number of columns in matrix must equal the number of "
1773 "components in result");
1774
1775 if (matrixType.getElementType() != resultType.getElementType())
1776 return emitOpError(message: "matrix must be a matrix with the same component type "
1777 "as the component type in result");
1778
1779 return success();
1780}
1781
1782//===----------------------------------------------------------------------===//
1783// spirv.MatrixTimesMatrix
1784//===----------------------------------------------------------------------===//
1785
1786LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1787 auto leftMatrix = llvm::cast<spirv::MatrixType>(Val: getLeftmatrix().getType());
1788 auto rightMatrix = llvm::cast<spirv::MatrixType>(Val: getRightmatrix().getType());
1789 auto resultMatrix = llvm::cast<spirv::MatrixType>(Val: getResult().getType());
1790
1791 // left matrix columns' count and right matrix rows' count must be equal
1792 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1793 return emitError(message: "left matrix columns' count must be equal to "
1794 "the right matrix rows' count");
1795
1796 // right and result matrices columns' count must be the same
1797 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1798 return emitError(
1799 message: "right and result matrices must have equal columns' count");
1800
1801 // right and result matrices component type must be the same
1802 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1803 return emitError(message: "right and result matrices' component type must"
1804 " be the same");
1805
1806 // left and result matrices component type must be the same
1807 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1808 return emitError(message: "left and result matrices' component type"
1809 " must be the same");
1810
1811 // left and result matrices rows count must be the same
1812 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1813 return emitError(message: "left and result matrices must have equal rows' count");
1814
1815 return success();
1816}
1817
1818//===----------------------------------------------------------------------===//
1819// spirv.SpecConstantComposite
1820//===----------------------------------------------------------------------===//
1821
1822ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
1823 OperationState &result) {
1824
1825 StringAttr compositeName;
1826 if (parser.parseSymbolName(result&: compositeName, attrName: SymbolTable::getSymbolAttrName(),
1827 attrs&: result.attributes))
1828 return failure();
1829
1830 if (parser.parseLParen())
1831 return failure();
1832
1833 SmallVector<Attribute, 4> constituents;
1834
1835 do {
1836 // The name of the constituent attribute isn't important
1837 const char *attrName = "spec_const";
1838 FlatSymbolRefAttr specConstRef;
1839 NamedAttrList attrs;
1840
1841 if (parser.parseAttribute(result&: specConstRef, type: Type(), attrName, attrs))
1842 return failure();
1843
1844 constituents.push_back(Elt: specConstRef);
1845 } while (!parser.parseOptionalComma());
1846
1847 if (parser.parseRParen())
1848 return failure();
1849
1850 StringAttr compositeSpecConstituentsName =
1851 spirv::SpecConstantCompositeOp::getConstituentsAttrName(name: result.name);
1852 result.addAttribute(name: compositeSpecConstituentsName,
1853 attr: parser.getBuilder().getArrayAttr(value: constituents));
1854
1855 Type type;
1856 if (parser.parseColonType(result&: type))
1857 return failure();
1858
1859 StringAttr typeAttrName =
1860 spirv::SpecConstantCompositeOp::getTypeAttrName(name: result.name);
1861 result.addAttribute(name: typeAttrName, attr: TypeAttr::get(type));
1862
1863 return success();
1864}
1865
1866void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
1867 printer << " ";
1868 printer.printSymbolName(symbolRef: getSymName());
1869 printer << " (" << llvm::interleaved(R: this->getConstituents().getValue())
1870 << ") : " << getType();
1871}
1872
1873LogicalResult spirv::SpecConstantCompositeOp::verify() {
1874 auto cType = llvm::dyn_cast<spirv::CompositeType>(Val: getType());
1875 auto constituents = this->getConstituents().getValue();
1876
1877 if (!cType)
1878 return emitError(message: "result type must be a composite type, but provided ")
1879 << getType();
1880
1881 if (llvm::isa<spirv::CooperativeMatrixType>(Val: cType))
1882 return emitError(message: "unsupported composite type ") << cType;
1883 if (constituents.size() != cType.getNumElements())
1884 return emitError(message: "has incorrect number of operands: expected ")
1885 << cType.getNumElements() << ", but provided "
1886 << constituents.size();
1887
1888 for (auto index : llvm::seq<uint32_t>(Begin: 0, End: constituents.size())) {
1889 auto constituent = llvm::cast<FlatSymbolRefAttr>(Val: constituents[index]);
1890
1891 auto constituentSpecConstOp =
1892 dyn_cast<spirv::SpecConstantOp>(Val: SymbolTable::lookupNearestSymbolFrom(
1893 from: (*this)->getParentOp(), symbol: constituent.getAttr()));
1894
1895 if (constituentSpecConstOp.getDefaultValue().getType() !=
1896 cType.getElementType(index))
1897 return emitError(message: "has incorrect types of operands: expected ")
1898 << cType.getElementType(index) << ", but provided "
1899 << constituentSpecConstOp.getDefaultValue().getType();
1900 }
1901
1902 return success();
1903}
1904
1905//===----------------------------------------------------------------------===//
1906// spirv.EXTSpecConstantCompositeReplicateOp
1907//===----------------------------------------------------------------------===//
1908
1909ParseResult
1910spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
1911 OperationState &result) {
1912 StringAttr compositeName;
1913 FlatSymbolRefAttr specConstRef;
1914 const char *attrName = "spec_const";
1915 NamedAttrList attrs;
1916 Type type;
1917
1918 if (parser.parseSymbolName(result&: compositeName, attrName: SymbolTable::getSymbolAttrName(),
1919 attrs&: result.attributes) ||
1920 parser.parseLParen() ||
1921 parser.parseAttribute(result&: specConstRef, type: Type(), attrName, attrs) ||
1922 parser.parseRParen() || parser.parseColonType(result&: type))
1923 return failure();
1924
1925 StringAttr compositeSpecConstituentName =
1926 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1927 name: result.name);
1928 result.addAttribute(name: compositeSpecConstituentName, attr: specConstRef);
1929
1930 StringAttr typeAttrName =
1931 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(name: result.name);
1932 result.addAttribute(name: typeAttrName, attr: TypeAttr::get(type));
1933
1934 return success();
1935}
1936
1937void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
1938 printer << " ";
1939 printer.printSymbolName(symbolRef: getSymName());
1940 printer << " (" << this->getConstituent() << ") : " << getType();
1941}
1942
1943LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1944 auto compositeType = dyn_cast<spirv::CompositeType>(Val: getType());
1945 if (!compositeType)
1946 return emitError(message: "result type must be a composite type, but provided ")
1947 << getType();
1948
1949 Operation *constituentOp = SymbolTable::lookupNearestSymbolFrom(
1950 from: (*this)->getParentOp(), symbol: this->getConstituent());
1951 if (!constituentOp)
1952 return emitError(
1953 message: "splat spec constant reference defining constituent not found");
1954
1955 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(Val: constituentOp);
1956 if (!constituentSpecConstOp)
1957 return emitError(message: "constituent is not a spec constant");
1958
1959 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1960 Type compositeElementType = compositeType.getElementType(0);
1961 if (constituentType != compositeElementType)
1962 return emitError(message: "constituent has incorrect type: expected ")
1963 << compositeElementType << ", but provided " << constituentType;
1964
1965 return success();
1966}
1967
1968//===----------------------------------------------------------------------===//
1969// spirv.SpecConstantOperation
1970//===----------------------------------------------------------------------===//
1971
1972ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
1973 OperationState &result) {
1974 Region *body = result.addRegion();
1975
1976 if (parser.parseKeyword(keyword: "wraps"))
1977 return failure();
1978
1979 body->push_back(block: new Block);
1980 Block &block = body->back();
1981 Operation *wrappedOp = parser.parseGenericOperation(insertBlock: &block, insertPt: block.begin());
1982
1983 if (!wrappedOp)
1984 return failure();
1985
1986 OpBuilder builder(parser.getContext());
1987 builder.setInsertionPointToEnd(&block);
1988 builder.create<spirv::YieldOp>(location: wrappedOp->getLoc(), args: wrappedOp->getResult(idx: 0));
1989 result.location = wrappedOp->getLoc();
1990
1991 result.addTypes(newTypes: wrappedOp->getResult(idx: 0).getType());
1992
1993 if (parser.parseOptionalAttrDict(result&: result.attributes))
1994 return failure();
1995
1996 return success();
1997}
1998
1999void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
2000 printer << " wraps ";
2001 printer.printGenericOp(op: &getBody().front().front());
2002}
2003
2004LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2005 Block &block = getRegion().getBlocks().front();
2006
2007 if (block.getOperations().size() != 2)
2008 return emitOpError(message: "expected exactly 2 nested ops");
2009
2010 Operation &enclosedOp = block.getOperations().front();
2011
2012 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
2013 return emitOpError(message: "invalid enclosed op");
2014
2015 for (auto operand : enclosedOp.getOperands())
2016 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2017 spirv::SpecConstantOperationOp>(Val: operand.getDefiningOp()))
2018 return emitOpError(
2019 message: "invalid operand, must be defined by a constant operation");
2020
2021 return success();
2022}
2023
2024//===----------------------------------------------------------------------===//
2025// spirv.GL.FrexpStruct
2026//===----------------------------------------------------------------------===//
2027
2028LogicalResult spirv::GLFrexpStructOp::verify() {
2029 spirv::StructType structTy =
2030 llvm::dyn_cast<spirv::StructType>(Val: getResult().getType());
2031
2032 if (structTy.getNumElements() != 2)
2033 return emitError(message: "result type must be a struct type with two memebers");
2034
2035 Type significandTy = structTy.getElementType(0);
2036 Type exponentTy = structTy.getElementType(1);
2037 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(Val&: exponentTy);
2038 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(Val&: exponentTy);
2039
2040 Type operandTy = getOperand().getType();
2041 VectorType operandVecTy = llvm::dyn_cast<VectorType>(Val&: operandTy);
2042 FloatType operandFTy = llvm::dyn_cast<FloatType>(Val&: operandTy);
2043
2044 if (significandTy != operandTy)
2045 return emitError(message: "member zero of the resulting struct type must be the "
2046 "same type as the operand");
2047
2048 if (exponentVecTy) {
2049 IntegerType componentIntTy =
2050 llvm::dyn_cast<IntegerType>(Val: exponentVecTy.getElementType());
2051 if (!componentIntTy || componentIntTy.getWidth() != 32)
2052 return emitError(message: "member one of the resulting struct type must"
2053 "be a scalar or vector of 32 bit integer type");
2054 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2055 return emitError(message: "member one of the resulting struct type "
2056 "must be a scalar or vector of 32 bit integer type");
2057 }
2058
2059 // Check that the two member types have the same number of components
2060 if (operandVecTy && exponentVecTy &&
2061 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2062 return success();
2063
2064 if (operandFTy && exponentIntTy)
2065 return success();
2066
2067 return emitError(message: "member one of the resulting struct type must have the same "
2068 "number of components as the operand type");
2069}
2070
2071//===----------------------------------------------------------------------===//
2072// spirv.GL.Ldexp
2073//===----------------------------------------------------------------------===//
2074
2075LogicalResult spirv::GLLdexpOp::verify() {
2076 Type significandType = getX().getType();
2077 Type exponentType = getExp().getType();
2078
2079 if (llvm::isa<FloatType>(Val: significandType) !=
2080 llvm::isa<IntegerType>(Val: exponentType))
2081 return emitOpError(message: "operands must both be scalars or vectors");
2082
2083 auto getNumElements = [](Type type) -> unsigned {
2084 if (auto vectorType = llvm::dyn_cast<VectorType>(Val&: type))
2085 return vectorType.getNumElements();
2086 return 1;
2087 };
2088
2089 if (getNumElements(significandType) != getNumElements(exponentType))
2090 return emitOpError(message: "operands must have the same number of elements");
2091
2092 return success();
2093}
2094
2095//===----------------------------------------------------------------------===//
2096// spirv.ShiftLeftLogicalOp
2097//===----------------------------------------------------------------------===//
2098
2099LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2100 return verifyShiftOp(op: *this);
2101}
2102
2103//===----------------------------------------------------------------------===//
2104// spirv.ShiftRightArithmeticOp
2105//===----------------------------------------------------------------------===//
2106
2107LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2108 return verifyShiftOp(op: *this);
2109}
2110
2111//===----------------------------------------------------------------------===//
2112// spirv.ShiftRightLogicalOp
2113//===----------------------------------------------------------------------===//
2114
2115LogicalResult spirv::ShiftRightLogicalOp::verify() {
2116 return verifyShiftOp(op: *this);
2117}
2118
2119//===----------------------------------------------------------------------===//
2120// spirv.VectorTimesScalarOp
2121//===----------------------------------------------------------------------===//
2122
2123LogicalResult spirv::VectorTimesScalarOp::verify() {
2124 if (getVector().getType() != getType())
2125 return emitOpError(message: "vector operand and result type mismatch");
2126 auto scalarType = llvm::cast<VectorType>(Val: getType()).getElementType();
2127 if (getScalar().getType() != scalarType)
2128 return emitOpError(message: "scalar operand and result element type match");
2129 return success();
2130}
2131

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp