1 | //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the SPIR-V dialect in MLIR. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
14 | |
15 | #include "SPIRVParsingUtils.h" |
16 | |
17 | #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
20 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
21 | #include "mlir/Dialect/UB/IR/UBOps.h" |
22 | #include "mlir/IR/Builders.h" |
23 | #include "mlir/IR/BuiltinTypes.h" |
24 | #include "mlir/IR/DialectImplementation.h" |
25 | #include "mlir/IR/MLIRContext.h" |
26 | #include "mlir/Parser/Parser.h" |
27 | #include "mlir/Transforms/InliningUtils.h" |
28 | #include "llvm/ADT/DenseMap.h" |
29 | #include "llvm/ADT/Sequence.h" |
30 | #include "llvm/ADT/SetVector.h" |
31 | #include "llvm/ADT/StringExtras.h" |
32 | #include "llvm/ADT/StringMap.h" |
33 | #include "llvm/ADT/TypeSwitch.h" |
34 | #include "llvm/Support/raw_ostream.h" |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::spirv; |
38 | |
39 | #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc" |
40 | |
41 | //===----------------------------------------------------------------------===// |
42 | // InlinerInterface |
43 | //===----------------------------------------------------------------------===// |
44 | |
45 | /// Returns true if the given region contains spirv.Return or spirv.ReturnValue |
46 | /// ops. |
47 | static inline bool containsReturn(Region ®ion) { |
48 | return llvm::any_of(region, [](Block &block) { |
49 | Operation *terminator = block.getTerminator(); |
50 | return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator); |
51 | }); |
52 | } |
53 | |
54 | namespace { |
55 | /// This class defines the interface for inlining within the SPIR-V dialect. |
56 | struct SPIRVInlinerInterface : public DialectInlinerInterface { |
57 | using DialectInlinerInterface::DialectInlinerInterface; |
58 | |
59 | /// All call operations within SPIRV can be inlined. |
60 | bool isLegalToInline(Operation *call, Operation *callable, |
61 | bool wouldBeCloned) const final { |
62 | return true; |
63 | } |
64 | |
65 | /// Returns true if the given region 'src' can be inlined into the region |
66 | /// 'dest' that is attached to an operation registered to the current dialect. |
67 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
68 | IRMapping &) const final { |
69 | // Return true here when inlining into spirv.func, spirv.mlir.selection, and |
70 | // spirv.mlir.loop operations. |
71 | auto *op = dest->getParentOp(); |
72 | return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op); |
73 | } |
74 | |
75 | /// Returns true if the given operation 'op', that is registered to this |
76 | /// dialect, can be inlined into the region 'dest' that is attached to an |
77 | /// operation registered to the current dialect. |
78 | bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, |
79 | IRMapping &) const final { |
80 | // TODO: Enable inlining structured control flows with return. |
81 | if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) && |
82 | containsReturn(op->getRegion(0))) |
83 | return false; |
84 | // TODO: we need to filter OpKill here to avoid inlining it to |
85 | // a loop continue construct: |
86 | // https://github.com/KhronosGroup/SPIRV-Headers/issues/86 |
87 | // However OpKill is fragment shader specific and we don't support it yet. |
88 | return true; |
89 | } |
90 | |
91 | /// Handle the given inlined terminator by replacing it with a new operation |
92 | /// as necessary. |
93 | void handleTerminator(Operation *op, Block *newDest) const final { |
94 | if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) { |
95 | OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest); |
96 | op->erase(); |
97 | } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) { |
98 | OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest, |
99 | retValOp->getOperands()); |
100 | op->erase(); |
101 | } |
102 | } |
103 | |
104 | /// Handle the given inlined terminator by replacing it with a new operation |
105 | /// as necessary. |
106 | void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { |
107 | // Only spirv.ReturnValue needs to be handled here. |
108 | auto retValOp = dyn_cast<spirv::ReturnValueOp>(op); |
109 | if (!retValOp) |
110 | return; |
111 | |
112 | // Replace the values directly with the return operands. |
113 | assert(valuesToRepl.size() == 1 && |
114 | "spirv.ReturnValue expected to only handle one result" ); |
115 | valuesToRepl.front().replaceAllUsesWith(newValue: retValOp.getValue()); |
116 | } |
117 | }; |
118 | } // namespace |
119 | |
120 | //===----------------------------------------------------------------------===// |
121 | // SPIR-V Dialect |
122 | //===----------------------------------------------------------------------===// |
123 | |
124 | void SPIRVDialect::initialize() { |
125 | registerAttributes(); |
126 | registerTypes(); |
127 | |
128 | // Add SPIR-V ops. |
129 | addOperations< |
130 | #define GET_OP_LIST |
131 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" |
132 | >(); |
133 | |
134 | addInterfaces<SPIRVInlinerInterface>(); |
135 | |
136 | // Allow unknown operations because SPIR-V is extensible. |
137 | allowUnknownOperations(); |
138 | declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>(); |
139 | } |
140 | |
141 | std::string SPIRVDialect::getAttributeName(Decoration decoration) { |
142 | return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)); |
143 | } |
144 | |
145 | //===----------------------------------------------------------------------===// |
146 | // Type Parsing |
147 | //===----------------------------------------------------------------------===// |
148 | |
149 | // Forward declarations. |
150 | template <typename ValTy> |
151 | static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, |
152 | DialectAsmParser &parser); |
153 | template <> |
154 | std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, |
155 | DialectAsmParser &parser); |
156 | |
157 | template <> |
158 | std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, |
159 | DialectAsmParser &parser); |
160 | |
161 | static Type parseAndVerifyType(SPIRVDialect const &dialect, |
162 | DialectAsmParser &parser) { |
163 | Type type; |
164 | SMLoc typeLoc = parser.getCurrentLocation(); |
165 | if (parser.parseType(result&: type)) |
166 | return Type(); |
167 | |
168 | // Allow SPIR-V dialect types |
169 | if (&type.getDialect() == &dialect) |
170 | return type; |
171 | |
172 | // Check other allowed types |
173 | if (auto t = llvm::dyn_cast<FloatType>(Val&: type)) { |
174 | if (type.isBF16()) { |
175 | parser.emitError(loc: typeLoc, message: "cannot use 'bf16' to compose SPIR-V types" ); |
176 | return Type(); |
177 | } |
178 | } else if (auto t = llvm::dyn_cast<IntegerType>(type)) { |
179 | if (!ScalarType::isValid(t)) { |
180 | parser.emitError(loc: typeLoc, |
181 | message: "only 1/8/16/32/64-bit integer type allowed but found " ) |
182 | << type; |
183 | return Type(); |
184 | } |
185 | } else if (auto t = llvm::dyn_cast<VectorType>(type)) { |
186 | if (t.getRank() != 1) { |
187 | parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found " ) << t; |
188 | return Type(); |
189 | } |
190 | if (t.getNumElements() > 4) { |
191 | parser.emitError( |
192 | loc: typeLoc, message: "vector length has to be less than or equal to 4 but found " ) |
193 | << t.getNumElements(); |
194 | return Type(); |
195 | } |
196 | } else { |
197 | parser.emitError(loc: typeLoc, message: "cannot use " ) |
198 | << type << " to compose SPIR-V types" ; |
199 | return Type(); |
200 | } |
201 | |
202 | return type; |
203 | } |
204 | |
205 | static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, |
206 | DialectAsmParser &parser) { |
207 | Type type; |
208 | SMLoc typeLoc = parser.getCurrentLocation(); |
209 | if (parser.parseType(result&: type)) |
210 | return Type(); |
211 | |
212 | if (auto t = llvm::dyn_cast<VectorType>(type)) { |
213 | if (t.getRank() != 1) { |
214 | parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found " ) << t; |
215 | return Type(); |
216 | } |
217 | if (t.getNumElements() > 4 || t.getNumElements() < 2) { |
218 | parser.emitError(loc: typeLoc, |
219 | message: "matrix columns size has to be less than or equal " |
220 | "to 4 and greater than or equal 2, but found " ) |
221 | << t.getNumElements(); |
222 | return Type(); |
223 | } |
224 | |
225 | if (!llvm::isa<FloatType>(t.getElementType())) { |
226 | parser.emitError(loc: typeLoc, message: "matrix columns' elements must be of " |
227 | "Float type, got " ) |
228 | << t.getElementType(); |
229 | return Type(); |
230 | } |
231 | } else { |
232 | parser.emitError(loc: typeLoc, message: "matrix must be composed using vector " |
233 | "type, got " ) |
234 | << type; |
235 | return Type(); |
236 | } |
237 | |
238 | return type; |
239 | } |
240 | |
241 | static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, |
242 | DialectAsmParser &parser) { |
243 | Type type; |
244 | SMLoc typeLoc = parser.getCurrentLocation(); |
245 | if (parser.parseType(result&: type)) |
246 | return Type(); |
247 | |
248 | if (!llvm::isa<ImageType>(Val: type)) { |
249 | parser.emitError(loc: typeLoc, |
250 | message: "sampled image must be composed using image type, got " ) |
251 | << type; |
252 | return Type(); |
253 | } |
254 | |
255 | return type; |
256 | } |
257 | |
258 | /// Parses an optional `, stride = N` assembly segment. If no parsing failure |
259 | /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if |
260 | /// missing. |
261 | static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, |
262 | DialectAsmParser &parser, |
263 | unsigned &stride) { |
264 | if (failed(result: parser.parseOptionalComma())) { |
265 | stride = 0; |
266 | return success(); |
267 | } |
268 | |
269 | if (parser.parseKeyword(keyword: "stride" ) || parser.parseEqual()) |
270 | return failure(); |
271 | |
272 | SMLoc strideLoc = parser.getCurrentLocation(); |
273 | std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser); |
274 | if (!optStride) |
275 | return failure(); |
276 | |
277 | if (!(stride = *optStride)) { |
278 | parser.emitError(loc: strideLoc, message: "ArrayStride must be greater than zero" ); |
279 | return failure(); |
280 | } |
281 | return success(); |
282 | } |
283 | |
284 | // element-type ::= integer-type |
285 | // | floating-point-type |
286 | // | vector-type |
287 | // | spirv-type |
288 | // |
289 | // array-type ::= `!spirv.array` `<` integer-literal `x` element-type |
290 | // (`,` `stride` `=` integer-literal)? `>` |
291 | static Type parseArrayType(SPIRVDialect const &dialect, |
292 | DialectAsmParser &parser) { |
293 | if (parser.parseLess()) |
294 | return Type(); |
295 | |
296 | SmallVector<int64_t, 1> countDims; |
297 | SMLoc countLoc = parser.getCurrentLocation(); |
298 | if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false)) |
299 | return Type(); |
300 | if (countDims.size() != 1) { |
301 | parser.emitError(loc: countLoc, |
302 | message: "expected single integer for array element count" ); |
303 | return Type(); |
304 | } |
305 | |
306 | // According to the SPIR-V spec: |
307 | // "Length is the number of elements in the array. It must be at least 1." |
308 | int64_t count = countDims[0]; |
309 | if (count == 0) { |
310 | parser.emitError(loc: countLoc, message: "expected array length greater than 0" ); |
311 | return Type(); |
312 | } |
313 | |
314 | Type elementType = parseAndVerifyType(dialect, parser); |
315 | if (!elementType) |
316 | return Type(); |
317 | |
318 | unsigned stride = 0; |
319 | if (failed(parseOptionalArrayStride(dialect, parser, stride))) |
320 | return Type(); |
321 | |
322 | if (parser.parseGreater()) |
323 | return Type(); |
324 | return ArrayType::get(elementType, elementCount: count, stride); |
325 | } |
326 | |
327 | // cooperative-matrix-type ::= |
328 | // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,` |
329 | // scope `,` use `>` |
330 | static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, |
331 | DialectAsmParser &parser) { |
332 | if (parser.parseLess()) |
333 | return {}; |
334 | |
335 | SmallVector<int64_t, 2> dims; |
336 | SMLoc countLoc = parser.getCurrentLocation(); |
337 | if (parser.parseDimensionList(dimensions&: dims, /*allowDynamic=*/false)) |
338 | return {}; |
339 | |
340 | if (dims.size() != 2) { |
341 | parser.emitError(loc: countLoc, message: "expected row and column count" ); |
342 | return {}; |
343 | } |
344 | |
345 | auto elementTy = parseAndVerifyType(dialect, parser); |
346 | if (!elementTy) |
347 | return {}; |
348 | |
349 | Scope scope; |
350 | if (parser.parseComma() || |
351 | spirv::parseEnumKeywordAttr(scope, parser, "scope <id>" )) |
352 | return {}; |
353 | |
354 | CooperativeMatrixUseKHR use; |
355 | if (parser.parseComma() || |
356 | spirv::parseEnumKeywordAttr(use, parser, "use <id>" )) |
357 | return {}; |
358 | |
359 | if (parser.parseGreater()) |
360 | return {}; |
361 | |
362 | return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use); |
363 | } |
364 | |
365 | // joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x` |
366 | // element-type |
367 | // `,` layout `,` scope`>` |
368 | static Type parseJointMatrixType(SPIRVDialect const &dialect, |
369 | DialectAsmParser &parser) { |
370 | if (parser.parseLess()) |
371 | return Type(); |
372 | |
373 | SmallVector<int64_t, 2> dims; |
374 | SMLoc countLoc = parser.getCurrentLocation(); |
375 | if (parser.parseDimensionList(dimensions&: dims, /*allowDynamic=*/false)) |
376 | return Type(); |
377 | |
378 | if (dims.size() != 2) { |
379 | parser.emitError(loc: countLoc, message: "expected rows and columns size" ); |
380 | return Type(); |
381 | } |
382 | |
383 | auto elementTy = parseAndVerifyType(dialect, parser); |
384 | if (!elementTy) |
385 | return Type(); |
386 | MatrixLayout matrixLayout; |
387 | if (parser.parseComma() || |
388 | spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>" )) |
389 | return Type(); |
390 | Scope scope; |
391 | if (parser.parseComma() || |
392 | spirv::parseEnumKeywordAttr(scope, parser, "scope <id>" )) |
393 | return Type(); |
394 | if (parser.parseGreater()) |
395 | return Type(); |
396 | return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1], |
397 | matrixLayout); |
398 | } |
399 | |
400 | // TODO: Reorder methods to be utilities first and parse*Type |
401 | // methods in alphabetical order |
402 | // |
403 | // storage-class ::= `UniformConstant` |
404 | // | `Uniform` |
405 | // | `Workgroup` |
406 | // | <and other storage classes...> |
407 | // |
408 | // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>` |
409 | static Type parsePointerType(SPIRVDialect const &dialect, |
410 | DialectAsmParser &parser) { |
411 | if (parser.parseLess()) |
412 | return Type(); |
413 | |
414 | auto pointeeType = parseAndVerifyType(dialect, parser); |
415 | if (!pointeeType) |
416 | return Type(); |
417 | |
418 | StringRef storageClassSpec; |
419 | SMLoc storageClassLoc = parser.getCurrentLocation(); |
420 | if (parser.parseComma() || parser.parseKeyword(keyword: &storageClassSpec)) |
421 | return Type(); |
422 | |
423 | auto storageClass = symbolizeStorageClass(storageClassSpec); |
424 | if (!storageClass) { |
425 | parser.emitError(loc: storageClassLoc, message: "unknown storage class: " ) |
426 | << storageClassSpec; |
427 | return Type(); |
428 | } |
429 | if (parser.parseGreater()) |
430 | return Type(); |
431 | return PointerType::get(pointeeType, *storageClass); |
432 | } |
433 | |
434 | // runtime-array-type ::= `!spirv.rtarray` `<` element-type |
435 | // (`,` `stride` `=` integer-literal)? `>` |
436 | static Type parseRuntimeArrayType(SPIRVDialect const &dialect, |
437 | DialectAsmParser &parser) { |
438 | if (parser.parseLess()) |
439 | return Type(); |
440 | |
441 | Type elementType = parseAndVerifyType(dialect, parser); |
442 | if (!elementType) |
443 | return Type(); |
444 | |
445 | unsigned stride = 0; |
446 | if (failed(parseOptionalArrayStride(dialect, parser, stride))) |
447 | return Type(); |
448 | |
449 | if (parser.parseGreater()) |
450 | return Type(); |
451 | return RuntimeArrayType::get(elementType, stride); |
452 | } |
453 | |
454 | // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>` |
455 | static Type parseMatrixType(SPIRVDialect const &dialect, |
456 | DialectAsmParser &parser) { |
457 | if (parser.parseLess()) |
458 | return Type(); |
459 | |
460 | SmallVector<int64_t, 1> countDims; |
461 | SMLoc countLoc = parser.getCurrentLocation(); |
462 | if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false)) |
463 | return Type(); |
464 | if (countDims.size() != 1) { |
465 | parser.emitError(loc: countLoc, message: "expected single unsigned " |
466 | "integer for number of columns" ); |
467 | return Type(); |
468 | } |
469 | |
470 | int64_t columnCount = countDims[0]; |
471 | // According to the specification, Matrices can have 2, 3, or 4 columns |
472 | if (columnCount < 2 || columnCount > 4) { |
473 | parser.emitError(loc: countLoc, message: "matrix is expected to have 2, 3, or 4 " |
474 | "columns" ); |
475 | return Type(); |
476 | } |
477 | |
478 | Type columnType = parseAndVerifyMatrixType(dialect, parser); |
479 | if (!columnType) |
480 | return Type(); |
481 | |
482 | if (parser.parseGreater()) |
483 | return Type(); |
484 | |
485 | return MatrixType::get(columnType, columnCount); |
486 | } |
487 | |
488 | // Specialize this function to parse each of the parameters that define an |
489 | // ImageType. By default it assumes this is an enum type. |
490 | template <typename ValTy> |
491 | static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, |
492 | DialectAsmParser &parser) { |
493 | StringRef enumSpec; |
494 | SMLoc enumLoc = parser.getCurrentLocation(); |
495 | if (parser.parseKeyword(keyword: &enumSpec)) { |
496 | return std::nullopt; |
497 | } |
498 | |
499 | auto val = spirv::symbolizeEnum<ValTy>(enumSpec); |
500 | if (!val) |
501 | parser.emitError(loc: enumLoc, message: "unknown attribute: '" ) << enumSpec << "'" ; |
502 | return val; |
503 | } |
504 | |
505 | template <> |
506 | std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, |
507 | DialectAsmParser &parser) { |
508 | // TODO: Further verify that the element type can be sampled |
509 | auto ty = parseAndVerifyType(dialect, parser); |
510 | if (!ty) |
511 | return std::nullopt; |
512 | return ty; |
513 | } |
514 | |
515 | template <typename IntTy> |
516 | static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect, |
517 | DialectAsmParser &parser) { |
518 | IntTy offsetVal = std::numeric_limits<IntTy>::max(); |
519 | if (parser.parseInteger(offsetVal)) |
520 | return std::nullopt; |
521 | return offsetVal; |
522 | } |
523 | |
524 | template <> |
525 | std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, |
526 | DialectAsmParser &parser) { |
527 | return parseAndVerifyInteger<unsigned>(dialect, parser); |
528 | } |
529 | |
530 | namespace { |
531 | // Functor object to parse a comma separated list of specs. The function |
532 | // parseAndVerify does the actual parsing and verification of individual |
533 | // elements. This is a functor since parsing the last element of the list |
534 | // (termination condition) needs partial specialization. |
535 | template <typename ParseType, typename... Args> |
536 | struct ParseCommaSeparatedList { |
537 | std::optional<std::tuple<ParseType, Args...>> |
538 | operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { |
539 | auto parseVal = parseAndVerify<ParseType>(dialect, parser); |
540 | if (!parseVal) |
541 | return std::nullopt; |
542 | |
543 | auto numArgs = std::tuple_size<std::tuple<Args...>>::value; |
544 | if (numArgs != 0 && failed(result: parser.parseComma())) |
545 | return std::nullopt; |
546 | auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser); |
547 | if (!remainingValues) |
548 | return std::nullopt; |
549 | return std::tuple_cat(std::tuple<ParseType>(parseVal.value()), |
550 | remainingValues.value()); |
551 | } |
552 | }; |
553 | |
554 | // Partial specialization of the function to parse a comma separated list of |
555 | // specs to parse the last element of the list. |
556 | template <typename ParseType> |
557 | struct ParseCommaSeparatedList<ParseType> { |
558 | std::optional<std::tuple<ParseType>> |
559 | operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { |
560 | if (auto value = parseAndVerify<ParseType>(dialect, parser)) |
561 | return std::tuple<ParseType>(*value); |
562 | return std::nullopt; |
563 | } |
564 | }; |
565 | } // namespace |
566 | |
567 | // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...> |
568 | // |
569 | // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` |
570 | // |
571 | // arrayed-info ::= `NonArrayed` | `Arrayed` |
572 | // |
573 | // sampling-info ::= `SingleSampled` | `MultiSampled` |
574 | // |
575 | // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` |
576 | // |
577 | // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...> |
578 | // |
579 | // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,` |
580 | // arrayed-info `,` sampling-info `,` |
581 | // sampler-use-info `,` format `>` |
582 | static Type parseImageType(SPIRVDialect const &dialect, |
583 | DialectAsmParser &parser) { |
584 | if (parser.parseLess()) |
585 | return Type(); |
586 | |
587 | auto value = |
588 | ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo, |
589 | ImageSamplingInfo, ImageSamplerUseInfo, |
590 | ImageFormat>{}(dialect, parser); |
591 | if (!value) |
592 | return Type(); |
593 | |
594 | if (parser.parseGreater()) |
595 | return Type(); |
596 | return ImageType::get(*value); |
597 | } |
598 | |
599 | // sampledImage-type :: = `!spirv.sampledImage<` image-type `>` |
600 | static Type parseSampledImageType(SPIRVDialect const &dialect, |
601 | DialectAsmParser &parser) { |
602 | if (parser.parseLess()) |
603 | return Type(); |
604 | |
605 | Type parsedType = parseAndVerifySampledImageType(dialect, parser); |
606 | if (!parsedType) |
607 | return Type(); |
608 | |
609 | if (parser.parseGreater()) |
610 | return Type(); |
611 | return SampledImageType::get(imageType: parsedType); |
612 | } |
613 | |
614 | // Parse decorations associated with a member. |
615 | static ParseResult parseStructMemberDecorations( |
616 | SPIRVDialect const &dialect, DialectAsmParser &parser, |
617 | ArrayRef<Type> memberTypes, |
618 | SmallVectorImpl<StructType::OffsetInfo> &offsetInfo, |
619 | SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) { |
620 | |
621 | // Check if the first element is offset. |
622 | SMLoc offsetLoc = parser.getCurrentLocation(); |
623 | StructType::OffsetInfo offset = 0; |
624 | OptionalParseResult offsetParseResult = parser.parseOptionalInteger(result&: offset); |
625 | if (offsetParseResult.has_value()) { |
626 | if (failed(result: *offsetParseResult)) |
627 | return failure(); |
628 | |
629 | if (offsetInfo.size() != memberTypes.size() - 1) { |
630 | return parser.emitError(loc: offsetLoc, |
631 | message: "offset specification must be given for " |
632 | "all members" ); |
633 | } |
634 | offsetInfo.push_back(Elt: offset); |
635 | } |
636 | |
637 | // Check for no spirv::Decorations. |
638 | if (succeeded(result: parser.parseOptionalRSquare())) |
639 | return success(); |
640 | |
641 | // If there was an offset, make sure to parse the comma. |
642 | if (offsetParseResult.has_value() && parser.parseComma()) |
643 | return failure(); |
644 | |
645 | // Check for spirv::Decorations. |
646 | auto parseDecorations = [&]() { |
647 | auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser); |
648 | if (!memberDecoration) |
649 | return failure(); |
650 | |
651 | // Parse member decoration value if it exists. |
652 | if (succeeded(result: parser.parseOptionalEqual())) { |
653 | auto memberDecorationValue = |
654 | parseAndVerifyInteger<uint32_t>(dialect, parser); |
655 | |
656 | if (!memberDecorationValue) |
657 | return failure(); |
658 | |
659 | memberDecorationInfo.emplace_back( |
660 | static_cast<uint32_t>(memberTypes.size() - 1), 1, |
661 | memberDecoration.value(), memberDecorationValue.value()); |
662 | } else { |
663 | memberDecorationInfo.emplace_back( |
664 | static_cast<uint32_t>(memberTypes.size() - 1), 0, |
665 | memberDecoration.value(), 0); |
666 | } |
667 | return success(); |
668 | }; |
669 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: parseDecorations)) || |
670 | failed(result: parser.parseRSquare())) |
671 | return failure(); |
672 | |
673 | return success(); |
674 | } |
675 | |
676 | // struct-member-decoration ::= integer-literal? spirv-decoration* |
677 | // struct-type ::= |
678 | // `!spirv.struct<` (id `,`)? |
679 | // `(` |
680 | // (spirv-type (`[` struct-member-decoration `]`)?)* |
681 | // `)>` |
682 | static Type parseStructType(SPIRVDialect const &dialect, |
683 | DialectAsmParser &parser) { |
684 | // TODO: This function is quite lengthy. Break it down into smaller chunks. |
685 | |
686 | if (parser.parseLess()) |
687 | return Type(); |
688 | |
689 | StringRef identifier; |
690 | FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse; |
691 | |
692 | // Check if this is an identified struct type. |
693 | if (succeeded(result: parser.parseOptionalKeyword(keyword: &identifier))) { |
694 | // Check if this is a possible recursive reference. |
695 | auto structType = |
696 | StructType::getIdentified(context: dialect.getContext(), identifier); |
697 | cyclicParse = parser.tryStartCyclicParse(structType); |
698 | if (succeeded(result: parser.parseOptionalGreater())) { |
699 | if (succeeded(result: cyclicParse)) { |
700 | parser.emitError( |
701 | loc: parser.getNameLoc(), |
702 | message: "recursive struct reference not nested in struct definition" ); |
703 | |
704 | return Type(); |
705 | } |
706 | |
707 | return structType; |
708 | } |
709 | |
710 | if (failed(result: parser.parseComma())) |
711 | return Type(); |
712 | |
713 | if (failed(result: cyclicParse)) { |
714 | parser.emitError(loc: parser.getNameLoc(), |
715 | message: "identifier already used for an enclosing struct" ); |
716 | return Type(); |
717 | } |
718 | } |
719 | |
720 | if (failed(result: parser.parseLParen())) |
721 | return Type(); |
722 | |
723 | if (succeeded(result: parser.parseOptionalRParen()) && |
724 | succeeded(result: parser.parseOptionalGreater())) { |
725 | return StructType::getEmpty(context: dialect.getContext(), identifier); |
726 | } |
727 | |
728 | StructType idStructTy; |
729 | |
730 | if (!identifier.empty()) |
731 | idStructTy = StructType::getIdentified(context: dialect.getContext(), identifier); |
732 | |
733 | SmallVector<Type, 4> memberTypes; |
734 | SmallVector<StructType::OffsetInfo, 4> offsetInfo; |
735 | SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo; |
736 | |
737 | do { |
738 | Type memberType; |
739 | if (parser.parseType(result&: memberType)) |
740 | return Type(); |
741 | memberTypes.push_back(Elt: memberType); |
742 | |
743 | if (succeeded(result: parser.parseOptionalLSquare())) |
744 | if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, |
745 | memberDecorationInfo)) |
746 | return Type(); |
747 | } while (succeeded(result: parser.parseOptionalComma())); |
748 | |
749 | if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { |
750 | parser.emitError(loc: parser.getNameLoc(), |
751 | message: "offset specification must be given for all members" ); |
752 | return Type(); |
753 | } |
754 | |
755 | if (failed(result: parser.parseRParen()) || failed(result: parser.parseGreater())) |
756 | return Type(); |
757 | |
758 | if (!identifier.empty()) { |
759 | if (failed(result: idStructTy.trySetBody(memberTypes, offsetInfo, |
760 | memberDecorations: memberDecorationInfo))) |
761 | return Type(); |
762 | return idStructTy; |
763 | } |
764 | |
765 | return StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationInfo); |
766 | } |
767 | |
768 | // spirv-type ::= array-type |
769 | // | element-type |
770 | // | image-type |
771 | // | pointer-type |
772 | // | runtime-array-type |
773 | // | sampled-image-type |
774 | // | struct-type |
775 | Type SPIRVDialect::parseType(DialectAsmParser &parser) const { |
776 | StringRef keyword; |
777 | if (parser.parseKeyword(&keyword)) |
778 | return Type(); |
779 | |
780 | if (keyword == "array" ) |
781 | return parseArrayType(*this, parser); |
782 | if (keyword == "coopmatrix" ) |
783 | return parseCooperativeMatrixType(*this, parser); |
784 | if (keyword == "jointmatrix" ) |
785 | return parseJointMatrixType(*this, parser); |
786 | if (keyword == "image" ) |
787 | return parseImageType(*this, parser); |
788 | if (keyword == "ptr" ) |
789 | return parsePointerType(*this, parser); |
790 | if (keyword == "rtarray" ) |
791 | return parseRuntimeArrayType(*this, parser); |
792 | if (keyword == "sampled_image" ) |
793 | return parseSampledImageType(*this, parser); |
794 | if (keyword == "struct" ) |
795 | return parseStructType(*this, parser); |
796 | if (keyword == "matrix" ) |
797 | return parseMatrixType(*this, parser); |
798 | parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: " ) << keyword; |
799 | return Type(); |
800 | } |
801 | |
802 | //===----------------------------------------------------------------------===// |
803 | // Type Printing |
804 | //===----------------------------------------------------------------------===// |
805 | |
806 | static void print(ArrayType type, DialectAsmPrinter &os) { |
807 | os << "array<" << type.getNumElements() << " x " << type.getElementType(); |
808 | if (unsigned stride = type.getArrayStride()) |
809 | os << ", stride=" << stride; |
810 | os << ">" ; |
811 | } |
812 | |
813 | static void print(RuntimeArrayType type, DialectAsmPrinter &os) { |
814 | os << "rtarray<" << type.getElementType(); |
815 | if (unsigned stride = type.getArrayStride()) |
816 | os << ", stride=" << stride; |
817 | os << ">" ; |
818 | } |
819 | |
820 | static void print(PointerType type, DialectAsmPrinter &os) { |
821 | os << "ptr<" << type.getPointeeType() << ", " |
822 | << stringifyStorageClass(type.getStorageClass()) << ">" ; |
823 | } |
824 | |
825 | static void print(ImageType type, DialectAsmPrinter &os) { |
826 | os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim()) |
827 | << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", " |
828 | << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " |
829 | << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", " |
830 | << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", " |
831 | << stringifyImageFormat(type.getImageFormat()) << ">" ; |
832 | } |
833 | |
834 | static void print(SampledImageType type, DialectAsmPrinter &os) { |
835 | os << "sampled_image<" << type.getImageType() << ">" ; |
836 | } |
837 | |
838 | static void print(StructType type, DialectAsmPrinter &os) { |
839 | FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint; |
840 | |
841 | os << "struct<" ; |
842 | |
843 | if (type.isIdentified()) { |
844 | os << type.getIdentifier(); |
845 | |
846 | cyclicPrint = os.tryStartCyclicPrint(attrOrType: type); |
847 | if (failed(result: cyclicPrint)) { |
848 | os << ">" ; |
849 | return; |
850 | } |
851 | |
852 | os << ", " ; |
853 | } |
854 | |
855 | os << "(" ; |
856 | |
857 | auto printMember = [&](unsigned i) { |
858 | os << type.getElementType(i); |
859 | SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations; |
860 | type.getMemberDecorations(i, decorationsInfo&: decorations); |
861 | if (type.hasOffset() || !decorations.empty()) { |
862 | os << " [" ; |
863 | if (type.hasOffset()) { |
864 | os << type.getMemberOffset(i); |
865 | if (!decorations.empty()) |
866 | os << ", " ; |
867 | } |
868 | auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { |
869 | os << stringifyDecoration(decoration.decoration); |
870 | if (decoration.hasValue) { |
871 | os << "=" << decoration.decorationValue; |
872 | } |
873 | }; |
874 | llvm::interleaveComma(c: decorations, os, each_fn: eachFn); |
875 | os << "]" ; |
876 | } |
877 | }; |
878 | llvm::interleaveComma(c: llvm::seq<unsigned>(Begin: 0, End: type.getNumElements()), os, |
879 | each_fn: printMember); |
880 | os << ")>" ; |
881 | } |
882 | |
883 | static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { |
884 | os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x" |
885 | << type.getElementType() << ", " << type.getScope() << ", " |
886 | << type.getUse() << ">" ; |
887 | } |
888 | |
889 | static void print(JointMatrixINTELType type, DialectAsmPrinter &os) { |
890 | os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x" ; |
891 | os << type.getElementType() << ", " |
892 | << stringifyMatrixLayout(type.getMatrixLayout()); |
893 | os << ", " << stringifyScope(type.getScope()) << ">" ; |
894 | } |
895 | |
896 | static void print(MatrixType type, DialectAsmPrinter &os) { |
897 | os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType(); |
898 | os << ">" ; |
899 | } |
900 | |
901 | void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { |
902 | TypeSwitch<Type>(type) |
903 | .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, PointerType, |
904 | RuntimeArrayType, ImageType, SampledImageType, StructType, |
905 | MatrixType>([&](auto type) { print(type, os); }) |
906 | .Default([](Type) { llvm_unreachable("unhandled SPIR-V type" ); }); |
907 | } |
908 | |
909 | //===----------------------------------------------------------------------===// |
910 | // Constant |
911 | //===----------------------------------------------------------------------===// |
912 | |
913 | Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, |
914 | Attribute value, Type type, |
915 | Location loc) { |
916 | if (auto poison = dyn_cast<ub::PoisonAttr>(value)) |
917 | return builder.create<ub::PoisonOp>(loc, type, poison); |
918 | |
919 | if (!spirv::ConstantOp::isBuildableWith(type)) |
920 | return nullptr; |
921 | |
922 | return builder.create<spirv::ConstantOp>(loc, type, value); |
923 | } |
924 | |
925 | //===----------------------------------------------------------------------===// |
926 | // Shader Interface ABI |
927 | //===----------------------------------------------------------------------===// |
928 | |
929 | LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, |
930 | NamedAttribute attribute) { |
931 | StringRef symbol = attribute.getName().strref(); |
932 | Attribute attr = attribute.getValue(); |
933 | |
934 | if (symbol == spirv::getEntryPointABIAttrName()) { |
935 | if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) { |
936 | return op->emitError("'" ) |
937 | << symbol << "' attribute must be an entry point ABI attribute" ; |
938 | } |
939 | } else if (symbol == spirv::getTargetEnvAttrName()) { |
940 | if (!llvm::isa<spirv::TargetEnvAttr>(attr)) |
941 | return op->emitError("'" ) << symbol << "' must be a spirv::TargetEnvAttr" ; |
942 | } else { |
943 | return op->emitError("found unsupported '" ) |
944 | << symbol << "' attribute on operation" ; |
945 | } |
946 | |
947 | return success(); |
948 | } |
949 | |
950 | /// Verifies the given SPIR-V `attribute` attached to a value of the given |
951 | /// `valueType` is valid. |
952 | static LogicalResult verifyRegionAttribute(Location loc, Type valueType, |
953 | NamedAttribute attribute) { |
954 | StringRef symbol = attribute.getName().strref(); |
955 | Attribute attr = attribute.getValue(); |
956 | |
957 | if (symbol == spirv::getInterfaceVarABIAttrName()) { |
958 | auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(Val&: attr); |
959 | if (!varABIAttr) |
960 | return emitError(loc, message: "'" ) |
961 | << symbol << "' must be a spirv::InterfaceVarABIAttr" ; |
962 | |
963 | if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) |
964 | return emitError(loc, message: "'" ) << symbol |
965 | << "' attribute cannot specify storage class " |
966 | "when attaching to a non-scalar value" ; |
967 | return success(); |
968 | } |
969 | if (symbol == spirv::DecorationAttr::name) { |
970 | if (!isa<spirv::DecorationAttr>(attr)) |
971 | return emitError(loc, message: "'" ) |
972 | << symbol << "' must be a spirv::DecorationAttr" ; |
973 | return success(); |
974 | } |
975 | |
976 | return emitError(loc, message: "found unsupported '" ) |
977 | << symbol << "' attribute on region argument" ; |
978 | } |
979 | |
980 | LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, |
981 | unsigned regionIndex, |
982 | unsigned argIndex, |
983 | NamedAttribute attribute) { |
984 | auto funcOp = dyn_cast<FunctionOpInterface>(op); |
985 | if (!funcOp) |
986 | return success(); |
987 | Type argType = funcOp.getArgumentTypes()[argIndex]; |
988 | |
989 | return verifyRegionAttribute(op->getLoc(), argType, attribute); |
990 | } |
991 | |
992 | LogicalResult SPIRVDialect::verifyRegionResultAttribute( |
993 | Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, |
994 | NamedAttribute attribute) { |
995 | return op->emitError("cannot attach SPIR-V attributes to region result" ); |
996 | } |
997 | |