1//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V dialect in MLIR.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14
15#include "SPIRVParsingUtils.h"
16
17#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/UB/IR/UBOps.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/DialectImplementation.h"
25#include "mlir/IR/MLIRContext.h"
26#include "mlir/Parser/Parser.h"
27#include "mlir/Transforms/InliningUtils.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/Sequence.h"
30#include "llvm/ADT/SetVector.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/ADT/StringMap.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/raw_ostream.h"
35
36using namespace mlir;
37using namespace mlir::spirv;
38
39#include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40
41//===----------------------------------------------------------------------===//
42// InlinerInterface
43//===----------------------------------------------------------------------===//
44
45/// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46/// ops.
47static inline bool containsReturn(Region &region) {
48 return llvm::any_of(region, [](Block &block) {
49 Operation *terminator = block.getTerminator();
50 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51 });
52}
53
54namespace {
55/// This class defines the interface for inlining within the SPIR-V dialect.
56struct SPIRVInlinerInterface : public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
58
59 /// All call operations within SPIRV can be inlined.
60 bool isLegalToInline(Operation *call, Operation *callable,
61 bool wouldBeCloned) const final {
62 return true;
63 }
64
65 /// Returns true if the given region 'src' can be inlined into the region
66 /// 'dest' that is attached to an operation registered to the current dialect.
67 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68 IRMapping &) const final {
69 // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70 // spirv.mlir.loop operations.
71 auto *op = dest->getParentOp();
72 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
73 }
74
75 /// Returns true if the given operation 'op', that is registered to this
76 /// dialect, can be inlined into the region 'dest' that is attached to an
77 /// operation registered to the current dialect.
78 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79 IRMapping &) const final {
80 // TODO: Enable inlining structured control flows with return.
81 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82 containsReturn(op->getRegion(0)))
83 return false;
84 // TODO: we need to filter OpKill here to avoid inlining it to
85 // a loop continue construct:
86 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87 // For now, we just disallow inlining OpKill anywhere in the code,
88 // but this restriction should be relaxed, as pointed above.
89 if (isa<spirv::KillOp>(op))
90 return false;
91
92 return true;
93 }
94
95 /// Handle the given inlined terminator by replacing it with a new operation
96 /// as necessary.
97 void handleTerminator(Operation *op, Block *newDest) const final {
98 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
99 OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
100 op->erase();
101 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
102 OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
103 retValOp->getOperands());
104 op->erase();
105 }
106 }
107
108 /// Handle the given inlined terminator by replacing it with a new operation
109 /// as necessary.
110 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
111 // Only spirv.ReturnValue needs to be handled here.
112 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
113 if (!retValOp)
114 return;
115
116 // Replace the values directly with the return operands.
117 assert(valuesToRepl.size() == 1 &&
118 "spirv.ReturnValue expected to only handle one result");
119 valuesToRepl.front().replaceAllUsesWith(newValue: retValOp.getValue());
120 }
121};
122} // namespace
123
124//===----------------------------------------------------------------------===//
125// SPIR-V Dialect
126//===----------------------------------------------------------------------===//
127
128void SPIRVDialect::initialize() {
129 registerAttributes();
130 registerTypes();
131
132 // Add SPIR-V ops.
133 addOperations<
134#define GET_OP_LIST
135#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
136 >();
137
138 addInterfaces<SPIRVInlinerInterface>();
139
140 // Allow unknown operations because SPIR-V is extensible.
141 allowUnknownOperations();
142 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
143}
144
145std::string SPIRVDialect::getAttributeName(Decoration decoration) {
146 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
147}
148
149//===----------------------------------------------------------------------===//
150// Type Parsing
151//===----------------------------------------------------------------------===//
152
153// Forward declarations.
154template <typename ValTy>
155static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
156 DialectAsmParser &parser);
157template <>
158std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
159 DialectAsmParser &parser);
160
161template <>
162std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
163 DialectAsmParser &parser);
164
165static Type parseAndVerifyType(SPIRVDialect const &dialect,
166 DialectAsmParser &parser) {
167 Type type;
168 SMLoc typeLoc = parser.getCurrentLocation();
169 if (parser.parseType(result&: type))
170 return Type();
171
172 // Allow SPIR-V dialect types
173 if (&type.getDialect() == &dialect)
174 return type;
175
176 // Check other allowed types
177 if (auto t = llvm::dyn_cast<FloatType>(type)) {
178 if (type.isBF16()) {
179 parser.emitError(loc: typeLoc, message: "cannot use 'bf16' to compose SPIR-V types");
180 return Type();
181 }
182 } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
183 if (!ScalarType::isValid(t)) {
184 parser.emitError(loc: typeLoc,
185 message: "only 1/8/16/32/64-bit integer type allowed but found ")
186 << type;
187 return Type();
188 }
189 } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
190 if (t.getRank() != 1) {
191 parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found ") << t;
192 return Type();
193 }
194 if (t.getNumElements() > 4) {
195 parser.emitError(
196 loc: typeLoc, message: "vector length has to be less than or equal to 4 but found ")
197 << t.getNumElements();
198 return Type();
199 }
200 } else {
201 parser.emitError(loc: typeLoc, message: "cannot use ")
202 << type << " to compose SPIR-V types";
203 return Type();
204 }
205
206 return type;
207}
208
209static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
210 DialectAsmParser &parser) {
211 Type type;
212 SMLoc typeLoc = parser.getCurrentLocation();
213 if (parser.parseType(result&: type))
214 return Type();
215
216 if (auto t = llvm::dyn_cast<VectorType>(type)) {
217 if (t.getRank() != 1) {
218 parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found ") << t;
219 return Type();
220 }
221 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
222 parser.emitError(loc: typeLoc,
223 message: "matrix columns size has to be less than or equal "
224 "to 4 and greater than or equal 2, but found ")
225 << t.getNumElements();
226 return Type();
227 }
228
229 if (!llvm::isa<FloatType>(t.getElementType())) {
230 parser.emitError(loc: typeLoc, message: "matrix columns' elements must be of "
231 "Float type, got ")
232 << t.getElementType();
233 return Type();
234 }
235 } else {
236 parser.emitError(loc: typeLoc, message: "matrix must be composed using vector "
237 "type, got ")
238 << type;
239 return Type();
240 }
241
242 return type;
243}
244
245static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
246 DialectAsmParser &parser) {
247 Type type;
248 SMLoc typeLoc = parser.getCurrentLocation();
249 if (parser.parseType(result&: type))
250 return Type();
251
252 if (!llvm::isa<ImageType>(Val: type)) {
253 parser.emitError(loc: typeLoc,
254 message: "sampled image must be composed using image type, got ")
255 << type;
256 return Type();
257 }
258
259 return type;
260}
261
262/// Parses an optional `, stride = N` assembly segment. If no parsing failure
263/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
264/// missing.
265static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
266 DialectAsmParser &parser,
267 unsigned &stride) {
268 if (failed(Result: parser.parseOptionalComma())) {
269 stride = 0;
270 return success();
271 }
272
273 if (parser.parseKeyword(keyword: "stride") || parser.parseEqual())
274 return failure();
275
276 SMLoc strideLoc = parser.getCurrentLocation();
277 std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
278 if (!optStride)
279 return failure();
280
281 if (!(stride = *optStride)) {
282 parser.emitError(loc: strideLoc, message: "ArrayStride must be greater than zero");
283 return failure();
284 }
285 return success();
286}
287
288// element-type ::= integer-type
289// | floating-point-type
290// | vector-type
291// | spirv-type
292//
293// array-type ::= `!spirv.array` `<` integer-literal `x` element-type
294// (`,` `stride` `=` integer-literal)? `>`
295static Type parseArrayType(SPIRVDialect const &dialect,
296 DialectAsmParser &parser) {
297 if (parser.parseLess())
298 return Type();
299
300 SmallVector<int64_t, 1> countDims;
301 SMLoc countLoc = parser.getCurrentLocation();
302 if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false))
303 return Type();
304 if (countDims.size() != 1) {
305 parser.emitError(loc: countLoc,
306 message: "expected single integer for array element count");
307 return Type();
308 }
309
310 // According to the SPIR-V spec:
311 // "Length is the number of elements in the array. It must be at least 1."
312 int64_t count = countDims[0];
313 if (count == 0) {
314 parser.emitError(loc: countLoc, message: "expected array length greater than 0");
315 return Type();
316 }
317
318 Type elementType = parseAndVerifyType(dialect, parser);
319 if (!elementType)
320 return Type();
321
322 unsigned stride = 0;
323 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
324 return Type();
325
326 if (parser.parseGreater())
327 return Type();
328 return ArrayType::get(elementType, elementCount: count, stride);
329}
330
331// cooperative-matrix-type ::=
332// `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
333// scope `,` use `>`
334static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
335 DialectAsmParser &parser) {
336 if (parser.parseLess())
337 return {};
338
339 SmallVector<int64_t, 2> dims;
340 SMLoc countLoc = parser.getCurrentLocation();
341 if (parser.parseDimensionList(dimensions&: dims, /*allowDynamic=*/false))
342 return {};
343
344 if (dims.size() != 2) {
345 parser.emitError(loc: countLoc, message: "expected row and column count");
346 return {};
347 }
348
349 auto elementTy = parseAndVerifyType(dialect, parser);
350 if (!elementTy)
351 return {};
352
353 Scope scope;
354 if (parser.parseComma() ||
355 spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
356 return {};
357
358 CooperativeMatrixUseKHR use;
359 if (parser.parseComma() ||
360 spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
361 return {};
362
363 if (parser.parseGreater())
364 return {};
365
366 return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
367}
368
369// TODO: Reorder methods to be utilities first and parse*Type
370// methods in alphabetical order
371//
372// storage-class ::= `UniformConstant`
373// | `Uniform`
374// | `Workgroup`
375// | <and other storage classes...>
376//
377// pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
378static Type parsePointerType(SPIRVDialect const &dialect,
379 DialectAsmParser &parser) {
380 if (parser.parseLess())
381 return Type();
382
383 auto pointeeType = parseAndVerifyType(dialect, parser);
384 if (!pointeeType)
385 return Type();
386
387 StringRef storageClassSpec;
388 SMLoc storageClassLoc = parser.getCurrentLocation();
389 if (parser.parseComma() || parser.parseKeyword(keyword: &storageClassSpec))
390 return Type();
391
392 auto storageClass = symbolizeStorageClass(storageClassSpec);
393 if (!storageClass) {
394 parser.emitError(loc: storageClassLoc, message: "unknown storage class: ")
395 << storageClassSpec;
396 return Type();
397 }
398 if (parser.parseGreater())
399 return Type();
400 return PointerType::get(pointeeType, *storageClass);
401}
402
403// runtime-array-type ::= `!spirv.rtarray` `<` element-type
404// (`,` `stride` `=` integer-literal)? `>`
405static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
406 DialectAsmParser &parser) {
407 if (parser.parseLess())
408 return Type();
409
410 Type elementType = parseAndVerifyType(dialect, parser);
411 if (!elementType)
412 return Type();
413
414 unsigned stride = 0;
415 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
416 return Type();
417
418 if (parser.parseGreater())
419 return Type();
420 return RuntimeArrayType::get(elementType, stride);
421}
422
423// matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
424static Type parseMatrixType(SPIRVDialect const &dialect,
425 DialectAsmParser &parser) {
426 if (parser.parseLess())
427 return Type();
428
429 SmallVector<int64_t, 1> countDims;
430 SMLoc countLoc = parser.getCurrentLocation();
431 if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false))
432 return Type();
433 if (countDims.size() != 1) {
434 parser.emitError(loc: countLoc, message: "expected single unsigned "
435 "integer for number of columns");
436 return Type();
437 }
438
439 int64_t columnCount = countDims[0];
440 // According to the specification, Matrices can have 2, 3, or 4 columns
441 if (columnCount < 2 || columnCount > 4) {
442 parser.emitError(loc: countLoc, message: "matrix is expected to have 2, 3, or 4 "
443 "columns");
444 return Type();
445 }
446
447 Type columnType = parseAndVerifyMatrixType(dialect, parser);
448 if (!columnType)
449 return Type();
450
451 if (parser.parseGreater())
452 return Type();
453
454 return MatrixType::get(columnType, columnCount);
455}
456
457// Specialize this function to parse each of the parameters that define an
458// ImageType. By default it assumes this is an enum type.
459template <typename ValTy>
460static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
461 DialectAsmParser &parser) {
462 StringRef enumSpec;
463 SMLoc enumLoc = parser.getCurrentLocation();
464 if (parser.parseKeyword(keyword: &enumSpec)) {
465 return std::nullopt;
466 }
467
468 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
469 if (!val)
470 parser.emitError(loc: enumLoc, message: "unknown attribute: '") << enumSpec << "'";
471 return val;
472}
473
474template <>
475std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
476 DialectAsmParser &parser) {
477 // TODO: Further verify that the element type can be sampled
478 auto ty = parseAndVerifyType(dialect, parser);
479 if (!ty)
480 return std::nullopt;
481 return ty;
482}
483
484template <typename IntTy>
485static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
486 DialectAsmParser &parser) {
487 IntTy offsetVal = std::numeric_limits<IntTy>::max();
488 if (parser.parseInteger(offsetVal))
489 return std::nullopt;
490 return offsetVal;
491}
492
493template <>
494std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
495 DialectAsmParser &parser) {
496 return parseAndVerifyInteger<unsigned>(dialect, parser);
497}
498
499namespace {
500// Functor object to parse a comma separated list of specs. The function
501// parseAndVerify does the actual parsing and verification of individual
502// elements. This is a functor since parsing the last element of the list
503// (termination condition) needs partial specialization.
504template <typename ParseType, typename... Args>
505struct ParseCommaSeparatedList {
506 std::optional<std::tuple<ParseType, Args...>>
507 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
508 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
509 if (!parseVal)
510 return std::nullopt;
511
512 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
513 if (numArgs != 0 && failed(Result: parser.parseComma()))
514 return std::nullopt;
515 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
516 if (!remainingValues)
517 return std::nullopt;
518 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
519 remainingValues.value());
520 }
521};
522
523// Partial specialization of the function to parse a comma separated list of
524// specs to parse the last element of the list.
525template <typename ParseType>
526struct ParseCommaSeparatedList<ParseType> {
527 std::optional<std::tuple<ParseType>>
528 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
529 if (auto value = parseAndVerify<ParseType>(dialect, parser))
530 return std::tuple<ParseType>(*value);
531 return std::nullopt;
532 }
533};
534} // namespace
535
536// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
537//
538// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
539//
540// arrayed-info ::= `NonArrayed` | `Arrayed`
541//
542// sampling-info ::= `SingleSampled` | `MultiSampled`
543//
544// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
545//
546// format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
547//
548// image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
549// arrayed-info `,` sampling-info `,`
550// sampler-use-info `,` format `>`
551static Type parseImageType(SPIRVDialect const &dialect,
552 DialectAsmParser &parser) {
553 if (parser.parseLess())
554 return Type();
555
556 auto value =
557 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
558 ImageSamplingInfo, ImageSamplerUseInfo,
559 ImageFormat>{}(dialect, parser);
560 if (!value)
561 return Type();
562
563 if (parser.parseGreater())
564 return Type();
565 return ImageType::get(*value);
566}
567
568// sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
569static Type parseSampledImageType(SPIRVDialect const &dialect,
570 DialectAsmParser &parser) {
571 if (parser.parseLess())
572 return Type();
573
574 Type parsedType = parseAndVerifySampledImageType(dialect, parser);
575 if (!parsedType)
576 return Type();
577
578 if (parser.parseGreater())
579 return Type();
580 return SampledImageType::get(imageType: parsedType);
581}
582
583// Parse decorations associated with a member.
584static ParseResult parseStructMemberDecorations(
585 SPIRVDialect const &dialect, DialectAsmParser &parser,
586 ArrayRef<Type> memberTypes,
587 SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
588 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
589
590 // Check if the first element is offset.
591 SMLoc offsetLoc = parser.getCurrentLocation();
592 StructType::OffsetInfo offset = 0;
593 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(result&: offset);
594 if (offsetParseResult.has_value()) {
595 if (failed(Result: *offsetParseResult))
596 return failure();
597
598 if (offsetInfo.size() != memberTypes.size() - 1) {
599 return parser.emitError(loc: offsetLoc,
600 message: "offset specification must be given for "
601 "all members");
602 }
603 offsetInfo.push_back(Elt: offset);
604 }
605
606 // Check for no spirv::Decorations.
607 if (succeeded(Result: parser.parseOptionalRSquare()))
608 return success();
609
610 // If there was an offset, make sure to parse the comma.
611 if (offsetParseResult.has_value() && parser.parseComma())
612 return failure();
613
614 // Check for spirv::Decorations.
615 auto parseDecorations = [&]() {
616 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
617 if (!memberDecoration)
618 return failure();
619
620 // Parse member decoration value if it exists.
621 if (succeeded(Result: parser.parseOptionalEqual())) {
622 auto memberDecorationValue =
623 parseAndVerifyInteger<uint32_t>(dialect, parser);
624
625 if (!memberDecorationValue)
626 return failure();
627
628 memberDecorationInfo.emplace_back(
629 static_cast<uint32_t>(memberTypes.size() - 1), 1,
630 memberDecoration.value(), memberDecorationValue.value());
631 } else {
632 memberDecorationInfo.emplace_back(
633 static_cast<uint32_t>(memberTypes.size() - 1), 0,
634 memberDecoration.value(), 0);
635 }
636 return success();
637 };
638 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: parseDecorations)) ||
639 failed(Result: parser.parseRSquare()))
640 return failure();
641
642 return success();
643}
644
645// struct-member-decoration ::= integer-literal? spirv-decoration*
646// struct-type ::=
647// `!spirv.struct<` (id `,`)?
648// `(`
649// (spirv-type (`[` struct-member-decoration `]`)?)*
650// `)>`
651static Type parseStructType(SPIRVDialect const &dialect,
652 DialectAsmParser &parser) {
653 // TODO: This function is quite lengthy. Break it down into smaller chunks.
654
655 if (parser.parseLess())
656 return Type();
657
658 StringRef identifier;
659 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
660
661 // Check if this is an identified struct type.
662 if (succeeded(Result: parser.parseOptionalKeyword(keyword: &identifier))) {
663 // Check if this is a possible recursive reference.
664 auto structType =
665 StructType::getIdentified(context: dialect.getContext(), identifier);
666 cyclicParse = parser.tryStartCyclicParse(structType);
667 if (succeeded(Result: parser.parseOptionalGreater())) {
668 if (succeeded(Result: cyclicParse)) {
669 parser.emitError(
670 loc: parser.getNameLoc(),
671 message: "recursive struct reference not nested in struct definition");
672
673 return Type();
674 }
675
676 return structType;
677 }
678
679 if (failed(Result: parser.parseComma()))
680 return Type();
681
682 if (failed(Result: cyclicParse)) {
683 parser.emitError(loc: parser.getNameLoc(),
684 message: "identifier already used for an enclosing struct");
685 return Type();
686 }
687 }
688
689 if (failed(Result: parser.parseLParen()))
690 return Type();
691
692 if (succeeded(Result: parser.parseOptionalRParen()) &&
693 succeeded(Result: parser.parseOptionalGreater())) {
694 return StructType::getEmpty(context: dialect.getContext(), identifier);
695 }
696
697 StructType idStructTy;
698
699 if (!identifier.empty())
700 idStructTy = StructType::getIdentified(context: dialect.getContext(), identifier);
701
702 SmallVector<Type, 4> memberTypes;
703 SmallVector<StructType::OffsetInfo, 4> offsetInfo;
704 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
705
706 do {
707 Type memberType;
708 if (parser.parseType(result&: memberType))
709 return Type();
710 memberTypes.push_back(Elt: memberType);
711
712 if (succeeded(Result: parser.parseOptionalLSquare()))
713 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
714 memberDecorationInfo))
715 return Type();
716 } while (succeeded(Result: parser.parseOptionalComma()));
717
718 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
719 parser.emitError(loc: parser.getNameLoc(),
720 message: "offset specification must be given for all members");
721 return Type();
722 }
723
724 if (failed(Result: parser.parseRParen()) || failed(Result: parser.parseGreater()))
725 return Type();
726
727 if (!identifier.empty()) {
728 if (failed(Result: idStructTy.trySetBody(memberTypes, offsetInfo,
729 memberDecorations: memberDecorationInfo)))
730 return Type();
731 return idStructTy;
732 }
733
734 return StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationInfo);
735}
736
737// spirv-type ::= array-type
738// | element-type
739// | image-type
740// | pointer-type
741// | runtime-array-type
742// | sampled-image-type
743// | struct-type
744Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
745 StringRef keyword;
746 if (parser.parseKeyword(&keyword))
747 return Type();
748
749 if (keyword == "array")
750 return parseArrayType(*this, parser);
751 if (keyword == "coopmatrix")
752 return parseCooperativeMatrixType(*this, parser);
753 if (keyword == "image")
754 return parseImageType(*this, parser);
755 if (keyword == "ptr")
756 return parsePointerType(*this, parser);
757 if (keyword == "rtarray")
758 return parseRuntimeArrayType(*this, parser);
759 if (keyword == "sampled_image")
760 return parseSampledImageType(*this, parser);
761 if (keyword == "struct")
762 return parseStructType(*this, parser);
763 if (keyword == "matrix")
764 return parseMatrixType(*this, parser);
765 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
766 return Type();
767}
768
769//===----------------------------------------------------------------------===//
770// Type Printing
771//===----------------------------------------------------------------------===//
772
773static void print(ArrayType type, DialectAsmPrinter &os) {
774 os << "array<" << type.getNumElements() << " x " << type.getElementType();
775 if (unsigned stride = type.getArrayStride())
776 os << ", stride=" << stride;
777 os << ">";
778}
779
780static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
781 os << "rtarray<" << type.getElementType();
782 if (unsigned stride = type.getArrayStride())
783 os << ", stride=" << stride;
784 os << ">";
785}
786
787static void print(PointerType type, DialectAsmPrinter &os) {
788 os << "ptr<" << type.getPointeeType() << ", "
789 << stringifyStorageClass(type.getStorageClass()) << ">";
790}
791
792static void print(ImageType type, DialectAsmPrinter &os) {
793 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
794 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
795 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
796 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
797 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
798 << stringifyImageFormat(type.getImageFormat()) << ">";
799}
800
801static void print(SampledImageType type, DialectAsmPrinter &os) {
802 os << "sampled_image<" << type.getImageType() << ">";
803}
804
805static void print(StructType type, DialectAsmPrinter &os) {
806 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
807
808 os << "struct<";
809
810 if (type.isIdentified()) {
811 os << type.getIdentifier();
812
813 cyclicPrint = os.tryStartCyclicPrint(attrOrType: type);
814 if (failed(Result: cyclicPrint)) {
815 os << ">";
816 return;
817 }
818
819 os << ", ";
820 }
821
822 os << "(";
823
824 auto printMember = [&](unsigned i) {
825 os << type.getElementType(i);
826 SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
827 type.getMemberDecorations(i, decorationsInfo&: decorations);
828 if (type.hasOffset() || !decorations.empty()) {
829 os << " [";
830 if (type.hasOffset()) {
831 os << type.getMemberOffset(i);
832 if (!decorations.empty())
833 os << ", ";
834 }
835 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
836 os << stringifyDecoration(decoration.decoration);
837 if (decoration.hasValue) {
838 os << "=" << decoration.decorationValue;
839 }
840 };
841 llvm::interleaveComma(c: decorations, os, each_fn: eachFn);
842 os << "]";
843 }
844 };
845 llvm::interleaveComma(c: llvm::seq<unsigned>(Begin: 0, End: type.getNumElements()), os,
846 each_fn: printMember);
847 os << ")>";
848}
849
850static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
851 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
852 << type.getElementType() << ", " << type.getScope() << ", "
853 << type.getUse() << ">";
854}
855
856static void print(MatrixType type, DialectAsmPrinter &os) {
857 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
858 os << ">";
859}
860
861void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
862 TypeSwitch<Type>(type)
863 .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
864 ImageType, SampledImageType, StructType, MatrixType>(
865 [&](auto type) { print(type, os); })
866 .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
867}
868
869//===----------------------------------------------------------------------===//
870// Constant
871//===----------------------------------------------------------------------===//
872
873Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
874 Attribute value, Type type,
875 Location loc) {
876 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
877 return builder.create<ub::PoisonOp>(loc, type, poison);
878
879 if (!spirv::ConstantOp::isBuildableWith(type))
880 return nullptr;
881
882 return builder.create<spirv::ConstantOp>(loc, type, value);
883}
884
885//===----------------------------------------------------------------------===//
886// Shader Interface ABI
887//===----------------------------------------------------------------------===//
888
889LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
890 NamedAttribute attribute) {
891 StringRef symbol = attribute.getName().strref();
892 Attribute attr = attribute.getValue();
893
894 if (symbol == spirv::getEntryPointABIAttrName()) {
895 if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
896 return op->emitError("'")
897 << symbol << "' attribute must be an entry point ABI attribute";
898 }
899 } else if (symbol == spirv::getTargetEnvAttrName()) {
900 if (!llvm::isa<spirv::TargetEnvAttr>(attr))
901 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
902 } else {
903 return op->emitError("found unsupported '")
904 << symbol << "' attribute on operation";
905 }
906
907 return success();
908}
909
910/// Verifies the given SPIR-V `attribute` attached to a value of the given
911/// `valueType` is valid.
912static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
913 NamedAttribute attribute) {
914 StringRef symbol = attribute.getName().strref();
915 Attribute attr = attribute.getValue();
916
917 if (symbol == spirv::getInterfaceVarABIAttrName()) {
918 auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(Val&: attr);
919 if (!varABIAttr)
920 return emitError(loc, message: "'")
921 << symbol << "' must be a spirv::InterfaceVarABIAttr";
922
923 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
924 return emitError(loc, message: "'") << symbol
925 << "' attribute cannot specify storage class "
926 "when attaching to a non-scalar value";
927 return success();
928 }
929 if (symbol == spirv::DecorationAttr::name) {
930 if (!isa<spirv::DecorationAttr>(attr))
931 return emitError(loc, message: "'")
932 << symbol << "' must be a spirv::DecorationAttr";
933 return success();
934 }
935
936 return emitError(loc, message: "found unsupported '")
937 << symbol << "' attribute on region argument";
938}
939
940LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
941 unsigned regionIndex,
942 unsigned argIndex,
943 NamedAttribute attribute) {
944 auto funcOp = dyn_cast<FunctionOpInterface>(op);
945 if (!funcOp)
946 return success();
947 Type argType = funcOp.getArgumentTypes()[argIndex];
948
949 return verifyRegionAttribute(op->getLoc(), argType, attribute);
950}
951
952LogicalResult SPIRVDialect::verifyRegionResultAttribute(
953 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
954 NamedAttribute attribute) {
955 return op->emitError("cannot attach SPIR-V attributes to region result");
956}
957

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp