1//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V dialect in MLIR.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14
15#include "SPIRVParsingUtils.h"
16
17#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/UB/IR/UBOps.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/DialectImplementation.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/Parser/Parser.h"
27#include "mlir/Transforms/InliningUtils.h"
28#include "llvm/ADT/Sequence.h"
29#include "llvm/ADT/StringExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31
32using namespace mlir;
33using namespace mlir::spirv;
34
35#include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
36
37//===----------------------------------------------------------------------===//
38// InlinerInterface
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given region contains spirv.Return or spirv.ReturnValue
42/// ops.
43static inline bool containsReturn(Region &region) {
44 return llvm::any_of(Range&: region, P: [](Block &block) {
45 Operation *terminator = block.getTerminator();
46 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(Val: terminator);
47 });
48}
49
50namespace {
51/// This class defines the interface for inlining within the SPIR-V dialect.
52struct SPIRVInlinerInterface : public DialectInlinerInterface {
53 using DialectInlinerInterface::DialectInlinerInterface;
54
55 /// All call operations within SPIRV can be inlined.
56 bool isLegalToInline(Operation *call, Operation *callable,
57 bool wouldBeCloned) const final {
58 return true;
59 }
60
61 /// Returns true if the given region 'src' can be inlined into the region
62 /// 'dest' that is attached to an operation registered to the current dialect.
63 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64 IRMapping &) const final {
65 // Return true here when inlining into spirv.func, spirv.mlir.selection, and
66 // spirv.mlir.loop operations.
67 auto *op = dest->getParentOp();
68 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(Val: op);
69 }
70
71 /// Returns true if the given operation 'op', that is registered to this
72 /// dialect, can be inlined into the region 'dest' that is attached to an
73 /// operation registered to the current dialect.
74 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75 IRMapping &) const final {
76 // TODO: Enable inlining structured control flows with return.
77 if ((isa<spirv::SelectionOp, spirv::LoopOp>(Val: op)) &&
78 containsReturn(region&: op->getRegion(index: 0)))
79 return false;
80 // TODO: we need to filter OpKill here to avoid inlining it to
81 // a loop continue construct:
82 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83 // For now, we just disallow inlining OpKill anywhere in the code,
84 // but this restriction should be relaxed, as pointed above.
85 if (isa<spirv::KillOp>(Val: op))
86 return false;
87
88 return true;
89 }
90
91 /// Handle the given inlined terminator by replacing it with a new operation
92 /// as necessary.
93 void handleTerminator(Operation *op, Block *newDest) const final {
94 if (auto returnOp = dyn_cast<spirv::ReturnOp>(Val: op)) {
95 OpBuilder(op).create<spirv::BranchOp>(location: op->getLoc(), args&: newDest);
96 op->erase();
97 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(Val: op)) {
98 OpBuilder(op).create<spirv::BranchOp>(location: retValOp->getLoc(), args&: newDest,
99 args: retValOp->getOperands());
100 op->erase();
101 }
102 }
103
104 /// Handle the given inlined terminator by replacing it with a new operation
105 /// as necessary.
106 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
107 // Only spirv.ReturnValue needs to be handled here.
108 auto retValOp = dyn_cast<spirv::ReturnValueOp>(Val: op);
109 if (!retValOp)
110 return;
111
112 // Replace the values directly with the return operands.
113 assert(valuesToRepl.size() == 1 &&
114 "spirv.ReturnValue expected to only handle one result");
115 valuesToRepl.front().replaceAllUsesWith(newValue: retValOp.getValue());
116 }
117};
118} // namespace
119
120//===----------------------------------------------------------------------===//
121// SPIR-V Dialect
122//===----------------------------------------------------------------------===//
123
124void SPIRVDialect::initialize() {
125 registerAttributes();
126 registerTypes();
127
128 // Add SPIR-V ops.
129 addOperations<
130#define GET_OP_LIST
131#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
132 >();
133
134 addInterfaces<SPIRVInlinerInterface>();
135
136 // Allow unknown operations because SPIR-V is extensible.
137 allowUnknownOperations();
138 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
139}
140
141std::string SPIRVDialect::getAttributeName(Decoration decoration) {
142 return llvm::convertToSnakeFromCamelCase(input: stringifyDecoration(decoration));
143}
144
145//===----------------------------------------------------------------------===//
146// Type Parsing
147//===----------------------------------------------------------------------===//
148
149// Forward declarations.
150template <typename ValTy>
151static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
152 DialectAsmParser &parser);
153template <>
154std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
155 DialectAsmParser &parser);
156
157template <>
158std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
159 DialectAsmParser &parser);
160
161static Type parseAndVerifyType(SPIRVDialect const &dialect,
162 DialectAsmParser &parser) {
163 Type type;
164 SMLoc typeLoc = parser.getCurrentLocation();
165 if (parser.parseType(result&: type))
166 return Type();
167
168 // Allow SPIR-V dialect types
169 if (&type.getDialect() == &dialect)
170 return type;
171
172 // Check other allowed types
173 if (auto t = llvm::dyn_cast<FloatType>(Val&: type)) {
174 // TODO: All float types are allowed for now, but this should be fixed.
175 } else if (auto t = llvm::dyn_cast<IntegerType>(Val&: type)) {
176 if (!ScalarType::isValid(t)) {
177 parser.emitError(loc: typeLoc,
178 message: "only 1/8/16/32/64-bit integer type allowed but found ")
179 << type;
180 return Type();
181 }
182 } else if (auto t = llvm::dyn_cast<VectorType>(Val&: type)) {
183 if (t.getRank() != 1) {
184 parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found ") << t;
185 return Type();
186 }
187 if (t.getNumElements() > 4) {
188 parser.emitError(
189 loc: typeLoc, message: "vector length has to be less than or equal to 4 but found ")
190 << t.getNumElements();
191 return Type();
192 }
193 } else if (auto t = dyn_cast<TensorArmType>(Val&: type)) {
194 if (!isa<ScalarType>(Val: t.getElementType())) {
195 parser.emitError(
196 loc: typeLoc, message: "only scalar element type allowed in tensor type but found ")
197 << t.getElementType();
198 return Type();
199 }
200 } else {
201 parser.emitError(loc: typeLoc, message: "cannot use ")
202 << type << " to compose SPIR-V types";
203 return Type();
204 }
205
206 return type;
207}
208
209static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
210 DialectAsmParser &parser) {
211 Type type;
212 SMLoc typeLoc = parser.getCurrentLocation();
213 if (parser.parseType(result&: type))
214 return Type();
215
216 if (auto t = llvm::dyn_cast<VectorType>(Val&: type)) {
217 if (t.getRank() != 1) {
218 parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found ") << t;
219 return Type();
220 }
221 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
222 parser.emitError(loc: typeLoc,
223 message: "matrix columns size has to be less than or equal "
224 "to 4 and greater than or equal 2, but found ")
225 << t.getNumElements();
226 return Type();
227 }
228
229 if (!llvm::isa<FloatType>(Val: t.getElementType())) {
230 parser.emitError(loc: typeLoc, message: "matrix columns' elements must be of "
231 "Float type, got ")
232 << t.getElementType();
233 return Type();
234 }
235 } else {
236 parser.emitError(loc: typeLoc, message: "matrix must be composed using vector "
237 "type, got ")
238 << type;
239 return Type();
240 }
241
242 return type;
243}
244
245static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
246 DialectAsmParser &parser) {
247 Type type;
248 SMLoc typeLoc = parser.getCurrentLocation();
249 if (parser.parseType(result&: type))
250 return Type();
251
252 if (!llvm::isa<ImageType>(Val: type)) {
253 parser.emitError(loc: typeLoc,
254 message: "sampled image must be composed using image type, got ")
255 << type;
256 return Type();
257 }
258
259 return type;
260}
261
262/// Parses an optional `, stride = N` assembly segment. If no parsing failure
263/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
264/// missing.
265static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
266 DialectAsmParser &parser,
267 unsigned &stride) {
268 if (failed(Result: parser.parseOptionalComma())) {
269 stride = 0;
270 return success();
271 }
272
273 if (parser.parseKeyword(keyword: "stride") || parser.parseEqual())
274 return failure();
275
276 SMLoc strideLoc = parser.getCurrentLocation();
277 std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
278 if (!optStride)
279 return failure();
280
281 if (!(stride = *optStride)) {
282 parser.emitError(loc: strideLoc, message: "ArrayStride must be greater than zero");
283 return failure();
284 }
285 return success();
286}
287
288// element-type ::= integer-type
289// | floating-point-type
290// | vector-type
291// | spirv-type
292//
293// array-type ::= `!spirv.array` `<` integer-literal `x` element-type
294// (`,` `stride` `=` integer-literal)? `>`
295static Type parseArrayType(SPIRVDialect const &dialect,
296 DialectAsmParser &parser) {
297 if (parser.parseLess())
298 return Type();
299
300 SmallVector<int64_t, 1> countDims;
301 SMLoc countLoc = parser.getCurrentLocation();
302 if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false))
303 return Type();
304 if (countDims.size() != 1) {
305 parser.emitError(loc: countLoc,
306 message: "expected single integer for array element count");
307 return Type();
308 }
309
310 // According to the SPIR-V spec:
311 // "Length is the number of elements in the array. It must be at least 1."
312 int64_t count = countDims[0];
313 if (count == 0) {
314 parser.emitError(loc: countLoc, message: "expected array length greater than 0");
315 return Type();
316 }
317
318 Type elementType = parseAndVerifyType(dialect, parser);
319 if (!elementType)
320 return Type();
321
322 unsigned stride = 0;
323 if (failed(Result: parseOptionalArrayStride(dialect, parser, stride)))
324 return Type();
325
326 if (parser.parseGreater())
327 return Type();
328 return ArrayType::get(elementType, elementCount: count, stride);
329}
330
331// cooperative-matrix-type ::=
332// `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
333// scope `,` use `>`
334static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
335 DialectAsmParser &parser) {
336 if (parser.parseLess())
337 return {};
338
339 SmallVector<int64_t, 2> dims;
340 SMLoc countLoc = parser.getCurrentLocation();
341 if (parser.parseDimensionList(dimensions&: dims, /*allowDynamic=*/false))
342 return {};
343
344 if (dims.size() != 2) {
345 parser.emitError(loc: countLoc, message: "expected row and column count");
346 return {};
347 }
348
349 auto elementTy = parseAndVerifyType(dialect, parser);
350 if (!elementTy)
351 return {};
352
353 Scope scope;
354 if (parser.parseComma() ||
355 spirv::parseEnumKeywordAttr(value&: scope, parser, attrName: "scope <id>"))
356 return {};
357
358 CooperativeMatrixUseKHR use;
359 if (parser.parseComma() ||
360 spirv::parseEnumKeywordAttr(value&: use, parser, attrName: "use <id>"))
361 return {};
362
363 if (parser.parseGreater())
364 return {};
365
366 return CooperativeMatrixType::get(elementType: elementTy, rows: dims[0], columns: dims[1], scope, use);
367}
368
369// tensor-arm-type ::=
370// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
371static Type parseTensorArmType(SPIRVDialect const &dialect,
372 DialectAsmParser &parser) {
373 if (parser.parseLess())
374 return {};
375
376 bool unranked = false;
377 SmallVector<int64_t, 4> dims;
378 SMLoc countLoc = parser.getCurrentLocation();
379
380 if (parser.parseOptionalStar().succeeded()) {
381 unranked = true;
382 if (parser.parseXInDimensionList())
383 return {};
384 } else if (parser.parseDimensionList(dimensions&: dims, /*allowDynamic=*/true)) {
385 return {};
386 }
387
388 if (!unranked && dims.empty()) {
389 parser.emitError(loc: countLoc, message: "arm.tensors do not support rank zero");
390 return {};
391 }
392
393 if (llvm::is_contained(Range&: dims, Element: 0)) {
394 parser.emitError(loc: countLoc, message: "arm.tensors do not support zero dimensions");
395 return {};
396 }
397
398 if (llvm::any_of(Range&: dims, P: [](int64_t dim) { return dim < 0; }) &&
399 llvm::any_of(Range&: dims, P: [](int64_t dim) { return dim > 0; })) {
400 parser.emitError(loc: countLoc, message: "arm.tensor shape dimensions must be either "
401 "fully dynamic or completed shaped");
402 return {};
403 }
404
405 auto elementTy = parseAndVerifyType(dialect, parser);
406 if (!elementTy)
407 return {};
408
409 if (parser.parseGreater())
410 return {};
411
412 return TensorArmType::get(shape: dims, elementType: elementTy);
413}
414
415// TODO: Reorder methods to be utilities first and parse*Type
416// methods in alphabetical order
417//
418// storage-class ::= `UniformConstant`
419// | `Uniform`
420// | `Workgroup`
421// | <and other storage classes...>
422//
423// pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
424static Type parsePointerType(SPIRVDialect const &dialect,
425 DialectAsmParser &parser) {
426 if (parser.parseLess())
427 return Type();
428
429 auto pointeeType = parseAndVerifyType(dialect, parser);
430 if (!pointeeType)
431 return Type();
432
433 StringRef storageClassSpec;
434 SMLoc storageClassLoc = parser.getCurrentLocation();
435 if (parser.parseComma() || parser.parseKeyword(keyword: &storageClassSpec))
436 return Type();
437
438 auto storageClass = symbolizeStorageClass(storageClassSpec);
439 if (!storageClass) {
440 parser.emitError(loc: storageClassLoc, message: "unknown storage class: ")
441 << storageClassSpec;
442 return Type();
443 }
444 if (parser.parseGreater())
445 return Type();
446 return PointerType::get(pointeeType, storageClass: *storageClass);
447}
448
449// runtime-array-type ::= `!spirv.rtarray` `<` element-type
450// (`,` `stride` `=` integer-literal)? `>`
451static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
452 DialectAsmParser &parser) {
453 if (parser.parseLess())
454 return Type();
455
456 Type elementType = parseAndVerifyType(dialect, parser);
457 if (!elementType)
458 return Type();
459
460 unsigned stride = 0;
461 if (failed(Result: parseOptionalArrayStride(dialect, parser, stride)))
462 return Type();
463
464 if (parser.parseGreater())
465 return Type();
466 return RuntimeArrayType::get(elementType, stride);
467}
468
469// matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
470static Type parseMatrixType(SPIRVDialect const &dialect,
471 DialectAsmParser &parser) {
472 if (parser.parseLess())
473 return Type();
474
475 SmallVector<int64_t, 1> countDims;
476 SMLoc countLoc = parser.getCurrentLocation();
477 if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false))
478 return Type();
479 if (countDims.size() != 1) {
480 parser.emitError(loc: countLoc, message: "expected single unsigned "
481 "integer for number of columns");
482 return Type();
483 }
484
485 int64_t columnCount = countDims[0];
486 // According to the specification, Matrices can have 2, 3, or 4 columns
487 if (columnCount < 2 || columnCount > 4) {
488 parser.emitError(loc: countLoc, message: "matrix is expected to have 2, 3, or 4 "
489 "columns");
490 return Type();
491 }
492
493 Type columnType = parseAndVerifyMatrixType(dialect, parser);
494 if (!columnType)
495 return Type();
496
497 if (parser.parseGreater())
498 return Type();
499
500 return MatrixType::get(columnType, columnCount);
501}
502
503// Specialize this function to parse each of the parameters that define an
504// ImageType. By default it assumes this is an enum type.
505template <typename ValTy>
506static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
507 DialectAsmParser &parser) {
508 StringRef enumSpec;
509 SMLoc enumLoc = parser.getCurrentLocation();
510 if (parser.parseKeyword(keyword: &enumSpec)) {
511 return std::nullopt;
512 }
513
514 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
515 if (!val)
516 parser.emitError(loc: enumLoc, message: "unknown attribute: '") << enumSpec << "'";
517 return val;
518}
519
520template <>
521std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
522 DialectAsmParser &parser) {
523 // TODO: Further verify that the element type can be sampled
524 auto ty = parseAndVerifyType(dialect, parser);
525 if (!ty)
526 return std::nullopt;
527 return ty;
528}
529
530template <typename IntTy>
531static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
532 DialectAsmParser &parser) {
533 IntTy offsetVal = std::numeric_limits<IntTy>::max();
534 if (parser.parseInteger(offsetVal))
535 return std::nullopt;
536 return offsetVal;
537}
538
539template <>
540std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
541 DialectAsmParser &parser) {
542 return parseAndVerifyInteger<unsigned>(dialect, parser);
543}
544
545namespace {
546// Functor object to parse a comma separated list of specs. The function
547// parseAndVerify does the actual parsing and verification of individual
548// elements. This is a functor since parsing the last element of the list
549// (termination condition) needs partial specialization.
550template <typename ParseType, typename... Args>
551struct ParseCommaSeparatedList {
552 std::optional<std::tuple<ParseType, Args...>>
553 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
554 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
555 if (!parseVal)
556 return std::nullopt;
557
558 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
559 if (numArgs != 0 && failed(Result: parser.parseComma()))
560 return std::nullopt;
561 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
562 if (!remainingValues)
563 return std::nullopt;
564 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
565 remainingValues.value());
566 }
567};
568
569// Partial specialization of the function to parse a comma separated list of
570// specs to parse the last element of the list.
571template <typename ParseType>
572struct ParseCommaSeparatedList<ParseType> {
573 std::optional<std::tuple<ParseType>>
574 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
575 if (auto value = parseAndVerify<ParseType>(dialect, parser))
576 return std::tuple<ParseType>(*value);
577 return std::nullopt;
578 }
579};
580} // namespace
581
582// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
583//
584// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
585//
586// arrayed-info ::= `NonArrayed` | `Arrayed`
587//
588// sampling-info ::= `SingleSampled` | `MultiSampled`
589//
590// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
591//
592// format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
593//
594// image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
595// arrayed-info `,` sampling-info `,`
596// sampler-use-info `,` format `>`
597static Type parseImageType(SPIRVDialect const &dialect,
598 DialectAsmParser &parser) {
599 if (parser.parseLess())
600 return Type();
601
602 auto value =
603 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
604 ImageSamplingInfo, ImageSamplerUseInfo,
605 ImageFormat>{}(dialect, parser);
606 if (!value)
607 return Type();
608
609 if (parser.parseGreater())
610 return Type();
611 return ImageType::get(*value);
612}
613
614// sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
615static Type parseSampledImageType(SPIRVDialect const &dialect,
616 DialectAsmParser &parser) {
617 if (parser.parseLess())
618 return Type();
619
620 Type parsedType = parseAndVerifySampledImageType(dialect, parser);
621 if (!parsedType)
622 return Type();
623
624 if (parser.parseGreater())
625 return Type();
626 return SampledImageType::get(imageType: parsedType);
627}
628
629// Parse decorations associated with a member.
630static ParseResult parseStructMemberDecorations(
631 SPIRVDialect const &dialect, DialectAsmParser &parser,
632 ArrayRef<Type> memberTypes,
633 SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
634 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
635
636 // Check if the first element is offset.
637 SMLoc offsetLoc = parser.getCurrentLocation();
638 StructType::OffsetInfo offset = 0;
639 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(result&: offset);
640 if (offsetParseResult.has_value()) {
641 if (failed(Result: *offsetParseResult))
642 return failure();
643
644 if (offsetInfo.size() != memberTypes.size() - 1) {
645 return parser.emitError(loc: offsetLoc,
646 message: "offset specification must be given for "
647 "all members");
648 }
649 offsetInfo.push_back(Elt: offset);
650 }
651
652 // Check for no spirv::Decorations.
653 if (succeeded(Result: parser.parseOptionalRSquare()))
654 return success();
655
656 // If there was an offset, make sure to parse the comma.
657 if (offsetParseResult.has_value() && parser.parseComma())
658 return failure();
659
660 // Check for spirv::Decorations.
661 auto parseDecorations = [&]() {
662 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
663 if (!memberDecoration)
664 return failure();
665
666 // Parse member decoration value if it exists.
667 if (succeeded(Result: parser.parseOptionalEqual())) {
668 auto memberDecorationValue =
669 parseAndVerifyInteger<uint32_t>(dialect, parser);
670
671 if (!memberDecorationValue)
672 return failure();
673
674 memberDecorationInfo.emplace_back(
675 Args: static_cast<uint32_t>(memberTypes.size() - 1), Args: 1,
676 Args&: memberDecoration.value(), Args&: memberDecorationValue.value());
677 } else {
678 memberDecorationInfo.emplace_back(
679 Args: static_cast<uint32_t>(memberTypes.size() - 1), Args: 0,
680 Args&: memberDecoration.value(), Args: 0);
681 }
682 return success();
683 };
684 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: parseDecorations)) ||
685 failed(Result: parser.parseRSquare()))
686 return failure();
687
688 return success();
689}
690
691// struct-member-decoration ::= integer-literal? spirv-decoration*
692// struct-type ::=
693// `!spirv.struct<` (id `,`)?
694// `(`
695// (spirv-type (`[` struct-member-decoration `]`)?)*
696// `)>`
697static Type parseStructType(SPIRVDialect const &dialect,
698 DialectAsmParser &parser) {
699 // TODO: This function is quite lengthy. Break it down into smaller chunks.
700
701 if (parser.parseLess())
702 return Type();
703
704 StringRef identifier;
705 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
706
707 // Check if this is an identified struct type.
708 if (succeeded(Result: parser.parseOptionalKeyword(keyword: &identifier))) {
709 // Check if this is a possible recursive reference.
710 auto structType =
711 StructType::getIdentified(context: dialect.getContext(), identifier);
712 cyclicParse = parser.tryStartCyclicParse(attrOrType: structType);
713 if (succeeded(Result: parser.parseOptionalGreater())) {
714 if (succeeded(Result: cyclicParse)) {
715 parser.emitError(
716 loc: parser.getNameLoc(),
717 message: "recursive struct reference not nested in struct definition");
718
719 return Type();
720 }
721
722 return structType;
723 }
724
725 if (failed(Result: parser.parseComma()))
726 return Type();
727
728 if (failed(Result: cyclicParse)) {
729 parser.emitError(loc: parser.getNameLoc(),
730 message: "identifier already used for an enclosing struct");
731 return Type();
732 }
733 }
734
735 if (failed(Result: parser.parseLParen()))
736 return Type();
737
738 if (succeeded(Result: parser.parseOptionalRParen()) &&
739 succeeded(Result: parser.parseOptionalGreater())) {
740 return StructType::getEmpty(context: dialect.getContext(), identifier);
741 }
742
743 StructType idStructTy;
744
745 if (!identifier.empty())
746 idStructTy = StructType::getIdentified(context: dialect.getContext(), identifier);
747
748 SmallVector<Type, 4> memberTypes;
749 SmallVector<StructType::OffsetInfo, 4> offsetInfo;
750 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
751
752 do {
753 Type memberType;
754 if (parser.parseType(result&: memberType))
755 return Type();
756 memberTypes.push_back(Elt: memberType);
757
758 if (succeeded(Result: parser.parseOptionalLSquare()))
759 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
760 memberDecorationInfo))
761 return Type();
762 } while (succeeded(Result: parser.parseOptionalComma()));
763
764 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
765 parser.emitError(loc: parser.getNameLoc(),
766 message: "offset specification must be given for all members");
767 return Type();
768 }
769
770 if (failed(Result: parser.parseRParen()) || failed(Result: parser.parseGreater()))
771 return Type();
772
773 if (!identifier.empty()) {
774 if (failed(Result: idStructTy.trySetBody(memberTypes, offsetInfo,
775 memberDecorations: memberDecorationInfo)))
776 return Type();
777 return idStructTy;
778 }
779
780 return StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationInfo);
781}
782
783// spirv-type ::= array-type
784// | element-type
785// | image-type
786// | pointer-type
787// | runtime-array-type
788// | sampled-image-type
789// | struct-type
790Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
791 StringRef keyword;
792 if (parser.parseKeyword(keyword: &keyword))
793 return Type();
794
795 if (keyword == "array")
796 return parseArrayType(dialect: *this, parser);
797 if (keyword == "coopmatrix")
798 return parseCooperativeMatrixType(dialect: *this, parser);
799 if (keyword == "image")
800 return parseImageType(dialect: *this, parser);
801 if (keyword == "ptr")
802 return parsePointerType(dialect: *this, parser);
803 if (keyword == "rtarray")
804 return parseRuntimeArrayType(dialect: *this, parser);
805 if (keyword == "sampled_image")
806 return parseSampledImageType(dialect: *this, parser);
807 if (keyword == "struct")
808 return parseStructType(dialect: *this, parser);
809 if (keyword == "matrix")
810 return parseMatrixType(dialect: *this, parser);
811 if (keyword == "arm.tensor")
812 return parseTensorArmType(dialect: *this, parser);
813 parser.emitError(loc: parser.getNameLoc(), message: "unknown SPIR-V type: ") << keyword;
814 return Type();
815}
816
817//===----------------------------------------------------------------------===//
818// Type Printing
819//===----------------------------------------------------------------------===//
820
821static void print(ArrayType type, DialectAsmPrinter &os) {
822 os << "array<" << type.getNumElements() << " x " << type.getElementType();
823 if (unsigned stride = type.getArrayStride())
824 os << ", stride=" << stride;
825 os << ">";
826}
827
828static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
829 os << "rtarray<" << type.getElementType();
830 if (unsigned stride = type.getArrayStride())
831 os << ", stride=" << stride;
832 os << ">";
833}
834
835static void print(PointerType type, DialectAsmPrinter &os) {
836 os << "ptr<" << type.getPointeeType() << ", "
837 << stringifyStorageClass(type.getStorageClass()) << ">";
838}
839
840static void print(ImageType type, DialectAsmPrinter &os) {
841 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
842 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
843 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
844 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
845 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
846 << stringifyImageFormat(type.getImageFormat()) << ">";
847}
848
849static void print(SampledImageType type, DialectAsmPrinter &os) {
850 os << "sampled_image<" << type.getImageType() << ">";
851}
852
853static void print(StructType type, DialectAsmPrinter &os) {
854 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
855
856 os << "struct<";
857
858 if (type.isIdentified()) {
859 os << type.getIdentifier();
860
861 cyclicPrint = os.tryStartCyclicPrint(attrOrType: type);
862 if (failed(Result: cyclicPrint)) {
863 os << ">";
864 return;
865 }
866
867 os << ", ";
868 }
869
870 os << "(";
871
872 auto printMember = [&](unsigned i) {
873 os << type.getElementType(i);
874 SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
875 type.getMemberDecorations(i, decorationsInfo&: decorations);
876 if (type.hasOffset() || !decorations.empty()) {
877 os << " [";
878 if (type.hasOffset()) {
879 os << type.getMemberOffset(i);
880 if (!decorations.empty())
881 os << ", ";
882 }
883 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
884 os << stringifyDecoration(decoration.decoration);
885 if (decoration.hasValue) {
886 os << "=" << decoration.decorationValue;
887 }
888 };
889 llvm::interleaveComma(c: decorations, os, each_fn: eachFn);
890 os << "]";
891 }
892 };
893 llvm::interleaveComma(c: llvm::seq<unsigned>(Begin: 0, End: type.getNumElements()), os,
894 each_fn: printMember);
895 os << ")>";
896}
897
898static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
899 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
900 << type.getElementType() << ", " << type.getScope() << ", "
901 << type.getUse() << ">";
902}
903
904static void print(MatrixType type, DialectAsmPrinter &os) {
905 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
906 os << ">";
907}
908
909static void print(TensorArmType type, DialectAsmPrinter &os) {
910 os << "arm.tensor<";
911
912 llvm::interleave(
913 c: type.getShape(), os,
914 each_fn: [&](int64_t dim) {
915 if (ShapedType::isDynamic(dValue: dim))
916 os << '?';
917 else
918 os << dim;
919 },
920 separator: "x");
921 if (!type.hasRank()) {
922 os << "*";
923 }
924 os << "x" << type.getElementType() << ">";
925}
926
927void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
928 TypeSwitch<Type>(type)
929 .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
930 ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
931 caseFn: [&](auto type) { print(type, os); })
932 .Default(defaultFn: [](Type) { llvm_unreachable("unhandled SPIR-V type"); });
933}
934
935//===----------------------------------------------------------------------===//
936// Constant
937//===----------------------------------------------------------------------===//
938
939Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
940 Attribute value, Type type,
941 Location loc) {
942 if (auto poison = dyn_cast<ub::PoisonAttr>(Val&: value))
943 return builder.create<ub::PoisonOp>(location: loc, args&: type, args&: poison);
944
945 if (!spirv::ConstantOp::isBuildableWith(type))
946 return nullptr;
947
948 return builder.create<spirv::ConstantOp>(location: loc, args&: type, args&: value);
949}
950
951//===----------------------------------------------------------------------===//
952// Shader Interface ABI
953//===----------------------------------------------------------------------===//
954
955LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
956 NamedAttribute attribute) {
957 StringRef symbol = attribute.getName().strref();
958 Attribute attr = attribute.getValue();
959
960 if (symbol == spirv::getEntryPointABIAttrName()) {
961 if (!llvm::isa<spirv::EntryPointABIAttr>(Val: attr)) {
962 return op->emitError(message: "'")
963 << symbol << "' attribute must be an entry point ABI attribute";
964 }
965 } else if (symbol == spirv::getTargetEnvAttrName()) {
966 if (!llvm::isa<spirv::TargetEnvAttr>(Val: attr))
967 return op->emitError(message: "'") << symbol << "' must be a spirv::TargetEnvAttr";
968 } else {
969 return op->emitError(message: "found unsupported '")
970 << symbol << "' attribute on operation";
971 }
972
973 return success();
974}
975
976/// Verifies the given SPIR-V `attribute` attached to a value of the given
977/// `valueType` is valid.
978static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
979 NamedAttribute attribute) {
980 StringRef symbol = attribute.getName().strref();
981 Attribute attr = attribute.getValue();
982
983 if (symbol == spirv::getInterfaceVarABIAttrName()) {
984 auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(Val&: attr);
985 if (!varABIAttr)
986 return emitError(loc, message: "'")
987 << symbol << "' must be a spirv::InterfaceVarABIAttr";
988
989 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
990 return emitError(loc, message: "'") << symbol
991 << "' attribute cannot specify storage class "
992 "when attaching to a non-scalar value";
993 return success();
994 }
995 if (symbol == spirv::DecorationAttr::name) {
996 if (!isa<spirv::DecorationAttr>(Val: attr))
997 return emitError(loc, message: "'")
998 << symbol << "' must be a spirv::DecorationAttr";
999 return success();
1000 }
1001
1002 return emitError(loc, message: "found unsupported '")
1003 << symbol << "' attribute on region argument";
1004}
1005
1006LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
1007 unsigned regionIndex,
1008 unsigned argIndex,
1009 NamedAttribute attribute) {
1010 auto funcOp = dyn_cast<FunctionOpInterface>(Val: op);
1011 if (!funcOp)
1012 return success();
1013 Type argType = funcOp.getArgumentTypes()[argIndex];
1014
1015 return verifyRegionAttribute(loc: op->getLoc(), valueType: argType, attribute);
1016}
1017
1018LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1019 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
1020 NamedAttribute attribute) {
1021 return op->emitError(message: "cannot attach SPIR-V attributes to region result");
1022}
1023

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp