1 | //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the SPIR-V dialect in MLIR. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
14 | |
15 | #include "SPIRVParsingUtils.h" |
16 | |
17 | #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
20 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
21 | #include "mlir/Dialect/UB/IR/UBOps.h" |
22 | #include "mlir/IR/Builders.h" |
23 | #include "mlir/IR/BuiltinTypes.h" |
24 | #include "mlir/IR/DialectImplementation.h" |
25 | #include "mlir/IR/MLIRContext.h" |
26 | #include "mlir/Parser/Parser.h" |
27 | #include "mlir/Transforms/InliningUtils.h" |
28 | #include "llvm/ADT/DenseMap.h" |
29 | #include "llvm/ADT/Sequence.h" |
30 | #include "llvm/ADT/SetVector.h" |
31 | #include "llvm/ADT/StringExtras.h" |
32 | #include "llvm/ADT/StringMap.h" |
33 | #include "llvm/ADT/TypeSwitch.h" |
34 | #include "llvm/Support/raw_ostream.h" |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::spirv; |
38 | |
39 | #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc" |
40 | |
41 | //===----------------------------------------------------------------------===// |
42 | // InlinerInterface |
43 | //===----------------------------------------------------------------------===// |
44 | |
45 | /// Returns true if the given region contains spirv.Return or spirv.ReturnValue |
46 | /// ops. |
47 | static inline bool containsReturn(Region ®ion) { |
48 | return llvm::any_of(region, [](Block &block) { |
49 | Operation *terminator = block.getTerminator(); |
50 | return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator); |
51 | }); |
52 | } |
53 | |
54 | namespace { |
55 | /// This class defines the interface for inlining within the SPIR-V dialect. |
56 | struct SPIRVInlinerInterface : public DialectInlinerInterface { |
57 | using DialectInlinerInterface::DialectInlinerInterface; |
58 | |
59 | /// All call operations within SPIRV can be inlined. |
60 | bool isLegalToInline(Operation *call, Operation *callable, |
61 | bool wouldBeCloned) const final { |
62 | return true; |
63 | } |
64 | |
65 | /// Returns true if the given region 'src' can be inlined into the region |
66 | /// 'dest' that is attached to an operation registered to the current dialect. |
67 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
68 | IRMapping &) const final { |
69 | // Return true here when inlining into spirv.func, spirv.mlir.selection, and |
70 | // spirv.mlir.loop operations. |
71 | auto *op = dest->getParentOp(); |
72 | return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op); |
73 | } |
74 | |
75 | /// Returns true if the given operation 'op', that is registered to this |
76 | /// dialect, can be inlined into the region 'dest' that is attached to an |
77 | /// operation registered to the current dialect. |
78 | bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, |
79 | IRMapping &) const final { |
80 | // TODO: Enable inlining structured control flows with return. |
81 | if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) && |
82 | containsReturn(op->getRegion(0))) |
83 | return false; |
84 | // TODO: we need to filter OpKill here to avoid inlining it to |
85 | // a loop continue construct: |
86 | // https://github.com/KhronosGroup/SPIRV-Headers/issues/86 |
87 | // For now, we just disallow inlining OpKill anywhere in the code, |
88 | // but this restriction should be relaxed, as pointed above. |
89 | if (isa<spirv::KillOp>(op)) |
90 | return false; |
91 | |
92 | return true; |
93 | } |
94 | |
95 | /// Handle the given inlined terminator by replacing it with a new operation |
96 | /// as necessary. |
97 | void handleTerminator(Operation *op, Block *newDest) const final { |
98 | if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) { |
99 | OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest); |
100 | op->erase(); |
101 | } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) { |
102 | OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest, |
103 | retValOp->getOperands()); |
104 | op->erase(); |
105 | } |
106 | } |
107 | |
108 | /// Handle the given inlined terminator by replacing it with a new operation |
109 | /// as necessary. |
110 | void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { |
111 | // Only spirv.ReturnValue needs to be handled here. |
112 | auto retValOp = dyn_cast<spirv::ReturnValueOp>(op); |
113 | if (!retValOp) |
114 | return; |
115 | |
116 | // Replace the values directly with the return operands. |
117 | assert(valuesToRepl.size() == 1 && |
118 | "spirv.ReturnValue expected to only handle one result"); |
119 | valuesToRepl.front().replaceAllUsesWith(newValue: retValOp.getValue()); |
120 | } |
121 | }; |
122 | } // namespace |
123 | |
124 | //===----------------------------------------------------------------------===// |
125 | // SPIR-V Dialect |
126 | //===----------------------------------------------------------------------===// |
127 | |
128 | void SPIRVDialect::initialize() { |
129 | registerAttributes(); |
130 | registerTypes(); |
131 | |
132 | // Add SPIR-V ops. |
133 | addOperations< |
134 | #define GET_OP_LIST |
135 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" |
136 | >(); |
137 | |
138 | addInterfaces<SPIRVInlinerInterface>(); |
139 | |
140 | // Allow unknown operations because SPIR-V is extensible. |
141 | allowUnknownOperations(); |
142 | declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>(); |
143 | } |
144 | |
145 | std::string SPIRVDialect::getAttributeName(Decoration decoration) { |
146 | return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)); |
147 | } |
148 | |
149 | //===----------------------------------------------------------------------===// |
150 | // Type Parsing |
151 | //===----------------------------------------------------------------------===// |
152 | |
153 | // Forward declarations. |
154 | template <typename ValTy> |
155 | static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, |
156 | DialectAsmParser &parser); |
157 | template <> |
158 | std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, |
159 | DialectAsmParser &parser); |
160 | |
161 | template <> |
162 | std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, |
163 | DialectAsmParser &parser); |
164 | |
165 | static Type parseAndVerifyType(SPIRVDialect const &dialect, |
166 | DialectAsmParser &parser) { |
167 | Type type; |
168 | SMLoc typeLoc = parser.getCurrentLocation(); |
169 | if (parser.parseType(result&: type)) |
170 | return Type(); |
171 | |
172 | // Allow SPIR-V dialect types |
173 | if (&type.getDialect() == &dialect) |
174 | return type; |
175 | |
176 | // Check other allowed types |
177 | if (auto t = llvm::dyn_cast<FloatType>(type)) { |
178 | if (type.isBF16()) { |
179 | parser.emitError(loc: typeLoc, message: "cannot use 'bf16' to compose SPIR-V types"); |
180 | return Type(); |
181 | } |
182 | } else if (auto t = llvm::dyn_cast<IntegerType>(type)) { |
183 | if (!ScalarType::isValid(t)) { |
184 | parser.emitError(loc: typeLoc, |
185 | message: "only 1/8/16/32/64-bit integer type allowed but found ") |
186 | << type; |
187 | return Type(); |
188 | } |
189 | } else if (auto t = llvm::dyn_cast<VectorType>(type)) { |
190 | if (t.getRank() != 1) { |
191 | parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found ") << t; |
192 | return Type(); |
193 | } |
194 | if (t.getNumElements() > 4) { |
195 | parser.emitError( |
196 | loc: typeLoc, message: "vector length has to be less than or equal to 4 but found ") |
197 | << t.getNumElements(); |
198 | return Type(); |
199 | } |
200 | } else { |
201 | parser.emitError(loc: typeLoc, message: "cannot use ") |
202 | << type << " to compose SPIR-V types"; |
203 | return Type(); |
204 | } |
205 | |
206 | return type; |
207 | } |
208 | |
209 | static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, |
210 | DialectAsmParser &parser) { |
211 | Type type; |
212 | SMLoc typeLoc = parser.getCurrentLocation(); |
213 | if (parser.parseType(result&: type)) |
214 | return Type(); |
215 | |
216 | if (auto t = llvm::dyn_cast<VectorType>(type)) { |
217 | if (t.getRank() != 1) { |
218 | parser.emitError(loc: typeLoc, message: "only 1-D vector allowed but found ") << t; |
219 | return Type(); |
220 | } |
221 | if (t.getNumElements() > 4 || t.getNumElements() < 2) { |
222 | parser.emitError(loc: typeLoc, |
223 | message: "matrix columns size has to be less than or equal " |
224 | "to 4 and greater than or equal 2, but found ") |
225 | << t.getNumElements(); |
226 | return Type(); |
227 | } |
228 | |
229 | if (!llvm::isa<FloatType>(t.getElementType())) { |
230 | parser.emitError(loc: typeLoc, message: "matrix columns' elements must be of " |
231 | "Float type, got ") |
232 | << t.getElementType(); |
233 | return Type(); |
234 | } |
235 | } else { |
236 | parser.emitError(loc: typeLoc, message: "matrix must be composed using vector " |
237 | "type, got ") |
238 | << type; |
239 | return Type(); |
240 | } |
241 | |
242 | return type; |
243 | } |
244 | |
245 | static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, |
246 | DialectAsmParser &parser) { |
247 | Type type; |
248 | SMLoc typeLoc = parser.getCurrentLocation(); |
249 | if (parser.parseType(result&: type)) |
250 | return Type(); |
251 | |
252 | if (!llvm::isa<ImageType>(Val: type)) { |
253 | parser.emitError(loc: typeLoc, |
254 | message: "sampled image must be composed using image type, got ") |
255 | << type; |
256 | return Type(); |
257 | } |
258 | |
259 | return type; |
260 | } |
261 | |
262 | /// Parses an optional `, stride = N` assembly segment. If no parsing failure |
263 | /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if |
264 | /// missing. |
265 | static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, |
266 | DialectAsmParser &parser, |
267 | unsigned &stride) { |
268 | if (failed(Result: parser.parseOptionalComma())) { |
269 | stride = 0; |
270 | return success(); |
271 | } |
272 | |
273 | if (parser.parseKeyword(keyword: "stride") || parser.parseEqual()) |
274 | return failure(); |
275 | |
276 | SMLoc strideLoc = parser.getCurrentLocation(); |
277 | std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser); |
278 | if (!optStride) |
279 | return failure(); |
280 | |
281 | if (!(stride = *optStride)) { |
282 | parser.emitError(loc: strideLoc, message: "ArrayStride must be greater than zero"); |
283 | return failure(); |
284 | } |
285 | return success(); |
286 | } |
287 | |
288 | // element-type ::= integer-type |
289 | // | floating-point-type |
290 | // | vector-type |
291 | // | spirv-type |
292 | // |
293 | // array-type ::= `!spirv.array` `<` integer-literal `x` element-type |
294 | // (`,` `stride` `=` integer-literal)? `>` |
295 | static Type parseArrayType(SPIRVDialect const &dialect, |
296 | DialectAsmParser &parser) { |
297 | if (parser.parseLess()) |
298 | return Type(); |
299 | |
300 | SmallVector<int64_t, 1> countDims; |
301 | SMLoc countLoc = parser.getCurrentLocation(); |
302 | if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false)) |
303 | return Type(); |
304 | if (countDims.size() != 1) { |
305 | parser.emitError(loc: countLoc, |
306 | message: "expected single integer for array element count"); |
307 | return Type(); |
308 | } |
309 | |
310 | // According to the SPIR-V spec: |
311 | // "Length is the number of elements in the array. It must be at least 1." |
312 | int64_t count = countDims[0]; |
313 | if (count == 0) { |
314 | parser.emitError(loc: countLoc, message: "expected array length greater than 0"); |
315 | return Type(); |
316 | } |
317 | |
318 | Type elementType = parseAndVerifyType(dialect, parser); |
319 | if (!elementType) |
320 | return Type(); |
321 | |
322 | unsigned stride = 0; |
323 | if (failed(parseOptionalArrayStride(dialect, parser, stride))) |
324 | return Type(); |
325 | |
326 | if (parser.parseGreater()) |
327 | return Type(); |
328 | return ArrayType::get(elementType, elementCount: count, stride); |
329 | } |
330 | |
331 | // cooperative-matrix-type ::= |
332 | // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,` |
333 | // scope `,` use `>` |
334 | static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, |
335 | DialectAsmParser &parser) { |
336 | if (parser.parseLess()) |
337 | return {}; |
338 | |
339 | SmallVector<int64_t, 2> dims; |
340 | SMLoc countLoc = parser.getCurrentLocation(); |
341 | if (parser.parseDimensionList(dimensions&: dims, /*allowDynamic=*/false)) |
342 | return {}; |
343 | |
344 | if (dims.size() != 2) { |
345 | parser.emitError(loc: countLoc, message: "expected row and column count"); |
346 | return {}; |
347 | } |
348 | |
349 | auto elementTy = parseAndVerifyType(dialect, parser); |
350 | if (!elementTy) |
351 | return {}; |
352 | |
353 | Scope scope; |
354 | if (parser.parseComma() || |
355 | spirv::parseEnumKeywordAttr(scope, parser, "scope <id>")) |
356 | return {}; |
357 | |
358 | CooperativeMatrixUseKHR use; |
359 | if (parser.parseComma() || |
360 | spirv::parseEnumKeywordAttr(use, parser, "use <id>")) |
361 | return {}; |
362 | |
363 | if (parser.parseGreater()) |
364 | return {}; |
365 | |
366 | return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use); |
367 | } |
368 | |
369 | // TODO: Reorder methods to be utilities first and parse*Type |
370 | // methods in alphabetical order |
371 | // |
372 | // storage-class ::= `UniformConstant` |
373 | // | `Uniform` |
374 | // | `Workgroup` |
375 | // | <and other storage classes...> |
376 | // |
377 | // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>` |
378 | static Type parsePointerType(SPIRVDialect const &dialect, |
379 | DialectAsmParser &parser) { |
380 | if (parser.parseLess()) |
381 | return Type(); |
382 | |
383 | auto pointeeType = parseAndVerifyType(dialect, parser); |
384 | if (!pointeeType) |
385 | return Type(); |
386 | |
387 | StringRef storageClassSpec; |
388 | SMLoc storageClassLoc = parser.getCurrentLocation(); |
389 | if (parser.parseComma() || parser.parseKeyword(keyword: &storageClassSpec)) |
390 | return Type(); |
391 | |
392 | auto storageClass = symbolizeStorageClass(storageClassSpec); |
393 | if (!storageClass) { |
394 | parser.emitError(loc: storageClassLoc, message: "unknown storage class: ") |
395 | << storageClassSpec; |
396 | return Type(); |
397 | } |
398 | if (parser.parseGreater()) |
399 | return Type(); |
400 | return PointerType::get(pointeeType, *storageClass); |
401 | } |
402 | |
403 | // runtime-array-type ::= `!spirv.rtarray` `<` element-type |
404 | // (`,` `stride` `=` integer-literal)? `>` |
405 | static Type parseRuntimeArrayType(SPIRVDialect const &dialect, |
406 | DialectAsmParser &parser) { |
407 | if (parser.parseLess()) |
408 | return Type(); |
409 | |
410 | Type elementType = parseAndVerifyType(dialect, parser); |
411 | if (!elementType) |
412 | return Type(); |
413 | |
414 | unsigned stride = 0; |
415 | if (failed(parseOptionalArrayStride(dialect, parser, stride))) |
416 | return Type(); |
417 | |
418 | if (parser.parseGreater()) |
419 | return Type(); |
420 | return RuntimeArrayType::get(elementType, stride); |
421 | } |
422 | |
423 | // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>` |
424 | static Type parseMatrixType(SPIRVDialect const &dialect, |
425 | DialectAsmParser &parser) { |
426 | if (parser.parseLess()) |
427 | return Type(); |
428 | |
429 | SmallVector<int64_t, 1> countDims; |
430 | SMLoc countLoc = parser.getCurrentLocation(); |
431 | if (parser.parseDimensionList(dimensions&: countDims, /*allowDynamic=*/false)) |
432 | return Type(); |
433 | if (countDims.size() != 1) { |
434 | parser.emitError(loc: countLoc, message: "expected single unsigned " |
435 | "integer for number of columns"); |
436 | return Type(); |
437 | } |
438 | |
439 | int64_t columnCount = countDims[0]; |
440 | // According to the specification, Matrices can have 2, 3, or 4 columns |
441 | if (columnCount < 2 || columnCount > 4) { |
442 | parser.emitError(loc: countLoc, message: "matrix is expected to have 2, 3, or 4 " |
443 | "columns"); |
444 | return Type(); |
445 | } |
446 | |
447 | Type columnType = parseAndVerifyMatrixType(dialect, parser); |
448 | if (!columnType) |
449 | return Type(); |
450 | |
451 | if (parser.parseGreater()) |
452 | return Type(); |
453 | |
454 | return MatrixType::get(columnType, columnCount); |
455 | } |
456 | |
457 | // Specialize this function to parse each of the parameters that define an |
458 | // ImageType. By default it assumes this is an enum type. |
459 | template <typename ValTy> |
460 | static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, |
461 | DialectAsmParser &parser) { |
462 | StringRef enumSpec; |
463 | SMLoc enumLoc = parser.getCurrentLocation(); |
464 | if (parser.parseKeyword(keyword: &enumSpec)) { |
465 | return std::nullopt; |
466 | } |
467 | |
468 | auto val = spirv::symbolizeEnum<ValTy>(enumSpec); |
469 | if (!val) |
470 | parser.emitError(loc: enumLoc, message: "unknown attribute: '") << enumSpec << "'"; |
471 | return val; |
472 | } |
473 | |
474 | template <> |
475 | std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, |
476 | DialectAsmParser &parser) { |
477 | // TODO: Further verify that the element type can be sampled |
478 | auto ty = parseAndVerifyType(dialect, parser); |
479 | if (!ty) |
480 | return std::nullopt; |
481 | return ty; |
482 | } |
483 | |
484 | template <typename IntTy> |
485 | static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect, |
486 | DialectAsmParser &parser) { |
487 | IntTy offsetVal = std::numeric_limits<IntTy>::max(); |
488 | if (parser.parseInteger(offsetVal)) |
489 | return std::nullopt; |
490 | return offsetVal; |
491 | } |
492 | |
493 | template <> |
494 | std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, |
495 | DialectAsmParser &parser) { |
496 | return parseAndVerifyInteger<unsigned>(dialect, parser); |
497 | } |
498 | |
499 | namespace { |
500 | // Functor object to parse a comma separated list of specs. The function |
501 | // parseAndVerify does the actual parsing and verification of individual |
502 | // elements. This is a functor since parsing the last element of the list |
503 | // (termination condition) needs partial specialization. |
504 | template <typename ParseType, typename... Args> |
505 | struct ParseCommaSeparatedList { |
506 | std::optional<std::tuple<ParseType, Args...>> |
507 | operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { |
508 | auto parseVal = parseAndVerify<ParseType>(dialect, parser); |
509 | if (!parseVal) |
510 | return std::nullopt; |
511 | |
512 | auto numArgs = std::tuple_size<std::tuple<Args...>>::value; |
513 | if (numArgs != 0 && failed(Result: parser.parseComma())) |
514 | return std::nullopt; |
515 | auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser); |
516 | if (!remainingValues) |
517 | return std::nullopt; |
518 | return std::tuple_cat(std::tuple<ParseType>(parseVal.value()), |
519 | remainingValues.value()); |
520 | } |
521 | }; |
522 | |
523 | // Partial specialization of the function to parse a comma separated list of |
524 | // specs to parse the last element of the list. |
525 | template <typename ParseType> |
526 | struct ParseCommaSeparatedList<ParseType> { |
527 | std::optional<std::tuple<ParseType>> |
528 | operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { |
529 | if (auto value = parseAndVerify<ParseType>(dialect, parser)) |
530 | return std::tuple<ParseType>(*value); |
531 | return std::nullopt; |
532 | } |
533 | }; |
534 | } // namespace |
535 | |
536 | // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...> |
537 | // |
538 | // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` |
539 | // |
540 | // arrayed-info ::= `NonArrayed` | `Arrayed` |
541 | // |
542 | // sampling-info ::= `SingleSampled` | `MultiSampled` |
543 | // |
544 | // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` |
545 | // |
546 | // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...> |
547 | // |
548 | // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,` |
549 | // arrayed-info `,` sampling-info `,` |
550 | // sampler-use-info `,` format `>` |
551 | static Type parseImageType(SPIRVDialect const &dialect, |
552 | DialectAsmParser &parser) { |
553 | if (parser.parseLess()) |
554 | return Type(); |
555 | |
556 | auto value = |
557 | ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo, |
558 | ImageSamplingInfo, ImageSamplerUseInfo, |
559 | ImageFormat>{}(dialect, parser); |
560 | if (!value) |
561 | return Type(); |
562 | |
563 | if (parser.parseGreater()) |
564 | return Type(); |
565 | return ImageType::get(*value); |
566 | } |
567 | |
568 | // sampledImage-type :: = `!spirv.sampledImage<` image-type `>` |
569 | static Type parseSampledImageType(SPIRVDialect const &dialect, |
570 | DialectAsmParser &parser) { |
571 | if (parser.parseLess()) |
572 | return Type(); |
573 | |
574 | Type parsedType = parseAndVerifySampledImageType(dialect, parser); |
575 | if (!parsedType) |
576 | return Type(); |
577 | |
578 | if (parser.parseGreater()) |
579 | return Type(); |
580 | return SampledImageType::get(imageType: parsedType); |
581 | } |
582 | |
583 | // Parse decorations associated with a member. |
584 | static ParseResult parseStructMemberDecorations( |
585 | SPIRVDialect const &dialect, DialectAsmParser &parser, |
586 | ArrayRef<Type> memberTypes, |
587 | SmallVectorImpl<StructType::OffsetInfo> &offsetInfo, |
588 | SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) { |
589 | |
590 | // Check if the first element is offset. |
591 | SMLoc offsetLoc = parser.getCurrentLocation(); |
592 | StructType::OffsetInfo offset = 0; |
593 | OptionalParseResult offsetParseResult = parser.parseOptionalInteger(result&: offset); |
594 | if (offsetParseResult.has_value()) { |
595 | if (failed(Result: *offsetParseResult)) |
596 | return failure(); |
597 | |
598 | if (offsetInfo.size() != memberTypes.size() - 1) { |
599 | return parser.emitError(loc: offsetLoc, |
600 | message: "offset specification must be given for " |
601 | "all members"); |
602 | } |
603 | offsetInfo.push_back(Elt: offset); |
604 | } |
605 | |
606 | // Check for no spirv::Decorations. |
607 | if (succeeded(Result: parser.parseOptionalRSquare())) |
608 | return success(); |
609 | |
610 | // If there was an offset, make sure to parse the comma. |
611 | if (offsetParseResult.has_value() && parser.parseComma()) |
612 | return failure(); |
613 | |
614 | // Check for spirv::Decorations. |
615 | auto parseDecorations = [&]() { |
616 | auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser); |
617 | if (!memberDecoration) |
618 | return failure(); |
619 | |
620 | // Parse member decoration value if it exists. |
621 | if (succeeded(Result: parser.parseOptionalEqual())) { |
622 | auto memberDecorationValue = |
623 | parseAndVerifyInteger<uint32_t>(dialect, parser); |
624 | |
625 | if (!memberDecorationValue) |
626 | return failure(); |
627 | |
628 | memberDecorationInfo.emplace_back( |
629 | static_cast<uint32_t>(memberTypes.size() - 1), 1, |
630 | memberDecoration.value(), memberDecorationValue.value()); |
631 | } else { |
632 | memberDecorationInfo.emplace_back( |
633 | static_cast<uint32_t>(memberTypes.size() - 1), 0, |
634 | memberDecoration.value(), 0); |
635 | } |
636 | return success(); |
637 | }; |
638 | if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: parseDecorations)) || |
639 | failed(Result: parser.parseRSquare())) |
640 | return failure(); |
641 | |
642 | return success(); |
643 | } |
644 | |
645 | // struct-member-decoration ::= integer-literal? spirv-decoration* |
646 | // struct-type ::= |
647 | // `!spirv.struct<` (id `,`)? |
648 | // `(` |
649 | // (spirv-type (`[` struct-member-decoration `]`)?)* |
650 | // `)>` |
651 | static Type parseStructType(SPIRVDialect const &dialect, |
652 | DialectAsmParser &parser) { |
653 | // TODO: This function is quite lengthy. Break it down into smaller chunks. |
654 | |
655 | if (parser.parseLess()) |
656 | return Type(); |
657 | |
658 | StringRef identifier; |
659 | FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse; |
660 | |
661 | // Check if this is an identified struct type. |
662 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: &identifier))) { |
663 | // Check if this is a possible recursive reference. |
664 | auto structType = |
665 | StructType::getIdentified(context: dialect.getContext(), identifier); |
666 | cyclicParse = parser.tryStartCyclicParse(structType); |
667 | if (succeeded(Result: parser.parseOptionalGreater())) { |
668 | if (succeeded(Result: cyclicParse)) { |
669 | parser.emitError( |
670 | loc: parser.getNameLoc(), |
671 | message: "recursive struct reference not nested in struct definition"); |
672 | |
673 | return Type(); |
674 | } |
675 | |
676 | return structType; |
677 | } |
678 | |
679 | if (failed(Result: parser.parseComma())) |
680 | return Type(); |
681 | |
682 | if (failed(Result: cyclicParse)) { |
683 | parser.emitError(loc: parser.getNameLoc(), |
684 | message: "identifier already used for an enclosing struct"); |
685 | return Type(); |
686 | } |
687 | } |
688 | |
689 | if (failed(Result: parser.parseLParen())) |
690 | return Type(); |
691 | |
692 | if (succeeded(Result: parser.parseOptionalRParen()) && |
693 | succeeded(Result: parser.parseOptionalGreater())) { |
694 | return StructType::getEmpty(context: dialect.getContext(), identifier); |
695 | } |
696 | |
697 | StructType idStructTy; |
698 | |
699 | if (!identifier.empty()) |
700 | idStructTy = StructType::getIdentified(context: dialect.getContext(), identifier); |
701 | |
702 | SmallVector<Type, 4> memberTypes; |
703 | SmallVector<StructType::OffsetInfo, 4> offsetInfo; |
704 | SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo; |
705 | |
706 | do { |
707 | Type memberType; |
708 | if (parser.parseType(result&: memberType)) |
709 | return Type(); |
710 | memberTypes.push_back(Elt: memberType); |
711 | |
712 | if (succeeded(Result: parser.parseOptionalLSquare())) |
713 | if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, |
714 | memberDecorationInfo)) |
715 | return Type(); |
716 | } while (succeeded(Result: parser.parseOptionalComma())); |
717 | |
718 | if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { |
719 | parser.emitError(loc: parser.getNameLoc(), |
720 | message: "offset specification must be given for all members"); |
721 | return Type(); |
722 | } |
723 | |
724 | if (failed(Result: parser.parseRParen()) || failed(Result: parser.parseGreater())) |
725 | return Type(); |
726 | |
727 | if (!identifier.empty()) { |
728 | if (failed(Result: idStructTy.trySetBody(memberTypes, offsetInfo, |
729 | memberDecorations: memberDecorationInfo))) |
730 | return Type(); |
731 | return idStructTy; |
732 | } |
733 | |
734 | return StructType::get(memberTypes, offsetInfo, memberDecorations: memberDecorationInfo); |
735 | } |
736 | |
737 | // spirv-type ::= array-type |
738 | // | element-type |
739 | // | image-type |
740 | // | pointer-type |
741 | // | runtime-array-type |
742 | // | sampled-image-type |
743 | // | struct-type |
744 | Type SPIRVDialect::parseType(DialectAsmParser &parser) const { |
745 | StringRef keyword; |
746 | if (parser.parseKeyword(&keyword)) |
747 | return Type(); |
748 | |
749 | if (keyword == "array") |
750 | return parseArrayType(*this, parser); |
751 | if (keyword == "coopmatrix") |
752 | return parseCooperativeMatrixType(*this, parser); |
753 | if (keyword == "image") |
754 | return parseImageType(*this, parser); |
755 | if (keyword == "ptr") |
756 | return parsePointerType(*this, parser); |
757 | if (keyword == "rtarray") |
758 | return parseRuntimeArrayType(*this, parser); |
759 | if (keyword == "sampled_image") |
760 | return parseSampledImageType(*this, parser); |
761 | if (keyword == "struct") |
762 | return parseStructType(*this, parser); |
763 | if (keyword == "matrix") |
764 | return parseMatrixType(*this, parser); |
765 | parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; |
766 | return Type(); |
767 | } |
768 | |
769 | //===----------------------------------------------------------------------===// |
770 | // Type Printing |
771 | //===----------------------------------------------------------------------===// |
772 | |
773 | static void print(ArrayType type, DialectAsmPrinter &os) { |
774 | os << "array<"<< type.getNumElements() << " x "<< type.getElementType(); |
775 | if (unsigned stride = type.getArrayStride()) |
776 | os << ", stride="<< stride; |
777 | os << ">"; |
778 | } |
779 | |
780 | static void print(RuntimeArrayType type, DialectAsmPrinter &os) { |
781 | os << "rtarray<"<< type.getElementType(); |
782 | if (unsigned stride = type.getArrayStride()) |
783 | os << ", stride="<< stride; |
784 | os << ">"; |
785 | } |
786 | |
787 | static void print(PointerType type, DialectAsmPrinter &os) { |
788 | os << "ptr<"<< type.getPointeeType() << ", " |
789 | << stringifyStorageClass(type.getStorageClass()) << ">"; |
790 | } |
791 | |
792 | static void print(ImageType type, DialectAsmPrinter &os) { |
793 | os << "image<"<< type.getElementType() << ", "<< stringifyDim(type.getDim()) |
794 | << ", "<< stringifyImageDepthInfo(type.getDepthInfo()) << ", " |
795 | << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " |
796 | << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", " |
797 | << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", " |
798 | << stringifyImageFormat(type.getImageFormat()) << ">"; |
799 | } |
800 | |
801 | static void print(SampledImageType type, DialectAsmPrinter &os) { |
802 | os << "sampled_image<"<< type.getImageType() << ">"; |
803 | } |
804 | |
805 | static void print(StructType type, DialectAsmPrinter &os) { |
806 | FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint; |
807 | |
808 | os << "struct<"; |
809 | |
810 | if (type.isIdentified()) { |
811 | os << type.getIdentifier(); |
812 | |
813 | cyclicPrint = os.tryStartCyclicPrint(attrOrType: type); |
814 | if (failed(Result: cyclicPrint)) { |
815 | os << ">"; |
816 | return; |
817 | } |
818 | |
819 | os << ", "; |
820 | } |
821 | |
822 | os << "("; |
823 | |
824 | auto printMember = [&](unsigned i) { |
825 | os << type.getElementType(i); |
826 | SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations; |
827 | type.getMemberDecorations(i, decorationsInfo&: decorations); |
828 | if (type.hasOffset() || !decorations.empty()) { |
829 | os << " ["; |
830 | if (type.hasOffset()) { |
831 | os << type.getMemberOffset(i); |
832 | if (!decorations.empty()) |
833 | os << ", "; |
834 | } |
835 | auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { |
836 | os << stringifyDecoration(decoration.decoration); |
837 | if (decoration.hasValue) { |
838 | os << "="<< decoration.decorationValue; |
839 | } |
840 | }; |
841 | llvm::interleaveComma(c: decorations, os, each_fn: eachFn); |
842 | os << "]"; |
843 | } |
844 | }; |
845 | llvm::interleaveComma(c: llvm::seq<unsigned>(Begin: 0, End: type.getNumElements()), os, |
846 | each_fn: printMember); |
847 | os << ")>"; |
848 | } |
849 | |
850 | static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { |
851 | os << "coopmatrix<"<< type.getRows() << "x"<< type.getColumns() << "x" |
852 | << type.getElementType() << ", "<< type.getScope() << ", " |
853 | << type.getUse() << ">"; |
854 | } |
855 | |
856 | static void print(MatrixType type, DialectAsmPrinter &os) { |
857 | os << "matrix<"<< type.getNumColumns() << " x "<< type.getColumnType(); |
858 | os << ">"; |
859 | } |
860 | |
861 | void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { |
862 | TypeSwitch<Type>(type) |
863 | .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType, |
864 | ImageType, SampledImageType, StructType, MatrixType>( |
865 | [&](auto type) { print(type, os); }) |
866 | .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); |
867 | } |
868 | |
869 | //===----------------------------------------------------------------------===// |
870 | // Constant |
871 | //===----------------------------------------------------------------------===// |
872 | |
873 | Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, |
874 | Attribute value, Type type, |
875 | Location loc) { |
876 | if (auto poison = dyn_cast<ub::PoisonAttr>(value)) |
877 | return builder.create<ub::PoisonOp>(loc, type, poison); |
878 | |
879 | if (!spirv::ConstantOp::isBuildableWith(type)) |
880 | return nullptr; |
881 | |
882 | return builder.create<spirv::ConstantOp>(loc, type, value); |
883 | } |
884 | |
885 | //===----------------------------------------------------------------------===// |
886 | // Shader Interface ABI |
887 | //===----------------------------------------------------------------------===// |
888 | |
889 | LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, |
890 | NamedAttribute attribute) { |
891 | StringRef symbol = attribute.getName().strref(); |
892 | Attribute attr = attribute.getValue(); |
893 | |
894 | if (symbol == spirv::getEntryPointABIAttrName()) { |
895 | if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) { |
896 | return op->emitError("'") |
897 | << symbol << "' attribute must be an entry point ABI attribute"; |
898 | } |
899 | } else if (symbol == spirv::getTargetEnvAttrName()) { |
900 | if (!llvm::isa<spirv::TargetEnvAttr>(attr)) |
901 | return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; |
902 | } else { |
903 | return op->emitError("found unsupported '") |
904 | << symbol << "' attribute on operation"; |
905 | } |
906 | |
907 | return success(); |
908 | } |
909 | |
910 | /// Verifies the given SPIR-V `attribute` attached to a value of the given |
911 | /// `valueType` is valid. |
912 | static LogicalResult verifyRegionAttribute(Location loc, Type valueType, |
913 | NamedAttribute attribute) { |
914 | StringRef symbol = attribute.getName().strref(); |
915 | Attribute attr = attribute.getValue(); |
916 | |
917 | if (symbol == spirv::getInterfaceVarABIAttrName()) { |
918 | auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(Val&: attr); |
919 | if (!varABIAttr) |
920 | return emitError(loc, message: "'") |
921 | << symbol << "' must be a spirv::InterfaceVarABIAttr"; |
922 | |
923 | if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) |
924 | return emitError(loc, message: "'") << symbol |
925 | << "' attribute cannot specify storage class " |
926 | "when attaching to a non-scalar value"; |
927 | return success(); |
928 | } |
929 | if (symbol == spirv::DecorationAttr::name) { |
930 | if (!isa<spirv::DecorationAttr>(attr)) |
931 | return emitError(loc, message: "'") |
932 | << symbol << "' must be a spirv::DecorationAttr"; |
933 | return success(); |
934 | } |
935 | |
936 | return emitError(loc, message: "found unsupported '") |
937 | << symbol << "' attribute on region argument"; |
938 | } |
939 | |
940 | LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, |
941 | unsigned regionIndex, |
942 | unsigned argIndex, |
943 | NamedAttribute attribute) { |
944 | auto funcOp = dyn_cast<FunctionOpInterface>(op); |
945 | if (!funcOp) |
946 | return success(); |
947 | Type argType = funcOp.getArgumentTypes()[argIndex]; |
948 | |
949 | return verifyRegionAttribute(op->getLoc(), argType, attribute); |
950 | } |
951 | |
952 | LogicalResult SPIRVDialect::verifyRegionResultAttribute( |
953 | Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, |
954 | NamedAttribute attribute) { |
955 | return op->emitError("cannot attach SPIR-V attributes to region result"); |
956 | } |
957 |
Definitions
- containsReturn
- SPIRVInlinerInterface
- isLegalToInline
- isLegalToInline
- isLegalToInline
- handleTerminator
- handleTerminator
- parseAndVerifyType
- parseAndVerifyMatrixType
- parseAndVerifySampledImageType
- parseOptionalArrayStride
- parseArrayType
- parseCooperativeMatrixType
- parsePointerType
- parseRuntimeArrayType
- parseMatrixType
- parseAndVerify
- parseAndVerify
- parseAndVerifyInteger
- parseAndVerify
- ParseCommaSeparatedList
- operator()
- ParseCommaSeparatedList
- operator()
- parseImageType
- parseSampledImageType
- parseStructMemberDecorations
- parseStructType
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more