1//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V dialect in MLIR.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14
15#include "SPIRVParsingUtils.h"
16
17#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/UB/IR/UBOps.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/DialectImplementation.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/Parser/Parser.h"
27#include "mlir/Transforms/InliningUtils.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/Sequence.h"
30#include "llvm/ADT/SetVector.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/ADT/StringMap.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/raw_ostream.h"
35
36using namespace mlir;
37using namespace mlir::spirv;
38
39#include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40
41//===----------------------------------------------------------------------===//
42// InlinerInterface
43//===----------------------------------------------------------------------===//
44
45/// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46/// ops.
47static inline bool containsReturn(Region &region) {
48 return llvm::any_of(region, [](Block &block) {
49 Operation *terminator = block.getTerminator();
50 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51 });
52}
53
54namespace {
55/// This class defines the interface for inlining within the SPIR-V dialect.
56struct SPIRVInlinerInterface : public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
58
59 /// All call operations within SPIRV can be inlined.
60 bool isLegalToInline(Operation *call, Operation *callable,
61 bool wouldBeCloned) const final {
62 return true;
63 }
64
65 /// Returns true if the given region 'src' can be inlined into the region
66 /// 'dest' that is attached to an operation registered to the current dialect.
67 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68 IRMapping &) const final {
69 // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70 // spirv.mlir.loop operations.
71 auto *op = dest->getParentOp();
72 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
73 }
74
75 /// Returns true if the given operation 'op', that is registered to this
76 /// dialect, can be inlined into the region 'dest' that is attached to an
77 /// operation registered to the current dialect.
78 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79 IRMapping &) const final {
80 // TODO: Enable inlining structured control flows with return.
81 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82 containsReturn(op->getRegion(0)))
83 return false;
84 // TODO: we need to filter OpKill here to avoid inlining it to
85 // a loop continue construct:
86 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87 // However OpKill is fragment shader specific and we don't support it yet.
88 return true;
89 }
90
91 /// Handle the given inlined terminator by replacing it with a new operation
92 /// as necessary.
93 void handleTerminator(Operation *op, Block *newDest) const final {
94 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95 OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
96 op->erase();
97 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
98 OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
99 retValOp->getOperands());
100 op->erase();
101 }
102 }
103
104 /// Handle the given inlined terminator by replacing it with a new operation
105 /// as necessary.
106 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
107 // Only spirv.ReturnValue needs to be handled here.
108 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
109 if (!retValOp)
110 return;
111
112 // Replace the values directly with the return operands.
113 assert(valuesToRepl.size() == 1 &&
114 "spirv.ReturnValue expected to only handle one result");
115 valuesToRepl.front().replaceAllUsesWith(newValue: retValOp.getValue());
116 }
117};
118} // namespace
119
120//===----------------------------------------------------------------------===//
121// SPIR-V Dialect
122//===----------------------------------------------------------------------===//
123
124void SPIRVDialect::initialize() {
125 registerAttributes();
126 registerTypes();
127
128 // Add SPIR-V ops.
129 addOperations<
130#define GET_OP_LIST
131#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
132 >();
133
134 addInterfaces<SPIRVInlinerInterface>();
135
136 // Allow unknown operations because SPIR-V is extensible.
137 allowUnknownOperations();
138 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
139}
140
141std::string SPIRVDialect::getAttributeName(Decoration decoration) {
142 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
143}
144
145//===----------------------------------------------------------------------===//
146// Type Parsing
147//===----------------------------------------------------------------------===//
148
149// Forward declarations.
150template <typename ValTy>
151static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
152 DialectAsmParser &parser);
153template <>
154std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
155 DialectAsmParser &parser);
156
157template <>
158std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
159 DialectAsmParser &parser);
160
161static Type parseAndVerifyType(SPIRVDialect const &dialect,
162 DialectAsmParser &parser) {
163 Type type;
164 SMLoc typeLoc = parser.getCurrentLocation();
165 if (parser.parseType(result&: type))
166 return Type();
167
168 // Allow SPIR-V dialect types
169 if (&type.getDialect() == &dialect)
170 return type;
171
172 // Check other allowed types
173 if (auto t = llvm::dyn_cast<FloatType>(Val&: type)) {
174 if (type.isBF16()) {
175 parser.emitError(loc: typeLoc, message: "cannot use 'bf16' to compose SPIR-V types");
176 return Type();
177 }
178 } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
179 if (!ScalarType::isValid(t)) {
180 parser.emitError(loc: typeLoc,
181 message: "only 1/8/16/32/64-bit integer type allowed but found ")
182 << type;
183 return Type();
184 }
185 } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
186 if (t.getRank() != 1) {
187 parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found ") << t;
188 return Type();
189 }
190 if (t.getNumElements() > 4) {
191 parser.emitError(
192 loc: typeLoc, message: "vector length has to be less than or equal to 4 but found ")
193 << t.getNumElements();
194 return Type();
195 }
196 } else {
197 parser.emitError(loc: typeLoc, message: "cannot use ")
198 << type << " to compose SPIR-V types";
199 return Type();
200 }
201
202 return type;
203}
204
205static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
206 DialectAsmParser &parser) {
207 Type type;
208 SMLoc typeLoc = parser.getCurrentLocation();
209 if (parser.parseType(result&: type))
210 return Type();
211
212 if (auto t = llvm::dyn_cast<VectorType>(type)) {
213 if (t.getRank() != 1) {
214 parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found ") << t;
215 return Type();
216 }
217 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
218 parser.emitError(loc: typeLoc,
219 message: "matrix columns size has to be less than or equal "
220 "to 4 and greater than or equal 2, but found ")
221 << t.getNumElements();
222 return Type();
223 }
224
225 if (!llvm::isa<FloatType>(t.getElementType())) {
226 parser.emitError(loc: typeLoc, message: "matrix columns' elements must be of "
227 "Float type, got ")
228 << t.getElementType();
229 return Type();
230 }
231 } else {
232 parser.emitError(loc: typeLoc, message: "matrix must be composed using vector "
233 "type, got ")
234 << type;
235 return Type();
236 }
237
238 return type;
239}
240
241static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
242 DialectAsmParser &parser) {
243 Type type;
244 SMLoc typeLoc = parser.getCurrentLocation();
245 if (parser.parseType(result&: type))
246 return Type();
247
248 if (!llvm::isa<ImageType>(Val: type)) {
249 parser.emitError(loc: typeLoc,
250 message: "sampled image must be composed using image type, got ")
251 << type;
252 return Type();
253 }
254
255 return type;
256}
257
258/// Parses an optional `, stride = N` assembly segment. If no parsing failure
259/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
260/// missing.
261static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
262 DialectAsmParser &parser,
263 unsigned &stride) {
264 if (failed(result: parser.parseOptionalComma())) {
265 stride = 0;
266 return success();
267 }
268
269 if (parser.parseKeyword(keyword: "stride") || parser.parseEqual())
270 return failure();
271
272 SMLoc strideLoc = parser.getCurrentLocation();
273 std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
274 if (!optStride)
275 return failure();
276
277 if (!(stride = *optStride)) {
278 parser.emitError(loc: strideLoc, message: "ArrayStride must be greater than zero");
279 return failure();
280 }
281 return success();
282}
283
284// element-type ::= integer-type
285// | floating-point-type
286// | vector-type
287// | spirv-type
288//
289// array-type ::= `!spirv.array` `<` integer-literal `x` element-type
290// (`,` `stride` `=` integer-literal)? `>`
291static Type parseArrayType(SPIRVDialect const &dialect,
292 DialectAsmParser &parser) {
293 if (parser.parseLess())
294 return Type();
295
296 SmallVector<int64_t, 1> countDims;
297 SMLoc countLoc = parser.getCurrentLocation();
298 if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false))
299 return Type();
300 if (countDims.size() != 1) {
301 parser.emitError(loc: countLoc,
302 message: "expected single integer for array element count");
303 return Type();
304 }
305
306 // According to the SPIR-V spec:
307 // "Length is the number of elements in the array. It must be at least 1."
308 int64_t count = countDims[0];
309 if (count == 0) {
310 parser.emitError(loc: countLoc, message: "expected array length greater than 0");
311 return Type();
312 }
313
314 Type elementType = parseAndVerifyType(dialect, parser);
315 if (!elementType)
316 return Type();
317
318 unsigned stride = 0;
319 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
320 return Type();
321
322 if (parser.parseGreater())
323 return Type();
324 return ArrayType::get(elementType, elementCount: count, stride);
325}
326
327// cooperative-matrix-type ::=
328// `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
329// scope `,` use `>`
330static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
331 DialectAsmParser &parser) {
332 if (parser.parseLess())
333 return {};
334
335 SmallVector<int64_t, 2> dims;
336 SMLoc countLoc = parser.getCurrentLocation();
337 if (parser.parseDimensionList(dimensions&: dims, /*allowDynamic=*/false))
338 return {};
339
340 if (dims.size() != 2) {
341 parser.emitError(loc: countLoc, message: "expected row and column count");
342 return {};
343 }
344
345 auto elementTy = parseAndVerifyType(dialect, parser);
346 if (!elementTy)
347 return {};
348
349 Scope scope;
350 if (parser.parseComma() ||
351 spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
352 return {};
353
354 CooperativeMatrixUseKHR use;
355 if (parser.parseComma() ||
356 spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
357 return {};
358
359 if (parser.parseGreater())
360 return {};
361
362 return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
363}
364
365// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
366// element-type
367// `,` layout `,` scope`>`
368static Type parseJointMatrixType(SPIRVDialect const &dialect,
369 DialectAsmParser &parser) {
370 if (parser.parseLess())
371 return Type();
372
373 SmallVector<int64_t, 2> dims;
374 SMLoc countLoc = parser.getCurrentLocation();
375 if (parser.parseDimensionList(dimensions&: dims, /*allowDynamic=*/false))
376 return Type();
377
378 if (dims.size() != 2) {
379 parser.emitError(loc: countLoc, message: "expected rows and columns size");
380 return Type();
381 }
382
383 auto elementTy = parseAndVerifyType(dialect, parser);
384 if (!elementTy)
385 return Type();
386 MatrixLayout matrixLayout;
387 if (parser.parseComma() ||
388 spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
389 return Type();
390 Scope scope;
391 if (parser.parseComma() ||
392 spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
393 return Type();
394 if (parser.parseGreater())
395 return Type();
396 return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
397 matrixLayout);
398}
399
400// TODO: Reorder methods to be utilities first and parse*Type
401// methods in alphabetical order
402//
403// storage-class ::= `UniformConstant`
404// | `Uniform`
405// | `Workgroup`
406// | <and other storage classes...>
407//
408// pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
409static Type parsePointerType(SPIRVDialect const &dialect,
410 DialectAsmParser &parser) {
411 if (parser.parseLess())
412 return Type();
413
414 auto pointeeType = parseAndVerifyType(dialect, parser);
415 if (!pointeeType)
416 return Type();
417
418 StringRef storageClassSpec;
419 SMLoc storageClassLoc = parser.getCurrentLocation();
420 if (parser.parseComma() || parser.parseKeyword(keyword: &storageClassSpec))
421 return Type();
422
423 auto storageClass = symbolizeStorageClass(storageClassSpec);
424 if (!storageClass) {
425 parser.emitError(loc: storageClassLoc, message: "unknown storage class: ")
426 << storageClassSpec;
427 return Type();
428 }
429 if (parser.parseGreater())
430 return Type();
431 return PointerType::get(pointeeType, *storageClass);
432}
433
434// runtime-array-type ::= `!spirv.rtarray` `<` element-type
435// (`,` `stride` `=` integer-literal)? `>`
436static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
437 DialectAsmParser &parser) {
438 if (parser.parseLess())
439 return Type();
440
441 Type elementType = parseAndVerifyType(dialect, parser);
442 if (!elementType)
443 return Type();
444
445 unsigned stride = 0;
446 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
447 return Type();
448
449 if (parser.parseGreater())
450 return Type();
451 return RuntimeArrayType::get(elementType, stride);
452}
453
454// matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
455static Type parseMatrixType(SPIRVDialect const &dialect,
456 DialectAsmParser &parser) {
457 if (parser.parseLess())
458 return Type();
459
460 SmallVector<int64_t, 1> countDims;
461 SMLoc countLoc = parser.getCurrentLocation();
462 if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false))
463 return Type();
464 if (countDims.size() != 1) {
465 parser.emitError(loc: countLoc, message: "expected single unsigned "
466 "integer for number of columns");
467 return Type();
468 }
469
470 int64_t columnCount = countDims[0];
471 // According to the specification, Matrices can have 2, 3, or 4 columns
472 if (columnCount < 2 || columnCount > 4) {
473 parser.emitError(loc: countLoc, message: "matrix is expected to have 2, 3, or 4 "
474 "columns");
475 return Type();
476 }
477
478 Type columnType = parseAndVerifyMatrixType(dialect, parser);
479 if (!columnType)
480 return Type();
481
482 if (parser.parseGreater())
483 return Type();
484
485 return MatrixType::get(columnType, columnCount);
486}
487
488// Specialize this function to parse each of the parameters that define an
489// ImageType. By default it assumes this is an enum type.
490template <typename ValTy>
491static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
492 DialectAsmParser &parser) {
493 StringRef enumSpec;
494 SMLoc enumLoc = parser.getCurrentLocation();
495 if (parser.parseKeyword(keyword: &enumSpec)) {
496 return std::nullopt;
497 }
498
499 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
500 if (!val)
501 parser.emitError(loc: enumLoc, message: "unknown attribute: '") << enumSpec << "'";
502 return val;
503}
504
505template <>
506std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
507 DialectAsmParser &parser) {
508 // TODO: Further verify that the element type can be sampled
509 auto ty = parseAndVerifyType(dialect, parser);
510 if (!ty)
511 return std::nullopt;
512 return ty;
513}
514
515template <typename IntTy>
516static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
517 DialectAsmParser &parser) {
518 IntTy offsetVal = std::numeric_limits<IntTy>::max();
519 if (parser.parseInteger(offsetVal))
520 return std::nullopt;
521 return offsetVal;
522}
523
524template <>
525std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
526 DialectAsmParser &parser) {
527 return parseAndVerifyInteger<unsigned>(dialect, parser);
528}
529
530namespace {
531// Functor object to parse a comma separated list of specs. The function
532// parseAndVerify does the actual parsing and verification of individual
533// elements. This is a functor since parsing the last element of the list
534// (termination condition) needs partial specialization.
535template <typename ParseType, typename... Args>
536struct ParseCommaSeparatedList {
537 std::optional<std::tuple<ParseType, Args...>>
538 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
539 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
540 if (!parseVal)
541 return std::nullopt;
542
543 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
544 if (numArgs != 0 && failed(result: parser.parseComma()))
545 return std::nullopt;
546 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
547 if (!remainingValues)
548 return std::nullopt;
549 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
550 remainingValues.value());
551 }
552};
553
554// Partial specialization of the function to parse a comma separated list of
555// specs to parse the last element of the list.
556template <typename ParseType>
557struct ParseCommaSeparatedList<ParseType> {
558 std::optional<std::tuple<ParseType>>
559 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
560 if (auto value = parseAndVerify<ParseType>(dialect, parser))
561 return std::tuple<ParseType>(*value);
562 return std::nullopt;
563 }
564};
565} // namespace
566
567// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
568//
569// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
570//
571// arrayed-info ::= `NonArrayed` | `Arrayed`
572//
573// sampling-info ::= `SingleSampled` | `MultiSampled`
574//
575// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
576//
577// format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
578//
579// image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
580// arrayed-info `,` sampling-info `,`
581// sampler-use-info `,` format `>`
582static Type parseImageType(SPIRVDialect const &dialect,
583 DialectAsmParser &parser) {
584 if (parser.parseLess())
585 return Type();
586
587 auto value =
588 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
589 ImageSamplingInfo, ImageSamplerUseInfo,
590 ImageFormat>{}(dialect, parser);
591 if (!value)
592 return Type();
593
594 if (parser.parseGreater())
595 return Type();
596 return ImageType::get(*value);
597}
598
599// sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
600static Type parseSampledImageType(SPIRVDialect const &dialect,
601 DialectAsmParser &parser) {
602 if (parser.parseLess())
603 return Type();
604
605 Type parsedType = parseAndVerifySampledImageType(dialect, parser);
606 if (!parsedType)
607 return Type();
608
609 if (parser.parseGreater())
610 return Type();
611 return SampledImageType::get(imageType: parsedType);
612}
613
614// Parse decorations associated with a member.
615static ParseResult parseStructMemberDecorations(
616 SPIRVDialect const &dialect, DialectAsmParser &parser,
617 ArrayRef<Type> memberTypes,
618 SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
619 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
620
621 // Check if the first element is offset.
622 SMLoc offsetLoc = parser.getCurrentLocation();
623 StructType::OffsetInfo offset = 0;
624 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(result&: offset);
625 if (offsetParseResult.has_value()) {
626 if (failed(result: *offsetParseResult))
627 return failure();
628
629 if (offsetInfo.size() != memberTypes.size() - 1) {
630 return parser.emitError(loc: offsetLoc,
631 message: "offset specification must be given for "
632 "all members");
633 }
634 offsetInfo.push_back(Elt: offset);
635 }
636
637 // Check for no spirv::Decorations.
638 if (succeeded(result: parser.parseOptionalRSquare()))
639 return success();
640
641 // If there was an offset, make sure to parse the comma.
642 if (offsetParseResult.has_value() && parser.parseComma())
643 return failure();
644
645 // Check for spirv::Decorations.
646 auto parseDecorations = [&]() {
647 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
648 if (!memberDecoration)
649 return failure();
650
651 // Parse member decoration value if it exists.
652 if (succeeded(result: parser.parseOptionalEqual())) {
653 auto memberDecorationValue =
654 parseAndVerifyInteger<uint32_t>(dialect, parser);
655
656 if (!memberDecorationValue)
657 return failure();
658
659 memberDecorationInfo.emplace_back(
660 static_cast<uint32_t>(memberTypes.size() - 1), 1,
661 memberDecoration.value(), memberDecorationValue.value());
662 } else {
663 memberDecorationInfo.emplace_back(
664 static_cast<uint32_t>(memberTypes.size() - 1), 0,
665 memberDecoration.value(), 0);
666 }
667 return success();
668 };
669 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: parseDecorations)) ||
670 failed(result: parser.parseRSquare()))
671 return failure();
672
673 return success();
674}
675
676// struct-member-decoration ::= integer-literal? spirv-decoration*
677// struct-type ::=
678// `!spirv.struct<` (id `,`)?
679// `(`
680// (spirv-type (`[` struct-member-decoration `]`)?)*
681// `)>`
682static Type parseStructType(SPIRVDialect const &dialect,
683 DialectAsmParser &parser) {
684 // TODO: This function is quite lengthy. Break it down into smaller chunks.
685
686 if (parser.parseLess())
687 return Type();
688
689 StringRef identifier;
690 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
691
692 // Check if this is an identified struct type.
693 if (succeeded(result: parser.parseOptionalKeyword(keyword: &identifier))) {
694 // Check if this is a possible recursive reference.
695 auto structType =
696 StructType::getIdentified(context: dialect.getContext(), identifier);
697 cyclicParse = parser.tryStartCyclicParse(structType);
698 if (succeeded(result: parser.parseOptionalGreater())) {
699 if (succeeded(result: cyclicParse)) {
700 parser.emitError(
701 loc: parser.getNameLoc(),
702 message: "recursive struct reference not nested in struct definition");
703
704 return Type();
705 }
706
707 return structType;
708 }
709
710 if (failed(result: parser.parseComma()))
711 return Type();
712
713 if (failed(result: cyclicParse)) {
714 parser.emitError(loc: parser.getNameLoc(),
715 message: "identifier already used for an enclosing struct");
716 return Type();
717 }
718 }
719
720 if (failed(result: parser.parseLParen()))
721 return Type();
722
723 if (succeeded(result: parser.parseOptionalRParen()) &&
724 succeeded(result: parser.parseOptionalGreater())) {
725 return StructType::getEmpty(context: dialect.getContext(), identifier);
726 }
727
728 StructType idStructTy;
729
730 if (!identifier.empty())
731 idStructTy = StructType::getIdentified(context: dialect.getContext(), identifier);
732
733 SmallVector<Type, 4> memberTypes;
734 SmallVector<StructType::OffsetInfo, 4> offsetInfo;
735 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
736
737 do {
738 Type memberType;
739 if (parser.parseType(result&: memberType))
740 return Type();
741 memberTypes.push_back(Elt: memberType);
742
743 if (succeeded(result: parser.parseOptionalLSquare()))
744 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
745 memberDecorationInfo))
746 return Type();
747 } while (succeeded(result: parser.parseOptionalComma()));
748
749 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
750 parser.emitError(loc: parser.getNameLoc(),
751 message: "offset specification must be given for all members");
752 return Type();
753 }
754
755 if (failed(result: parser.parseRParen()) || failed(result: parser.parseGreater()))
756 return Type();
757
758 if (!identifier.empty()) {
759 if (failed(result: idStructTy.trySetBody(memberTypes, offsetInfo,
760 memberDecorations: memberDecorationInfo)))
761 return Type();
762 return idStructTy;
763 }
764
765 return StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationInfo);
766}
767
768// spirv-type ::= array-type
769// | element-type
770// | image-type
771// | pointer-type
772// | runtime-array-type
773// | sampled-image-type
774// | struct-type
775Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
776 StringRef keyword;
777 if (parser.parseKeyword(&keyword))
778 return Type();
779
780 if (keyword == "array")
781 return parseArrayType(*this, parser);
782 if (keyword == "coopmatrix")
783 return parseCooperativeMatrixType(*this, parser);
784 if (keyword == "jointmatrix")
785 return parseJointMatrixType(*this, parser);
786 if (keyword == "image")
787 return parseImageType(*this, parser);
788 if (keyword == "ptr")
789 return parsePointerType(*this, parser);
790 if (keyword == "rtarray")
791 return parseRuntimeArrayType(*this, parser);
792 if (keyword == "sampled_image")
793 return parseSampledImageType(*this, parser);
794 if (keyword == "struct")
795 return parseStructType(*this, parser);
796 if (keyword == "matrix")
797 return parseMatrixType(*this, parser);
798 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
799 return Type();
800}
801
802//===----------------------------------------------------------------------===//
803// Type Printing
804//===----------------------------------------------------------------------===//
805
806static void print(ArrayType type, DialectAsmPrinter &os) {
807 os << "array<" << type.getNumElements() << " x " << type.getElementType();
808 if (unsigned stride = type.getArrayStride())
809 os << ", stride=" << stride;
810 os << ">";
811}
812
813static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
814 os << "rtarray<" << type.getElementType();
815 if (unsigned stride = type.getArrayStride())
816 os << ", stride=" << stride;
817 os << ">";
818}
819
820static void print(PointerType type, DialectAsmPrinter &os) {
821 os << "ptr<" << type.getPointeeType() << ", "
822 << stringifyStorageClass(type.getStorageClass()) << ">";
823}
824
825static void print(ImageType type, DialectAsmPrinter &os) {
826 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
827 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
828 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
829 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
830 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
831 << stringifyImageFormat(type.getImageFormat()) << ">";
832}
833
834static void print(SampledImageType type, DialectAsmPrinter &os) {
835 os << "sampled_image<" << type.getImageType() << ">";
836}
837
838static void print(StructType type, DialectAsmPrinter &os) {
839 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
840
841 os << "struct<";
842
843 if (type.isIdentified()) {
844 os << type.getIdentifier();
845
846 cyclicPrint = os.tryStartCyclicPrint(attrOrType: type);
847 if (failed(result: cyclicPrint)) {
848 os << ">";
849 return;
850 }
851
852 os << ", ";
853 }
854
855 os << "(";
856
857 auto printMember = [&](unsigned i) {
858 os << type.getElementType(i);
859 SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
860 type.getMemberDecorations(i, decorationsInfo&: decorations);
861 if (type.hasOffset() || !decorations.empty()) {
862 os << " [";
863 if (type.hasOffset()) {
864 os << type.getMemberOffset(i);
865 if (!decorations.empty())
866 os << ", ";
867 }
868 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
869 os << stringifyDecoration(decoration.decoration);
870 if (decoration.hasValue) {
871 os << "=" << decoration.decorationValue;
872 }
873 };
874 llvm::interleaveComma(c: decorations, os, each_fn: eachFn);
875 os << "]";
876 }
877 };
878 llvm::interleaveComma(c: llvm::seq<unsigned>(Begin: 0, End: type.getNumElements()), os,
879 each_fn: printMember);
880 os << ")>";
881}
882
883static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
884 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
885 << type.getElementType() << ", " << type.getScope() << ", "
886 << type.getUse() << ">";
887}
888
889static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
890 os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
891 os << type.getElementType() << ", "
892 << stringifyMatrixLayout(type.getMatrixLayout());
893 os << ", " << stringifyScope(type.getScope()) << ">";
894}
895
896static void print(MatrixType type, DialectAsmPrinter &os) {
897 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
898 os << ">";
899}
900
901void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
902 TypeSwitch<Type>(type)
903 .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, PointerType,
904 RuntimeArrayType, ImageType, SampledImageType, StructType,
905 MatrixType>([&](auto type) { print(type, os); })
906 .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
907}
908
909//===----------------------------------------------------------------------===//
910// Constant
911//===----------------------------------------------------------------------===//
912
913Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
914 Attribute value, Type type,
915 Location loc) {
916 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
917 return builder.create<ub::PoisonOp>(loc, type, poison);
918
919 if (!spirv::ConstantOp::isBuildableWith(type))
920 return nullptr;
921
922 return builder.create<spirv::ConstantOp>(loc, type, value);
923}
924
925//===----------------------------------------------------------------------===//
926// Shader Interface ABI
927//===----------------------------------------------------------------------===//
928
929LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
930 NamedAttribute attribute) {
931 StringRef symbol = attribute.getName().strref();
932 Attribute attr = attribute.getValue();
933
934 if (symbol == spirv::getEntryPointABIAttrName()) {
935 if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
936 return op->emitError("'")
937 << symbol << "' attribute must be an entry point ABI attribute";
938 }
939 } else if (symbol == spirv::getTargetEnvAttrName()) {
940 if (!llvm::isa<spirv::TargetEnvAttr>(attr))
941 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
942 } else {
943 return op->emitError("found unsupported '")
944 << symbol << "' attribute on operation";
945 }
946
947 return success();
948}
949
950/// Verifies the given SPIR-V `attribute` attached to a value of the given
951/// `valueType` is valid.
952static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
953 NamedAttribute attribute) {
954 StringRef symbol = attribute.getName().strref();
955 Attribute attr = attribute.getValue();
956
957 if (symbol == spirv::getInterfaceVarABIAttrName()) {
958 auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(Val&: attr);
959 if (!varABIAttr)
960 return emitError(loc, message: "'")
961 << symbol << "' must be a spirv::InterfaceVarABIAttr";
962
963 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
964 return emitError(loc, message: "'") << symbol
965 << "' attribute cannot specify storage class "
966 "when attaching to a non-scalar value";
967 return success();
968 }
969 if (symbol == spirv::DecorationAttr::name) {
970 if (!isa<spirv::DecorationAttr>(attr))
971 return emitError(loc, message: "'")
972 << symbol << "' must be a spirv::DecorationAttr";
973 return success();
974 }
975
976 return emitError(loc, message: "found unsupported '")
977 << symbol << "' attribute on region argument";
978}
979
980LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
981 unsigned regionIndex,
982 unsigned argIndex,
983 NamedAttribute attribute) {
984 auto funcOp = dyn_cast<FunctionOpInterface>(op);
985 if (!funcOp)
986 return success();
987 Type argType = funcOp.getArgumentTypes()[argIndex];
988
989 return verifyRegionAttribute(op->getLoc(), argType, attribute);
990}
991
992LogicalResult SPIRVDialect::verifyRegionResultAttribute(
993 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
994 NamedAttribute attribute) {
995 return op->emitError("cannot attach SPIR-V attributes to region result");
996}
997

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp