1 | //===- SPIRVAttributes.cpp - SPIR-V attribute definitions -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
10 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
11 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/DialectImplementation.h" |
14 | #include "llvm/ADT/TypeSwitch.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::spirv; |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // TableGen'erated attribute utility functions |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | namespace mlir { |
24 | namespace spirv { |
25 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc" |
26 | } // namespace spirv |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // Attribute storage classes |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | namespace spirv { |
33 | namespace detail { |
34 | |
35 | struct InterfaceVarABIAttributeStorage : public AttributeStorage { |
36 | using KeyTy = std::tuple<Attribute, Attribute, Attribute>; |
37 | |
38 | InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding, |
39 | Attribute storageClass) |
40 | : descriptorSet(descriptorSet), binding(binding), |
41 | storageClass(storageClass) {} |
42 | |
43 | bool operator==(const KeyTy &key) const { |
44 | return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding && |
45 | std::get<2>(key) == storageClass; |
46 | } |
47 | |
48 | static InterfaceVarABIAttributeStorage * |
49 | construct(AttributeStorageAllocator &allocator, const KeyTy &key) { |
50 | return new (allocator.allocate<InterfaceVarABIAttributeStorage>()) |
51 | InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key), |
52 | std::get<2>(key)); |
53 | } |
54 | |
55 | Attribute descriptorSet; |
56 | Attribute binding; |
57 | Attribute storageClass; |
58 | }; |
59 | |
60 | struct VerCapExtAttributeStorage : public AttributeStorage { |
61 | using KeyTy = std::tuple<Attribute, Attribute, Attribute>; |
62 | |
63 | VerCapExtAttributeStorage(Attribute version, Attribute capabilities, |
64 | Attribute extensions) |
65 | : version(version), capabilities(capabilities), extensions(extensions) {} |
66 | |
67 | bool operator==(const KeyTy &key) const { |
68 | return std::get<0>(key) == version && std::get<1>(key) == capabilities && |
69 | std::get<2>(key) == extensions; |
70 | } |
71 | |
72 | static VerCapExtAttributeStorage * |
73 | construct(AttributeStorageAllocator &allocator, const KeyTy &key) { |
74 | return new (allocator.allocate<VerCapExtAttributeStorage>()) |
75 | VerCapExtAttributeStorage(std::get<0>(key), std::get<1>(key), |
76 | std::get<2>(key)); |
77 | } |
78 | |
79 | Attribute version; |
80 | Attribute capabilities; |
81 | Attribute extensions; |
82 | }; |
83 | |
84 | struct TargetEnvAttributeStorage : public AttributeStorage { |
85 | using KeyTy = |
86 | std::tuple<Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute>; |
87 | |
88 | TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI, |
89 | Vendor vendorID, DeviceType deviceType, |
90 | uint32_t deviceID, Attribute limits) |
91 | : triple(triple), limits(limits), clientAPI(clientAPI), |
92 | vendorID(vendorID), deviceType(deviceType), deviceID(deviceID) {} |
93 | |
94 | bool operator==(const KeyTy &key) const { |
95 | return key == std::make_tuple(triple, clientAPI, vendorID, deviceType, |
96 | deviceID, limits); |
97 | } |
98 | |
99 | static TargetEnvAttributeStorage * |
100 | construct(AttributeStorageAllocator &allocator, const KeyTy &key) { |
101 | return new (allocator.allocate<TargetEnvAttributeStorage>()) |
102 | TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key), |
103 | std::get<2>(key), std::get<3>(key), |
104 | std::get<4>(key), std::get<5>(key)); |
105 | } |
106 | |
107 | Attribute triple; |
108 | Attribute limits; |
109 | ClientAPI clientAPI; |
110 | Vendor vendorID; |
111 | DeviceType deviceType; |
112 | uint32_t deviceID; |
113 | }; |
114 | } // namespace detail |
115 | } // namespace spirv |
116 | } // namespace mlir |
117 | |
118 | //===----------------------------------------------------------------------===// |
119 | // InterfaceVarABIAttr |
120 | //===----------------------------------------------------------------------===// |
121 | |
122 | spirv::InterfaceVarABIAttr |
123 | spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding, |
124 | std::optional<spirv::StorageClass> storageClass, |
125 | MLIRContext *context) { |
126 | Builder b(context); |
127 | auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet); |
128 | auto bindingAttr = b.getI32IntegerAttr(binding); |
129 | auto storageClassAttr = |
130 | storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass)) |
131 | : IntegerAttr(); |
132 | return get(descriptorSetAttr, bindingAttr, storageClassAttr); |
133 | } |
134 | |
135 | spirv::InterfaceVarABIAttr |
136 | spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding, |
137 | IntegerAttr storageClass) { |
138 | assert(descriptorSet && binding); |
139 | MLIRContext *context = descriptorSet.getContext(); |
140 | return Base::get(context, descriptorSet, binding, storageClass); |
141 | } |
142 | |
143 | StringRef spirv::InterfaceVarABIAttr::getKindName() { |
144 | return "interface_var_abi" ; |
145 | } |
146 | |
147 | uint32_t spirv::InterfaceVarABIAttr::getBinding() { |
148 | return llvm::cast<IntegerAttr>(getImpl()->binding).getInt(); |
149 | } |
150 | |
151 | uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() { |
152 | return llvm::cast<IntegerAttr>(getImpl()->descriptorSet).getInt(); |
153 | } |
154 | |
155 | std::optional<spirv::StorageClass> |
156 | spirv::InterfaceVarABIAttr::getStorageClass() { |
157 | if (getImpl()->storageClass) |
158 | return static_cast<spirv::StorageClass>( |
159 | llvm::cast<IntegerAttr>(getImpl()->storageClass) |
160 | .getValue() |
161 | .getZExtValue()); |
162 | return std::nullopt; |
163 | } |
164 | |
165 | LogicalResult spirv::InterfaceVarABIAttr::verify( |
166 | function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet, |
167 | IntegerAttr binding, IntegerAttr storageClass) { |
168 | if (!descriptorSet.getType().isSignlessInteger(32)) |
169 | return emitError() << "expected 32-bit integer for descriptor set" ; |
170 | |
171 | if (!binding.getType().isSignlessInteger(32)) |
172 | return emitError() << "expected 32-bit integer for binding" ; |
173 | |
174 | if (storageClass) { |
175 | if (auto storageClassAttr = llvm::cast<IntegerAttr>(storageClass)) { |
176 | auto storageClassValue = |
177 | spirv::symbolizeStorageClass(storageClassAttr.getInt()); |
178 | if (!storageClassValue) |
179 | return emitError() << "unknown storage class" ; |
180 | } else { |
181 | return emitError() << "expected valid storage class" ; |
182 | } |
183 | } |
184 | |
185 | return success(); |
186 | } |
187 | |
188 | //===----------------------------------------------------------------------===// |
189 | // VerCapExtAttr |
190 | //===----------------------------------------------------------------------===// |
191 | |
192 | spirv::VerCapExtAttr spirv::VerCapExtAttr::get( |
193 | spirv::Version version, ArrayRef<spirv::Capability> capabilities, |
194 | ArrayRef<spirv::Extension> extensions, MLIRContext *context) { |
195 | Builder b(context); |
196 | |
197 | auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version)); |
198 | |
199 | SmallVector<Attribute, 4> capAttrs; |
200 | capAttrs.reserve(N: capabilities.size()); |
201 | for (spirv::Capability cap : capabilities) |
202 | capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap))); |
203 | |
204 | SmallVector<Attribute, 4> extAttrs; |
205 | extAttrs.reserve(N: extensions.size()); |
206 | for (spirv::Extension ext : extensions) |
207 | extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext))); |
208 | |
209 | return get(version: versionAttr, capabilities: b.getArrayAttr(capAttrs), extensions: b.getArrayAttr(extAttrs)); |
210 | } |
211 | |
212 | spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version, |
213 | ArrayAttr capabilities, |
214 | ArrayAttr extensions) { |
215 | assert(version && capabilities && extensions); |
216 | MLIRContext *context = version.getContext(); |
217 | return Base::get(context, version, capabilities, extensions); |
218 | } |
219 | |
220 | StringRef spirv::VerCapExtAttr::getKindName() { return "vce" ; } |
221 | |
222 | spirv::Version spirv::VerCapExtAttr::getVersion() { |
223 | return static_cast<spirv::Version>( |
224 | llvm::cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue()); |
225 | } |
226 | |
227 | spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it) |
228 | : llvm::mapped_iterator<ArrayAttr::iterator, |
229 | spirv::Extension (*)(Attribute)>( |
230 | it, [](Attribute attr) { |
231 | return *symbolizeExtension(llvm::cast<StringAttr>(attr).getValue()); |
232 | }) {} |
233 | |
234 | spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() { |
235 | auto range = getExtensionsAttr().getValue(); |
236 | return {ext_iterator(range.begin()), ext_iterator(range.end())}; |
237 | } |
238 | |
239 | ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() { |
240 | return llvm::cast<ArrayAttr>(getImpl()->extensions); |
241 | } |
242 | |
243 | spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it) |
244 | : llvm::mapped_iterator<ArrayAttr::iterator, |
245 | spirv::Capability (*)(Attribute)>( |
246 | it, [](Attribute attr) { |
247 | return *symbolizeCapability( |
248 | llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()); |
249 | }) {} |
250 | |
251 | spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() { |
252 | auto range = getCapabilitiesAttr().getValue(); |
253 | return {cap_iterator(range.begin()), cap_iterator(range.end())}; |
254 | } |
255 | |
256 | ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() { |
257 | return llvm::cast<ArrayAttr>(getImpl()->capabilities); |
258 | } |
259 | |
260 | LogicalResult |
261 | spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
262 | IntegerAttr version, ArrayAttr capabilities, |
263 | ArrayAttr extensions) { |
264 | if (!version.getType().isSignlessInteger(32)) |
265 | return emitError() << "expected 32-bit integer for version" ; |
266 | |
267 | if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { |
268 | if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) |
269 | if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue())) |
270 | return true; |
271 | return false; |
272 | })) |
273 | return emitError() << "unknown capability in capability list" ; |
274 | |
275 | if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { |
276 | if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) |
277 | if (spirv::symbolizeExtension(strAttr.getValue())) |
278 | return true; |
279 | return false; |
280 | })) |
281 | return emitError() << "unknown extension in extension list" ; |
282 | |
283 | return success(); |
284 | } |
285 | |
286 | //===----------------------------------------------------------------------===// |
287 | // TargetEnvAttr |
288 | //===----------------------------------------------------------------------===// |
289 | |
290 | spirv::TargetEnvAttr spirv::TargetEnvAttr::get( |
291 | spirv::VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI, |
292 | Vendor vendorID, DeviceType deviceType, uint32_t deviceID) { |
293 | assert(triple && limits && "expected valid triple and limits" ); |
294 | MLIRContext *context = triple.getContext(); |
295 | return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID, |
296 | limits); |
297 | } |
298 | |
299 | StringRef spirv::TargetEnvAttr::getKindName() { return "target_env" ; } |
300 | |
301 | spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const { |
302 | return llvm::cast<spirv::VerCapExtAttr>(getImpl()->triple); |
303 | } |
304 | |
305 | spirv::Version spirv::TargetEnvAttr::getVersion() const { |
306 | return getTripleAttr().getVersion(); |
307 | } |
308 | |
309 | spirv::VerCapExtAttr::ext_range spirv::TargetEnvAttr::getExtensions() { |
310 | return getTripleAttr().getExtensions(); |
311 | } |
312 | |
313 | ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() { |
314 | return getTripleAttr().getExtensionsAttr(); |
315 | } |
316 | |
317 | spirv::VerCapExtAttr::cap_range spirv::TargetEnvAttr::getCapabilities() { |
318 | return getTripleAttr().getCapabilities(); |
319 | } |
320 | |
321 | ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() { |
322 | return getTripleAttr().getCapabilitiesAttr(); |
323 | } |
324 | |
325 | spirv::ClientAPI spirv::TargetEnvAttr::getClientAPI() const { |
326 | return getImpl()->clientAPI; |
327 | } |
328 | |
329 | spirv::Vendor spirv::TargetEnvAttr::getVendorID() const { |
330 | return getImpl()->vendorID; |
331 | } |
332 | |
333 | spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const { |
334 | return getImpl()->deviceType; |
335 | } |
336 | |
337 | uint32_t spirv::TargetEnvAttr::getDeviceID() const { |
338 | return getImpl()->deviceID; |
339 | } |
340 | |
341 | spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const { |
342 | return llvm::cast<spirv::ResourceLimitsAttr>(getImpl()->limits); |
343 | } |
344 | |
345 | //===----------------------------------------------------------------------===// |
346 | // ODS Generated Attributes |
347 | //===----------------------------------------------------------------------===// |
348 | |
349 | #define GET_ATTRDEF_CLASSES |
350 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc" |
351 | |
352 | //===----------------------------------------------------------------------===// |
353 | // Attribute Parsing |
354 | //===----------------------------------------------------------------------===// |
355 | |
356 | /// Parses a comma-separated list of keywords, invokes `processKeyword` on each |
357 | /// of the parsed keyword, and returns failure if any error occurs. |
358 | static ParseResult |
359 | parseKeywordList(DialectAsmParser &parser, |
360 | function_ref<LogicalResult(SMLoc, StringRef)> processKeyword) { |
361 | if (parser.parseLSquare()) |
362 | return failure(); |
363 | |
364 | // Special case for empty list. |
365 | if (succeeded(result: parser.parseOptionalRSquare())) |
366 | return success(); |
367 | |
368 | // Keep parsing the keyword and an optional comma following it. If the comma |
369 | // is successfully parsed, then we have more keywords to parse. |
370 | if (failed(result: parser.parseCommaSeparatedList([&]() { |
371 | auto loc = parser.getCurrentLocation(); |
372 | StringRef keyword; |
373 | if (parser.parseKeyword(keyword: &keyword) || |
374 | failed(processKeyword(loc, keyword))) |
375 | return failure(); |
376 | return success(); |
377 | }))) |
378 | return failure(); |
379 | return parser.parseRSquare(); |
380 | } |
381 | |
382 | /// Parses a spirv::InterfaceVarABIAttr. |
383 | static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) { |
384 | if (parser.parseLess()) |
385 | return {}; |
386 | |
387 | Builder &builder = parser.getBuilder(); |
388 | |
389 | if (parser.parseLParen()) |
390 | return {}; |
391 | |
392 | IntegerAttr descriptorSetAttr; |
393 | { |
394 | auto loc = parser.getCurrentLocation(); |
395 | uint32_t descriptorSet = 0; |
396 | auto descriptorSetParseResult = parser.parseOptionalInteger(result&: descriptorSet); |
397 | |
398 | if (!descriptorSetParseResult.has_value() || |
399 | failed(result: *descriptorSetParseResult)) { |
400 | parser.emitError(loc, message: "missing descriptor set" ); |
401 | return {}; |
402 | } |
403 | descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet); |
404 | } |
405 | |
406 | if (parser.parseComma()) |
407 | return {}; |
408 | |
409 | IntegerAttr bindingAttr; |
410 | { |
411 | auto loc = parser.getCurrentLocation(); |
412 | uint32_t binding = 0; |
413 | auto bindingParseResult = parser.parseOptionalInteger(result&: binding); |
414 | |
415 | if (!bindingParseResult.has_value() || failed(result: *bindingParseResult)) { |
416 | parser.emitError(loc, message: "missing binding" ); |
417 | return {}; |
418 | } |
419 | bindingAttr = builder.getI32IntegerAttr(binding); |
420 | } |
421 | |
422 | if (parser.parseRParen()) |
423 | return {}; |
424 | |
425 | IntegerAttr storageClassAttr; |
426 | { |
427 | if (succeeded(result: parser.parseOptionalComma())) { |
428 | auto loc = parser.getCurrentLocation(); |
429 | StringRef storageClass; |
430 | if (parser.parseKeyword(keyword: &storageClass)) |
431 | return {}; |
432 | |
433 | if (auto storageClassSymbol = |
434 | spirv::symbolizeStorageClass(storageClass)) { |
435 | storageClassAttr = builder.getI32IntegerAttr( |
436 | static_cast<uint32_t>(*storageClassSymbol)); |
437 | } else { |
438 | parser.emitError(loc, message: "unknown storage class: " ) << storageClass; |
439 | return {}; |
440 | } |
441 | } |
442 | } |
443 | |
444 | if (parser.parseGreater()) |
445 | return {}; |
446 | |
447 | return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr, |
448 | storageClassAttr); |
449 | } |
450 | |
451 | static Attribute parseVerCapExtAttr(DialectAsmParser &parser) { |
452 | if (parser.parseLess()) |
453 | return {}; |
454 | |
455 | Builder &builder = parser.getBuilder(); |
456 | |
457 | IntegerAttr versionAttr; |
458 | { |
459 | auto loc = parser.getCurrentLocation(); |
460 | StringRef version; |
461 | if (parser.parseKeyword(keyword: &version) || parser.parseComma()) |
462 | return {}; |
463 | |
464 | if (auto versionSymbol = spirv::symbolizeVersion(version)) { |
465 | versionAttr = |
466 | builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol)); |
467 | } else { |
468 | parser.emitError(loc, message: "unknown version: " ) << version; |
469 | return {}; |
470 | } |
471 | } |
472 | |
473 | ArrayAttr capabilitiesAttr; |
474 | { |
475 | SmallVector<Attribute, 4> capabilities; |
476 | SMLoc errorloc; |
477 | StringRef errorKeyword; |
478 | |
479 | auto processCapability = [&](SMLoc loc, StringRef capability) { |
480 | if (auto capSymbol = spirv::symbolizeCapability(capability)) { |
481 | capabilities.push_back( |
482 | builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol))); |
483 | return success(); |
484 | } |
485 | return errorloc = loc, errorKeyword = capability, failure(); |
486 | }; |
487 | if (parseKeywordList(parser, processCapability) || parser.parseComma()) { |
488 | if (!errorKeyword.empty()) |
489 | parser.emitError(loc: errorloc, message: "unknown capability: " ) << errorKeyword; |
490 | return {}; |
491 | } |
492 | |
493 | capabilitiesAttr = builder.getArrayAttr(capabilities); |
494 | } |
495 | |
496 | ArrayAttr extensionsAttr; |
497 | { |
498 | SmallVector<Attribute, 1> extensions; |
499 | SMLoc errorloc; |
500 | StringRef errorKeyword; |
501 | |
502 | auto processExtension = [&](SMLoc loc, StringRef extension) { |
503 | if (spirv::symbolizeExtension(extension)) { |
504 | extensions.push_back(builder.getStringAttr(extension)); |
505 | return success(); |
506 | } |
507 | return errorloc = loc, errorKeyword = extension, failure(); |
508 | }; |
509 | if (parseKeywordList(parser, processExtension)) { |
510 | if (!errorKeyword.empty()) |
511 | parser.emitError(loc: errorloc, message: "unknown extension: " ) << errorKeyword; |
512 | return {}; |
513 | } |
514 | |
515 | extensionsAttr = builder.getArrayAttr(extensions); |
516 | } |
517 | |
518 | if (parser.parseGreater()) |
519 | return {}; |
520 | |
521 | return spirv::VerCapExtAttr::get(version: versionAttr, capabilities: capabilitiesAttr, |
522 | extensions: extensionsAttr); |
523 | } |
524 | |
525 | /// Parses a spirv::TargetEnvAttr. |
526 | static Attribute parseTargetEnvAttr(DialectAsmParser &parser) { |
527 | if (parser.parseLess()) |
528 | return {}; |
529 | |
530 | spirv::VerCapExtAttr tripleAttr; |
531 | if (parser.parseAttribute(result&: tripleAttr) || parser.parseComma()) |
532 | return {}; |
533 | |
534 | auto clientAPI = spirv::ClientAPI::Unknown; |
535 | if (succeeded(result: parser.parseOptionalKeyword(keyword: "api" ))) { |
536 | if (parser.parseEqual()) |
537 | return {}; |
538 | auto loc = parser.getCurrentLocation(); |
539 | StringRef apiStr; |
540 | if (parser.parseKeyword(keyword: &apiStr)) |
541 | return {}; |
542 | if (auto apiSymbol = spirv::symbolizeClientAPI(apiStr)) |
543 | clientAPI = *apiSymbol; |
544 | else |
545 | parser.emitError(loc, message: "unknown client API: " ) << apiStr; |
546 | if (parser.parseComma()) |
547 | return {}; |
548 | } |
549 | |
550 | // Parse [vendor[:device-type[:device-id]]] |
551 | Vendor vendorID = Vendor::Unknown; |
552 | DeviceType deviceType = DeviceType::Unknown; |
553 | uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID; |
554 | { |
555 | auto loc = parser.getCurrentLocation(); |
556 | StringRef vendorStr; |
557 | if (succeeded(result: parser.parseOptionalKeyword(keyword: &vendorStr))) { |
558 | if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) |
559 | vendorID = *vendorSymbol; |
560 | else |
561 | parser.emitError(loc, message: "unknown vendor: " ) << vendorStr; |
562 | |
563 | if (succeeded(result: parser.parseOptionalColon())) { |
564 | loc = parser.getCurrentLocation(); |
565 | StringRef deviceTypeStr; |
566 | if (parser.parseKeyword(keyword: &deviceTypeStr)) |
567 | return {}; |
568 | if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) |
569 | deviceType = *deviceTypeSymbol; |
570 | else |
571 | parser.emitError(loc, message: "unknown device type: " ) << deviceTypeStr; |
572 | |
573 | if (succeeded(result: parser.parseOptionalColon())) { |
574 | loc = parser.getCurrentLocation(); |
575 | if (parser.parseInteger(result&: deviceID)) |
576 | return {}; |
577 | } |
578 | } |
579 | if (parser.parseComma()) |
580 | return {}; |
581 | } |
582 | } |
583 | |
584 | ResourceLimitsAttr limitsAttr; |
585 | if (parser.parseAttribute(limitsAttr) || parser.parseGreater()) |
586 | return {}; |
587 | |
588 | return spirv::TargetEnvAttr::get(tripleAttr, limitsAttr, clientAPI, vendorID, |
589 | deviceType, deviceID); |
590 | } |
591 | |
592 | Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser, |
593 | Type type) const { |
594 | // SPIR-V attributes are dictionaries so they do not have type. |
595 | if (type) { |
596 | parser.emitError(parser.getNameLoc(), "unexpected type" ); |
597 | return {}; |
598 | } |
599 | |
600 | // Parse the kind keyword first. |
601 | StringRef attrKind; |
602 | Attribute attr; |
603 | OptionalParseResult result = |
604 | generatedAttributeParser(parser, &attrKind, type, attr); |
605 | if (result.has_value()) |
606 | return attr; |
607 | |
608 | if (attrKind == spirv::TargetEnvAttr::getKindName()) |
609 | return parseTargetEnvAttr(parser); |
610 | if (attrKind == spirv::VerCapExtAttr::getKindName()) |
611 | return parseVerCapExtAttr(parser); |
612 | if (attrKind == spirv::InterfaceVarABIAttr::getKindName()) |
613 | return parseInterfaceVarABIAttr(parser); |
614 | |
615 | parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: " ) |
616 | << attrKind; |
617 | return {}; |
618 | } |
619 | |
620 | //===----------------------------------------------------------------------===// |
621 | // Attribute Printing |
622 | //===----------------------------------------------------------------------===// |
623 | |
624 | static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) { |
625 | auto &os = printer.getStream(); |
626 | printer << spirv::VerCapExtAttr::getKindName() << "<" |
627 | << spirv::stringifyVersion(triple.getVersion()) << ", [" ; |
628 | llvm::interleaveComma( |
629 | triple.getCapabilities(), os, |
630 | [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); |
631 | printer << "], [" ; |
632 | llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) { |
633 | os << llvm::cast<StringAttr>(attr).getValue(); |
634 | }); |
635 | printer << "]>" ; |
636 | } |
637 | |
638 | static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) { |
639 | printer << spirv::TargetEnvAttr::getKindName() << "<#spirv." ; |
640 | print(triple: targetEnv.getTripleAttr(), printer); |
641 | auto clientAPI = targetEnv.getClientAPI(); |
642 | if (clientAPI != spirv::ClientAPI::Unknown) |
643 | printer << ", api=" << clientAPI; |
644 | spirv::Vendor vendorID = targetEnv.getVendorID(); |
645 | spirv::DeviceType deviceType = targetEnv.getDeviceType(); |
646 | uint32_t deviceID = targetEnv.getDeviceID(); |
647 | if (vendorID != spirv::Vendor::Unknown) { |
648 | printer << ", " << spirv::stringifyVendor(vendorID); |
649 | if (deviceType != spirv::DeviceType::Unknown) { |
650 | printer << ":" << spirv::stringifyDeviceType(deviceType); |
651 | if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID) |
652 | printer << ":" << deviceID; |
653 | } |
654 | } |
655 | printer << ", " << targetEnv.getResourceLimits() << ">" ; |
656 | } |
657 | |
658 | static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr, |
659 | DialectAsmPrinter &printer) { |
660 | printer << spirv::InterfaceVarABIAttr::getKindName() << "<(" |
661 | << interfaceVarABIAttr.getDescriptorSet() << ", " |
662 | << interfaceVarABIAttr.getBinding() << ")" ; |
663 | auto storageClass = interfaceVarABIAttr.getStorageClass(); |
664 | if (storageClass) |
665 | printer << ", " << spirv::stringifyStorageClass(*storageClass); |
666 | printer << ">" ; |
667 | } |
668 | |
669 | void SPIRVDialect::printAttribute(Attribute attr, |
670 | DialectAsmPrinter &printer) const { |
671 | if (succeeded(generatedAttributePrinter(attr, printer))) |
672 | return; |
673 | |
674 | if (auto targetEnv = llvm::dyn_cast<TargetEnvAttr>(attr)) |
675 | print(targetEnv, printer); |
676 | else if (auto vceAttr = llvm::dyn_cast<VerCapExtAttr>(attr)) |
677 | print(vceAttr, printer); |
678 | else if (auto interfaceVarABIAttr = llvm::dyn_cast<InterfaceVarABIAttr>(attr)) |
679 | print(interfaceVarABIAttr, printer); |
680 | else |
681 | llvm_unreachable("unhandled SPIR-V attribute kind" ); |
682 | } |
683 | |
684 | //===----------------------------------------------------------------------===// |
685 | // SPIR-V Dialect |
686 | //===----------------------------------------------------------------------===// |
687 | |
688 | void spirv::SPIRVDialect::registerAttributes() { |
689 | addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>(); |
690 | addAttributes< |
691 | #define GET_ATTRDEF_LIST |
692 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc" |
693 | >(); |
694 | } |
695 | |