1 | //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the parser for the MLIR Types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Parser.h" |
14 | |
15 | #include "AsmParserImpl.h" |
16 | #include "mlir/AsmParser/AsmParserState.h" |
17 | #include "mlir/IR/AffineMap.h" |
18 | #include "mlir/IR/BuiltinAttributes.h" |
19 | #include "mlir/IR/BuiltinDialect.h" |
20 | #include "mlir/IR/BuiltinTypes.h" |
21 | #include "mlir/IR/DialectImplementation.h" |
22 | #include "mlir/IR/DialectResourceBlobManager.h" |
23 | #include "mlir/IR/IntegerSet.h" |
24 | #include "llvm/ADT/StringExtras.h" |
25 | #include "llvm/Support/Endian.h" |
26 | #include <optional> |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::detail; |
30 | |
31 | /// Parse an arbitrary attribute. |
32 | /// |
33 | /// attribute-value ::= `unit` |
34 | /// | bool-literal |
35 | /// | integer-literal (`:` (index-type | integer-type))? |
36 | /// | float-literal (`:` float-type)? |
37 | /// | string-literal (`:` type)? |
38 | /// | type |
39 | /// | `[` `:` (integer-type | float-type) tensor-literal `]` |
40 | /// | `[` (attribute-value (`,` attribute-value)*)? `]` |
41 | /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` |
42 | /// | symbol-ref-id (`::` symbol-ref-id)* |
43 | /// | `dense` `<` tensor-literal `>` `:` |
44 | /// (tensor-type | vector-type) |
45 | /// | `sparse` `<` attribute-value `,` attribute-value `>` |
46 | /// `:` (tensor-type | vector-type) |
47 | /// | `strided` `<` `[` comma-separated-int-or-question `]` |
48 | /// (`,` `offset` `:` integer-literal)? `>` |
49 | /// | distinct-attribute |
50 | /// | extended-attribute |
51 | /// |
52 | Attribute Parser::parseAttribute(Type type) { |
53 | switch (getToken().getKind()) { |
54 | // Parse an AffineMap or IntegerSet attribute. |
55 | case Token::kw_affine_map: { |
56 | consumeToken(kind: Token::kw_affine_map); |
57 | |
58 | AffineMap map; |
59 | if (parseToken(expectedToken: Token::less, message: "expected '<' in affine map" ) || |
60 | parseAffineMapReference(map) || |
61 | parseToken(expectedToken: Token::greater, message: "expected '>' in affine map" )) |
62 | return Attribute(); |
63 | return AffineMapAttr::get(map); |
64 | } |
65 | case Token::kw_affine_set: { |
66 | consumeToken(kind: Token::kw_affine_set); |
67 | |
68 | IntegerSet set; |
69 | if (parseToken(expectedToken: Token::less, message: "expected '<' in integer set" ) || |
70 | parseIntegerSetReference(set) || |
71 | parseToken(expectedToken: Token::greater, message: "expected '>' in integer set" )) |
72 | return Attribute(); |
73 | return IntegerSetAttr::get(set); |
74 | } |
75 | |
76 | // Parse an array attribute. |
77 | case Token::l_square: { |
78 | consumeToken(kind: Token::l_square); |
79 | SmallVector<Attribute, 4> elements; |
80 | auto parseElt = [&]() -> ParseResult { |
81 | elements.push_back(Elt: parseAttribute()); |
82 | return elements.back() ? success() : failure(); |
83 | }; |
84 | |
85 | if (parseCommaSeparatedListUntil(rightToken: Token::r_square, parseElement: parseElt)) |
86 | return nullptr; |
87 | return builder.getArrayAttr(elements); |
88 | } |
89 | |
90 | // Parse a boolean attribute. |
91 | case Token::kw_false: |
92 | consumeToken(kind: Token::kw_false); |
93 | return builder.getBoolAttr(value: false); |
94 | case Token::kw_true: |
95 | consumeToken(kind: Token::kw_true); |
96 | return builder.getBoolAttr(value: true); |
97 | |
98 | // Parse a dense elements attribute. |
99 | case Token::kw_dense: |
100 | return parseDenseElementsAttr(attrType: type); |
101 | |
102 | // Parse a dense resource elements attribute. |
103 | case Token::kw_dense_resource: |
104 | return parseDenseResourceElementsAttr(attrType: type); |
105 | |
106 | // Parse a dense array attribute. |
107 | case Token::kw_array: |
108 | return parseDenseArrayAttr(type); |
109 | |
110 | // Parse a dictionary attribute. |
111 | case Token::l_brace: { |
112 | NamedAttrList elements; |
113 | if (parseAttributeDict(attributes&: elements)) |
114 | return nullptr; |
115 | return elements.getDictionary(getContext()); |
116 | } |
117 | |
118 | // Parse an extended attribute, i.e. alias or dialect attribute. |
119 | case Token::hash_identifier: |
120 | return parseExtendedAttr(type); |
121 | |
122 | // Parse floating point and integer attributes. |
123 | case Token::floatliteral: |
124 | return parseFloatAttr(type, /*isNegative=*/false); |
125 | case Token::integer: |
126 | return parseDecOrHexAttr(type, /*isNegative=*/false); |
127 | case Token::minus: { |
128 | consumeToken(kind: Token::minus); |
129 | if (getToken().is(k: Token::integer)) |
130 | return parseDecOrHexAttr(type, /*isNegative=*/true); |
131 | if (getToken().is(k: Token::floatliteral)) |
132 | return parseFloatAttr(type, /*isNegative=*/true); |
133 | |
134 | return (emitWrongTokenError( |
135 | message: "expected constant integer or floating point value" ), |
136 | nullptr); |
137 | } |
138 | |
139 | // Parse a location attribute. |
140 | case Token::kw_loc: { |
141 | consumeToken(kind: Token::kw_loc); |
142 | |
143 | LocationAttr locAttr; |
144 | if (parseToken(expectedToken: Token::l_paren, message: "expected '(' in inline location" ) || |
145 | parseLocationInstance(loc&: locAttr) || |
146 | parseToken(expectedToken: Token::r_paren, message: "expected ')' in inline location" )) |
147 | return Attribute(); |
148 | return locAttr; |
149 | } |
150 | |
151 | // Parse a sparse elements attribute. |
152 | case Token::kw_sparse: |
153 | return parseSparseElementsAttr(attrType: type); |
154 | |
155 | // Parse a strided layout attribute. |
156 | case Token::kw_strided: |
157 | return parseStridedLayoutAttr(); |
158 | |
159 | // Parse a distinct attribute. |
160 | case Token::kw_distinct: |
161 | return parseDistinctAttr(type); |
162 | |
163 | // Parse a string attribute. |
164 | case Token::string: { |
165 | auto val = getToken().getStringValue(); |
166 | consumeToken(kind: Token::string); |
167 | // Parse the optional trailing colon type if one wasn't explicitly provided. |
168 | if (!type && consumeIf(kind: Token::colon) && !(type = parseType())) |
169 | return Attribute(); |
170 | |
171 | return type ? StringAttr::get(val, type) |
172 | : StringAttr::get(getContext(), val); |
173 | } |
174 | |
175 | // Parse a symbol reference attribute. |
176 | case Token::at_identifier: { |
177 | // When populating the parser state, this is a list of locations for all of |
178 | // the nested references. |
179 | SmallVector<SMRange> referenceLocations; |
180 | if (state.asmState) |
181 | referenceLocations.push_back(Elt: getToken().getLocRange()); |
182 | |
183 | // Parse the top-level reference. |
184 | std::string nameStr = getToken().getSymbolReference(); |
185 | consumeToken(kind: Token::at_identifier); |
186 | |
187 | // Parse any nested references. |
188 | std::vector<FlatSymbolRefAttr> nestedRefs; |
189 | while (getToken().is(k: Token::colon)) { |
190 | // Check for the '::' prefix. |
191 | const char *curPointer = getToken().getLoc().getPointer(); |
192 | consumeToken(kind: Token::colon); |
193 | if (!consumeIf(kind: Token::colon)) { |
194 | if (getToken().isNot(k1: Token::eof, k2: Token::error)) { |
195 | state.lex.resetPointer(newPointer: curPointer); |
196 | consumeToken(); |
197 | } |
198 | break; |
199 | } |
200 | // Parse the reference itself. |
201 | auto curLoc = getToken().getLoc(); |
202 | if (getToken().isNot(k: Token::at_identifier)) { |
203 | emitError(loc: curLoc, message: "expected nested symbol reference identifier" ); |
204 | return Attribute(); |
205 | } |
206 | |
207 | // If we are populating the assembly state, add the location for this |
208 | // reference. |
209 | if (state.asmState) |
210 | referenceLocations.push_back(Elt: getToken().getLocRange()); |
211 | |
212 | std::string nameStr = getToken().getSymbolReference(); |
213 | consumeToken(kind: Token::at_identifier); |
214 | nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); |
215 | } |
216 | SymbolRefAttr symbolRefAttr = |
217 | SymbolRefAttr::get(getContext(), nameStr, nestedRefs); |
218 | |
219 | // If we are populating the assembly state, record this symbol reference. |
220 | if (state.asmState) |
221 | state.asmState->addUses(symbolRefAttr, referenceLocations); |
222 | return symbolRefAttr; |
223 | } |
224 | |
225 | // Parse a 'unit' attribute. |
226 | case Token::kw_unit: |
227 | consumeToken(kind: Token::kw_unit); |
228 | return builder.getUnitAttr(); |
229 | |
230 | // Handle completion of an attribute. |
231 | case Token::code_complete: |
232 | if (getToken().isCodeCompletionFor(kind: Token::hash_identifier)) |
233 | return parseExtendedAttr(type); |
234 | return codeCompleteAttribute(); |
235 | |
236 | default: |
237 | // Parse a type attribute. We parse `Optional` here to allow for providing a |
238 | // better error message. |
239 | Type type; |
240 | OptionalParseResult result = parseOptionalType(type); |
241 | if (!result.has_value()) |
242 | return emitWrongTokenError(message: "expected attribute value" ), Attribute(); |
243 | return failed(*result) ? Attribute() : TypeAttr::get(type); |
244 | } |
245 | } |
246 | |
247 | /// Parse an optional attribute with the provided type. |
248 | OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, |
249 | Type type) { |
250 | switch (getToken().getKind()) { |
251 | case Token::at_identifier: |
252 | case Token::floatliteral: |
253 | case Token::integer: |
254 | case Token::hash_identifier: |
255 | case Token::kw_affine_map: |
256 | case Token::kw_affine_set: |
257 | case Token::kw_dense: |
258 | case Token::kw_dense_resource: |
259 | case Token::kw_false: |
260 | case Token::kw_loc: |
261 | case Token::kw_sparse: |
262 | case Token::kw_true: |
263 | case Token::kw_unit: |
264 | case Token::l_brace: |
265 | case Token::l_square: |
266 | case Token::minus: |
267 | case Token::string: |
268 | attribute = parseAttribute(type); |
269 | return success(isSuccess: attribute != nullptr); |
270 | |
271 | default: |
272 | // Parse an optional type attribute. |
273 | Type type; |
274 | OptionalParseResult result = parseOptionalType(type); |
275 | if (result.has_value() && succeeded(*result)) |
276 | attribute = TypeAttr::get(type); |
277 | return result; |
278 | } |
279 | } |
280 | OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute, |
281 | Type type) { |
282 | return parseOptionalAttributeWithToken(kind: Token::l_square, attr&: attribute, type); |
283 | } |
284 | OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute, |
285 | Type type) { |
286 | return parseOptionalAttributeWithToken(kind: Token::string, attr&: attribute, type); |
287 | } |
288 | OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result, |
289 | Type type) { |
290 | return parseOptionalAttributeWithToken(kind: Token::at_identifier, attr&: result, type); |
291 | } |
292 | |
293 | /// Attribute dictionary. |
294 | /// |
295 | /// attribute-dict ::= `{` `}` |
296 | /// | `{` attribute-entry (`,` attribute-entry)* `}` |
297 | /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value |
298 | /// |
299 | ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { |
300 | llvm::SmallDenseSet<StringAttr> seenKeys; |
301 | auto parseElt = [&]() -> ParseResult { |
302 | // The name of an attribute can either be a bare identifier, or a string. |
303 | std::optional<StringAttr> nameId; |
304 | if (getToken().is(k: Token::string)) |
305 | nameId = builder.getStringAttr(getToken().getStringValue()); |
306 | else if (getToken().isAny(k1: Token::bare_identifier, k2: Token::inttype) || |
307 | getToken().isKeyword()) |
308 | nameId = builder.getStringAttr(getTokenSpelling()); |
309 | else |
310 | return emitWrongTokenError(message: "expected attribute name" ); |
311 | |
312 | if (nameId->empty()) |
313 | return emitError(message: "expected valid attribute name" ); |
314 | |
315 | if (!seenKeys.insert(*nameId).second) |
316 | return emitError(message: "duplicate key '" ) |
317 | << nameId->getValue() << "' in dictionary attribute" ; |
318 | consumeToken(); |
319 | |
320 | // Lazy load a dialect in the context if there is a possible namespace. |
321 | auto splitName = nameId->strref().split('.'); |
322 | if (!splitName.second.empty()) |
323 | getContext()->getOrLoadDialect(splitName.first); |
324 | |
325 | // Try to parse the '=' for the attribute value. |
326 | if (!consumeIf(kind: Token::equal)) { |
327 | // If there is no '=', we treat this as a unit attribute. |
328 | attributes.push_back(newAttribute: {*nameId, builder.getUnitAttr()}); |
329 | return success(); |
330 | } |
331 | |
332 | auto attr = parseAttribute(); |
333 | if (!attr) |
334 | return failure(); |
335 | attributes.push_back(newAttribute: {*nameId, attr}); |
336 | return success(); |
337 | }; |
338 | |
339 | return parseCommaSeparatedList(delimiter: Delimiter::Braces, parseElementFn: parseElt, |
340 | contextMessage: " in attribute dictionary" ); |
341 | } |
342 | |
343 | /// Parse a float attribute. |
344 | Attribute Parser::parseFloatAttr(Type type, bool isNegative) { |
345 | auto val = getToken().getFloatingPointValue(); |
346 | if (!val) |
347 | return (emitError(message: "floating point value too large for attribute" ), nullptr); |
348 | consumeToken(kind: Token::floatliteral); |
349 | if (!type) { |
350 | // Default to F64 when no type is specified. |
351 | if (!consumeIf(kind: Token::colon)) |
352 | type = builder.getF64Type(); |
353 | else if (!(type = parseType())) |
354 | return nullptr; |
355 | } |
356 | if (!isa<FloatType>(Val: type)) |
357 | return (emitError(message: "floating point value not valid for specified type" ), |
358 | nullptr); |
359 | return FloatAttr::get(type, isNegative ? -*val : *val); |
360 | } |
361 | |
362 | /// Construct an APint from a parsed value, a known attribute type and |
363 | /// sign. |
364 | static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative, |
365 | StringRef spelling) { |
366 | // Parse the integer value into an APInt that is big enough to hold the value. |
367 | APInt result; |
368 | bool isHex = spelling.size() > 1 && spelling[1] == 'x'; |
369 | if (spelling.getAsInteger(Radix: isHex ? 0 : 10, Result&: result)) |
370 | return std::nullopt; |
371 | |
372 | // Extend or truncate the bitwidth to the right size. |
373 | unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth |
374 | : type.getIntOrFloatBitWidth(); |
375 | |
376 | if (width > result.getBitWidth()) { |
377 | result = result.zext(width); |
378 | } else if (width < result.getBitWidth()) { |
379 | // The parser can return an unnecessarily wide result with leading zeros. |
380 | // This isn't a problem, but truncating off bits is bad. |
381 | if (result.countl_zero() < result.getBitWidth() - width) |
382 | return std::nullopt; |
383 | |
384 | result = result.trunc(width); |
385 | } |
386 | |
387 | if (width == 0) { |
388 | // 0 bit integers cannot be negative and manipulation of their sign bit will |
389 | // assert, so short-cut validation here. |
390 | if (isNegative) |
391 | return std::nullopt; |
392 | } else if (isNegative) { |
393 | // The value is negative, we have an overflow if the sign bit is not set |
394 | // in the negated apInt. |
395 | result.negate(); |
396 | if (!result.isSignBitSet()) |
397 | return std::nullopt; |
398 | } else if ((type.isSignedInteger() || type.isIndex()) && |
399 | result.isSignBitSet()) { |
400 | // The value is a positive signed integer or index, |
401 | // we have an overflow if the sign bit is set. |
402 | return std::nullopt; |
403 | } |
404 | |
405 | return result; |
406 | } |
407 | |
408 | /// Parse a decimal or a hexadecimal literal, which can be either an integer |
409 | /// or a float attribute. |
410 | Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { |
411 | Token tok = getToken(); |
412 | StringRef spelling = tok.getSpelling(); |
413 | SMLoc loc = tok.getLoc(); |
414 | |
415 | consumeToken(kind: Token::integer); |
416 | if (!type) { |
417 | // Default to i64 if not type is specified. |
418 | if (!consumeIf(kind: Token::colon)) |
419 | type = builder.getIntegerType(64); |
420 | else if (!(type = parseType())) |
421 | return nullptr; |
422 | } |
423 | |
424 | if (auto floatType = dyn_cast<FloatType>(Val&: type)) { |
425 | std::optional<APFloat> result; |
426 | if (failed(result: parseFloatFromIntegerLiteral(result, tok, isNegative, |
427 | semantics: floatType.getFloatSemantics(), |
428 | typeSizeInBits: floatType.getWidth()))) |
429 | return Attribute(); |
430 | return FloatAttr::get(floatType, *result); |
431 | } |
432 | |
433 | if (!isa<IntegerType, IndexType>(Val: type)) |
434 | return emitError(loc, message: "integer literal not valid for specified type" ), |
435 | nullptr; |
436 | |
437 | if (isNegative && type.isUnsignedInteger()) { |
438 | emitError(loc, |
439 | message: "negative integer literal not valid for unsigned integer type" ); |
440 | return nullptr; |
441 | } |
442 | |
443 | std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); |
444 | if (!apInt) |
445 | return emitError(loc, message: "integer constant out of range for attribute" ), |
446 | nullptr; |
447 | return builder.getIntegerAttr(type, *apInt); |
448 | } |
449 | |
450 | //===----------------------------------------------------------------------===// |
451 | // TensorLiteralParser |
452 | //===----------------------------------------------------------------------===// |
453 | |
454 | /// Parse elements values stored within a hex string. On success, the values are |
455 | /// stored into 'result'. |
456 | static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, |
457 | std::string &result) { |
458 | if (std::optional<std::string> value = tok.getHexStringValue()) { |
459 | result = std::move(*value); |
460 | return success(); |
461 | } |
462 | return parser.emitError( |
463 | loc: tok.getLoc(), message: "expected string containing hex digits starting with `0x`" ); |
464 | } |
465 | |
466 | namespace { |
467 | /// This class implements a parser for TensorLiterals. A tensor literal is |
468 | /// either a single element (e.g, 5) or a multi-dimensional list of elements |
469 | /// (e.g., [[5, 5]]). |
470 | class TensorLiteralParser { |
471 | public: |
472 | TensorLiteralParser(Parser &p) : p(p) {} |
473 | |
474 | /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser |
475 | /// may also parse a tensor literal that is store as a hex string. |
476 | ParseResult parse(bool allowHex); |
477 | |
478 | /// Build a dense attribute instance with the parsed elements and the given |
479 | /// shaped type. |
480 | DenseElementsAttr getAttr(SMLoc loc, ShapedType type); |
481 | |
482 | ArrayRef<int64_t> getShape() const { return shape; } |
483 | |
484 | private: |
485 | /// Get the parsed elements for an integer attribute. |
486 | ParseResult getIntAttrElements(SMLoc loc, Type eltTy, |
487 | std::vector<APInt> &intValues); |
488 | |
489 | /// Get the parsed elements for a float attribute. |
490 | ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy, |
491 | std::vector<APFloat> &floatValues); |
492 | |
493 | /// Build a Dense String attribute for the given type. |
494 | DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); |
495 | |
496 | /// Build a Dense attribute with hex data for the given type. |
497 | DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); |
498 | |
499 | /// Parse a single element, returning failure if it isn't a valid element |
500 | /// literal. For example: |
501 | /// parseElement(1) -> Success, 1 |
502 | /// parseElement([1]) -> Failure |
503 | ParseResult parseElement(); |
504 | |
505 | /// Parse a list of either lists or elements, returning the dimensions of the |
506 | /// parsed sub-tensors in dims. For example: |
507 | /// parseList([1, 2, 3]) -> Success, [3] |
508 | /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] |
509 | /// parseList([[1, 2], 3]) -> Failure |
510 | /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure |
511 | ParseResult parseList(SmallVectorImpl<int64_t> &dims); |
512 | |
513 | /// Parse a literal that was printed as a hex string. |
514 | ParseResult parseHexElements(); |
515 | |
516 | Parser &p; |
517 | |
518 | /// The shape inferred from the parsed elements. |
519 | SmallVector<int64_t, 4> shape; |
520 | |
521 | /// Storage used when parsing elements, this is a pair of <is_negated, token>. |
522 | std::vector<std::pair<bool, Token>> storage; |
523 | |
524 | /// Storage used when parsing elements that were stored as hex values. |
525 | std::optional<Token> hexStorage; |
526 | }; |
527 | } // namespace |
528 | |
529 | /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser |
530 | /// may also parse a tensor literal that is store as a hex string. |
531 | ParseResult TensorLiteralParser::parse(bool allowHex) { |
532 | // If hex is allowed, check for a string literal. |
533 | if (allowHex && p.getToken().is(k: Token::string)) { |
534 | hexStorage = p.getToken(); |
535 | p.consumeToken(kind: Token::string); |
536 | return success(); |
537 | } |
538 | // Otherwise, parse a list or an individual element. |
539 | if (p.getToken().is(k: Token::l_square)) |
540 | return parseList(dims&: shape); |
541 | return parseElement(); |
542 | } |
543 | |
544 | /// Build a dense attribute instance with the parsed elements and the given |
545 | /// shaped type. |
546 | DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { |
547 | Type eltType = type.getElementType(); |
548 | |
549 | // Check to see if we parse the literal from a hex string. |
550 | if (hexStorage && |
551 | (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType))) |
552 | return getHexAttr(loc, type); |
553 | |
554 | // Check that the parsed storage size has the same number of elements to the |
555 | // type, or is a known splat. |
556 | if (!shape.empty() && getShape() != type.getShape()) { |
557 | p.emitError(loc) << "inferred shape of elements literal ([" << getShape() |
558 | << "]) does not match type ([" << type.getShape() << "])" ; |
559 | return nullptr; |
560 | } |
561 | |
562 | // Handle the case where no elements were parsed. |
563 | if (!hexStorage && storage.empty() && type.getNumElements()) { |
564 | p.emitError(loc) << "parsed zero elements, but type (" << type |
565 | << ") expected at least 1" ; |
566 | return nullptr; |
567 | } |
568 | |
569 | // Handle complex types in the specific element type cases below. |
570 | bool isComplex = false; |
571 | if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) { |
572 | eltType = complexTy.getElementType(); |
573 | isComplex = true; |
574 | } |
575 | |
576 | // Handle integer and index types. |
577 | if (eltType.isIntOrIndex()) { |
578 | std::vector<APInt> intValues; |
579 | if (failed(result: getIntAttrElements(loc, eltTy: eltType, intValues))) |
580 | return nullptr; |
581 | if (isComplex) { |
582 | // If this is a complex, treat the parsed values as complex values. |
583 | auto complexData = llvm::ArrayRef( |
584 | reinterpret_cast<std::complex<APInt> *>(intValues.data()), |
585 | intValues.size() / 2); |
586 | return DenseElementsAttr::get(type, complexData); |
587 | } |
588 | return DenseElementsAttr::get(type, intValues); |
589 | } |
590 | // Handle floating point types. |
591 | if (FloatType floatTy = dyn_cast<FloatType>(Val&: eltType)) { |
592 | std::vector<APFloat> floatValues; |
593 | if (failed(result: getFloatAttrElements(loc, eltTy: floatTy, floatValues))) |
594 | return nullptr; |
595 | if (isComplex) { |
596 | // If this is a complex, treat the parsed values as complex values. |
597 | auto complexData = llvm::ArrayRef( |
598 | reinterpret_cast<std::complex<APFloat> *>(floatValues.data()), |
599 | floatValues.size() / 2); |
600 | return DenseElementsAttr::get(type, complexData); |
601 | } |
602 | return DenseElementsAttr::get(type, floatValues); |
603 | } |
604 | |
605 | // Other types are assumed to be string representations. |
606 | return getStringAttr(loc, type, type.getElementType()); |
607 | } |
608 | |
609 | /// Build a Dense Integer attribute for the given type. |
610 | ParseResult |
611 | TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, |
612 | std::vector<APInt> &intValues) { |
613 | intValues.reserve(n: storage.size()); |
614 | bool isUintType = eltTy.isUnsignedInteger(); |
615 | for (const auto &signAndToken : storage) { |
616 | bool isNegative = signAndToken.first; |
617 | const Token &token = signAndToken.second; |
618 | auto tokenLoc = token.getLoc(); |
619 | |
620 | if (isNegative && isUintType) { |
621 | return p.emitError(loc: tokenLoc) |
622 | << "expected unsigned integer elements, but parsed negative value" ; |
623 | } |
624 | |
625 | // Check to see if floating point values were parsed. |
626 | if (token.is(k: Token::floatliteral)) { |
627 | return p.emitError(loc: tokenLoc) |
628 | << "expected integer elements, but parsed floating-point" ; |
629 | } |
630 | |
631 | assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && |
632 | "unexpected token type" ); |
633 | if (token.isAny(k1: Token::kw_true, k2: Token::kw_false)) { |
634 | if (!eltTy.isInteger(width: 1)) { |
635 | return p.emitError(loc: tokenLoc) |
636 | << "expected i1 type for 'true' or 'false' values" ; |
637 | } |
638 | APInt apInt(1, token.is(k: Token::kw_true), /*isSigned=*/false); |
639 | intValues.push_back(x: apInt); |
640 | continue; |
641 | } |
642 | |
643 | // Create APInt values for each element with the correct bitwidth. |
644 | std::optional<APInt> apInt = |
645 | buildAttributeAPInt(type: eltTy, isNegative, spelling: token.getSpelling()); |
646 | if (!apInt) |
647 | return p.emitError(loc: tokenLoc, message: "integer constant out of range for type" ); |
648 | intValues.push_back(x: *apInt); |
649 | } |
650 | return success(); |
651 | } |
652 | |
653 | /// Build a Dense Float attribute for the given type. |
654 | ParseResult |
655 | TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, |
656 | std::vector<APFloat> &floatValues) { |
657 | floatValues.reserve(n: storage.size()); |
658 | for (const auto &signAndToken : storage) { |
659 | bool isNegative = signAndToken.first; |
660 | const Token &token = signAndToken.second; |
661 | |
662 | // Handle hexadecimal float literals. |
663 | if (token.is(k: Token::integer) && token.getSpelling().starts_with(Prefix: "0x" )) { |
664 | std::optional<APFloat> result; |
665 | if (failed(result: p.parseFloatFromIntegerLiteral(result, tok: token, isNegative, |
666 | semantics: eltTy.getFloatSemantics(), |
667 | typeSizeInBits: eltTy.getWidth()))) |
668 | return failure(); |
669 | |
670 | floatValues.push_back(x: *result); |
671 | continue; |
672 | } |
673 | |
674 | // Check to see if any decimal integers or booleans were parsed. |
675 | if (!token.is(k: Token::floatliteral)) |
676 | return p.emitError() |
677 | << "expected floating-point elements, but parsed integer" ; |
678 | |
679 | // Build the float values from tokens. |
680 | auto val = token.getFloatingPointValue(); |
681 | if (!val) |
682 | return p.emitError(message: "floating point value too large for attribute" ); |
683 | |
684 | APFloat apVal(isNegative ? -*val : *val); |
685 | if (!eltTy.isF64()) { |
686 | bool unused; |
687 | apVal.convert(ToSemantics: eltTy.getFloatSemantics(), RM: APFloat::rmNearestTiesToEven, |
688 | losesInfo: &unused); |
689 | } |
690 | floatValues.push_back(x: apVal); |
691 | } |
692 | return success(); |
693 | } |
694 | |
695 | /// Build a Dense String attribute for the given type. |
696 | DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, |
697 | Type eltTy) { |
698 | if (hexStorage.has_value()) { |
699 | auto stringValue = hexStorage->getStringValue(); |
700 | return DenseStringElementsAttr::get(type, {stringValue}); |
701 | } |
702 | |
703 | std::vector<std::string> stringValues; |
704 | std::vector<StringRef> stringRefValues; |
705 | stringValues.reserve(n: storage.size()); |
706 | stringRefValues.reserve(n: storage.size()); |
707 | |
708 | for (auto val : storage) { |
709 | stringValues.push_back(x: val.second.getStringValue()); |
710 | stringRefValues.emplace_back(args&: stringValues.back()); |
711 | } |
712 | |
713 | return DenseStringElementsAttr::get(type, stringRefValues); |
714 | } |
715 | |
716 | /// Build a Dense attribute with hex data for the given type. |
717 | DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { |
718 | Type elementType = type.getElementType(); |
719 | if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) { |
720 | p.emitError(loc) |
721 | << "expected floating-point, integer, or complex element type, got " |
722 | << elementType; |
723 | return nullptr; |
724 | } |
725 | |
726 | std::string data; |
727 | if (parseElementAttrHexValues(parser&: p, tok: *hexStorage, result&: data)) |
728 | return nullptr; |
729 | |
730 | ArrayRef<char> rawData(data.data(), data.size()); |
731 | bool detectedSplat = false; |
732 | if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { |
733 | p.emitError(loc) << "elements hex data size is invalid for provided type: " |
734 | << type; |
735 | return nullptr; |
736 | } |
737 | |
738 | if (llvm::endianness::native == llvm::endianness::big) { |
739 | // Convert endianess in big-endian(BE) machines. `rawData` is |
740 | // little-endian(LE) because HEX in raw data of dense element attribute |
741 | // is always LE format. It is converted into BE here to be used in BE |
742 | // machines. |
743 | SmallVector<char, 64> outDataVec(rawData.size()); |
744 | MutableArrayRef<char> convRawData(outDataVec); |
745 | DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( |
746 | rawData, convRawData, type); |
747 | return DenseElementsAttr::getFromRawBuffer(type, convRawData); |
748 | } |
749 | |
750 | return DenseElementsAttr::getFromRawBuffer(type, rawData); |
751 | } |
752 | |
753 | ParseResult TensorLiteralParser::parseElement() { |
754 | switch (p.getToken().getKind()) { |
755 | // Parse a boolean element. |
756 | case Token::kw_true: |
757 | case Token::kw_false: |
758 | case Token::floatliteral: |
759 | case Token::integer: |
760 | storage.emplace_back(/*isNegative=*/args: false, args: p.getToken()); |
761 | p.consumeToken(); |
762 | break; |
763 | |
764 | // Parse a signed integer or a negative floating-point element. |
765 | case Token::minus: |
766 | p.consumeToken(kind: Token::minus); |
767 | if (!p.getToken().isAny(k1: Token::floatliteral, k2: Token::integer)) |
768 | return p.emitError(message: "expected integer or floating point literal" ); |
769 | storage.emplace_back(/*isNegative=*/args: true, args: p.getToken()); |
770 | p.consumeToken(); |
771 | break; |
772 | |
773 | case Token::string: |
774 | storage.emplace_back(/*isNegative=*/args: false, args: p.getToken()); |
775 | p.consumeToken(); |
776 | break; |
777 | |
778 | // Parse a complex element of the form '(' element ',' element ')'. |
779 | case Token::l_paren: |
780 | p.consumeToken(kind: Token::l_paren); |
781 | if (parseElement() || |
782 | p.parseToken(expectedToken: Token::comma, message: "expected ',' between complex elements" ) || |
783 | parseElement() || |
784 | p.parseToken(expectedToken: Token::r_paren, message: "expected ')' after complex elements" )) |
785 | return failure(); |
786 | break; |
787 | |
788 | default: |
789 | return p.emitError(message: "expected element literal of primitive type" ); |
790 | } |
791 | |
792 | return success(); |
793 | } |
794 | |
795 | /// Parse a list of either lists or elements, returning the dimensions of the |
796 | /// parsed sub-tensors in dims. For example: |
797 | /// parseList([1, 2, 3]) -> Success, [3] |
798 | /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] |
799 | /// parseList([[1, 2], 3]) -> Failure |
800 | /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure |
801 | ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { |
802 | auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, |
803 | const SmallVectorImpl<int64_t> &newDims) -> ParseResult { |
804 | if (prevDims == newDims) |
805 | return success(); |
806 | return p.emitError(message: "tensor literal is invalid; ranks are not consistent " |
807 | "between elements" ); |
808 | }; |
809 | |
810 | bool first = true; |
811 | SmallVector<int64_t, 4> newDims; |
812 | unsigned size = 0; |
813 | auto parseOneElement = [&]() -> ParseResult { |
814 | SmallVector<int64_t, 4> thisDims; |
815 | if (p.getToken().getKind() == Token::l_square) { |
816 | if (parseList(dims&: thisDims)) |
817 | return failure(); |
818 | } else if (parseElement()) { |
819 | return failure(); |
820 | } |
821 | ++size; |
822 | if (!first) |
823 | return checkDims(newDims, thisDims); |
824 | newDims = thisDims; |
825 | first = false; |
826 | return success(); |
827 | }; |
828 | if (p.parseCommaSeparatedList(delimiter: Parser::Delimiter::Square, parseElementFn: parseOneElement)) |
829 | return failure(); |
830 | |
831 | // Return the sublists' dimensions with 'size' prepended. |
832 | dims.clear(); |
833 | dims.push_back(Elt: size); |
834 | dims.append(in_start: newDims.begin(), in_end: newDims.end()); |
835 | return success(); |
836 | } |
837 | |
838 | //===----------------------------------------------------------------------===// |
839 | // DenseArrayAttr Parser |
840 | //===----------------------------------------------------------------------===// |
841 | |
842 | namespace { |
843 | /// A generic dense array element parser. It parsers integer and floating point |
844 | /// elements. |
845 | class DenseArrayElementParser { |
846 | public: |
847 | explicit DenseArrayElementParser(Type type) : type(type) {} |
848 | |
849 | /// Parse an integer element. |
850 | ParseResult parseIntegerElement(Parser &p); |
851 | |
852 | /// Parse a floating point element. |
853 | ParseResult parseFloatElement(Parser &p); |
854 | |
855 | /// Convert the current contents to a dense array. |
856 | DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); } |
857 | |
858 | private: |
859 | /// Append the raw data of an APInt to the result. |
860 | void append(const APInt &data); |
861 | |
862 | /// The array element type. |
863 | Type type; |
864 | /// The resultant byte array representing the contents of the array. |
865 | std::vector<char> rawData; |
866 | /// The number of elements in the array. |
867 | int64_t size = 0; |
868 | }; |
869 | } // namespace |
870 | |
871 | void DenseArrayElementParser::append(const APInt &data) { |
872 | if (data.getBitWidth()) { |
873 | assert(data.getBitWidth() % 8 == 0); |
874 | unsigned byteSize = data.getBitWidth() / 8; |
875 | size_t offset = rawData.size(); |
876 | rawData.insert(position: rawData.end(), n: byteSize, x: 0); |
877 | llvm::StoreIntToMemory( |
878 | IntVal: data, Dst: reinterpret_cast<uint8_t *>(rawData.data() + offset), StoreBytes: byteSize); |
879 | } |
880 | ++size; |
881 | } |
882 | |
883 | ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { |
884 | bool isNegative = p.consumeIf(kind: Token::minus); |
885 | |
886 | // Parse an integer literal as an APInt. |
887 | std::optional<APInt> value; |
888 | StringRef spelling = p.getToken().getSpelling(); |
889 | if (p.getToken().isAny(k1: Token::kw_true, k2: Token::kw_false)) { |
890 | if (!type.isInteger(width: 1)) |
891 | return p.emitError(message: "expected i1 type for 'true' or 'false' values" ); |
892 | value = APInt(/*numBits=*/8, p.getToken().is(k: Token::kw_true), |
893 | !type.isUnsignedInteger()); |
894 | p.consumeToken(); |
895 | } else if (p.consumeIf(kind: Token::integer)) { |
896 | value = buildAttributeAPInt(type, isNegative, spelling); |
897 | if (!value) |
898 | return p.emitError(message: "integer constant out of range" ); |
899 | } else { |
900 | return p.emitError(message: "expected integer literal" ); |
901 | } |
902 | append(data: *value); |
903 | return success(); |
904 | } |
905 | |
906 | ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { |
907 | bool isNegative = p.consumeIf(kind: Token::minus); |
908 | |
909 | Token token = p.getToken(); |
910 | std::optional<APFloat> result; |
911 | auto floatType = cast<FloatType>(Val&: type); |
912 | if (p.consumeIf(kind: Token::integer)) { |
913 | // Parse an integer literal as a float. |
914 | if (p.parseFloatFromIntegerLiteral(result, tok: token, isNegative, |
915 | semantics: floatType.getFloatSemantics(), |
916 | typeSizeInBits: floatType.getWidth())) |
917 | return failure(); |
918 | } else if (p.consumeIf(kind: Token::floatliteral)) { |
919 | // Parse a floating point literal. |
920 | std::optional<double> val = token.getFloatingPointValue(); |
921 | if (!val) |
922 | return failure(); |
923 | result = APFloat(isNegative ? -*val : *val); |
924 | if (!type.isF64()) { |
925 | bool unused; |
926 | result->convert(ToSemantics: floatType.getFloatSemantics(), |
927 | RM: APFloat::rmNearestTiesToEven, losesInfo: &unused); |
928 | } |
929 | } else { |
930 | return p.emitError(message: "expected integer or floating point literal" ); |
931 | } |
932 | |
933 | append(data: result->bitcastToAPInt()); |
934 | return success(); |
935 | } |
936 | |
937 | /// Parse a dense array attribute. |
938 | Attribute Parser::parseDenseArrayAttr(Type attrType) { |
939 | consumeToken(kind: Token::kw_array); |
940 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'array'" )) |
941 | return {}; |
942 | |
943 | SMLoc typeLoc = getToken().getLoc(); |
944 | Type eltType = parseType(); |
945 | if (!eltType) { |
946 | emitError(loc: typeLoc, message: "expected an integer or floating point type" ); |
947 | return {}; |
948 | } |
949 | |
950 | // Only bool or integer and floating point elements divisible by bytes are |
951 | // supported. |
952 | if (!eltType.isIntOrIndexOrFloat()) { |
953 | emitError(loc: typeLoc, message: "expected integer or float type, got: " ) << eltType; |
954 | return {}; |
955 | } |
956 | if (!eltType.isInteger(width: 1) && eltType.getIntOrFloatBitWidth() % 8 != 0) { |
957 | emitError(loc: typeLoc, message: "element type bitwidth must be a multiple of 8" ); |
958 | return {}; |
959 | } |
960 | |
961 | // Check for empty list. |
962 | if (consumeIf(Token::greater)) |
963 | return DenseArrayAttr::get(eltType, 0, {}); |
964 | |
965 | if (parseToken(expectedToken: Token::colon, message: "expected ':' after dense array type" )) |
966 | return {}; |
967 | |
968 | DenseArrayElementParser eltParser(eltType); |
969 | if (eltType.isIntOrIndex()) { |
970 | if (parseCommaSeparatedList( |
971 | parseElementFn: [&] { return eltParser.parseIntegerElement(p&: *this); })) |
972 | return {}; |
973 | } else { |
974 | if (parseCommaSeparatedList( |
975 | parseElementFn: [&] { return eltParser.parseFloatElement(p&: *this); })) |
976 | return {}; |
977 | } |
978 | if (parseToken(expectedToken: Token::greater, message: "expected '>' to close an array attribute" )) |
979 | return {}; |
980 | return eltParser.getAttr(); |
981 | } |
982 | |
983 | /// Parse a dense elements attribute. |
984 | Attribute Parser::parseDenseElementsAttr(Type attrType) { |
985 | auto attribLoc = getToken().getLoc(); |
986 | consumeToken(kind: Token::kw_dense); |
987 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'dense'" )) |
988 | return nullptr; |
989 | |
990 | // Parse the literal data if necessary. |
991 | TensorLiteralParser literalParser(*this); |
992 | if (!consumeIf(kind: Token::greater)) { |
993 | if (literalParser.parse(/*allowHex=*/true) || |
994 | parseToken(expectedToken: Token::greater, message: "expected '>'" )) |
995 | return nullptr; |
996 | } |
997 | |
998 | // If the type is specified `parseElementsLiteralType` will not parse a type. |
999 | // Use the attribute location as the location for error reporting in that |
1000 | // case. |
1001 | auto loc = attrType ? attribLoc : getToken().getLoc(); |
1002 | auto type = parseElementsLiteralType(attrType); |
1003 | if (!type) |
1004 | return nullptr; |
1005 | return literalParser.getAttr(loc, type); |
1006 | } |
1007 | |
1008 | Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { |
1009 | auto loc = getToken().getLoc(); |
1010 | consumeToken(kind: Token::kw_dense_resource); |
1011 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'dense_resource'" )) |
1012 | return nullptr; |
1013 | |
1014 | // Parse the resource handle. |
1015 | FailureOr<AsmDialectResourceHandle> rawHandle = |
1016 | parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>()); |
1017 | if (failed(result: rawHandle) || parseToken(expectedToken: Token::greater, message: "expected '>'" )) |
1018 | return nullptr; |
1019 | |
1020 | auto *handle = dyn_cast<DenseResourceElementsHandle>(Val: &*rawHandle); |
1021 | if (!handle) |
1022 | return emitError(loc, message: "invalid `dense_resource` handle type" ), nullptr; |
1023 | |
1024 | // Parse the type of the attribute if the user didn't provide one. |
1025 | SMLoc typeLoc = loc; |
1026 | if (!attrType) { |
1027 | typeLoc = getToken().getLoc(); |
1028 | if (parseToken(expectedToken: Token::colon, message: "expected ':'" ) || !(attrType = parseType())) |
1029 | return nullptr; |
1030 | } |
1031 | |
1032 | ShapedType shapedType = dyn_cast<ShapedType>(attrType); |
1033 | if (!shapedType) { |
1034 | emitError(loc: typeLoc, message: "`dense_resource` expected a shaped type" ); |
1035 | return nullptr; |
1036 | } |
1037 | |
1038 | return DenseResourceElementsAttr::get(shapedType, *handle); |
1039 | } |
1040 | |
1041 | /// Shaped type for elements attribute. |
1042 | /// |
1043 | /// elements-literal-type ::= vector-type | ranked-tensor-type |
1044 | /// |
1045 | /// This method also checks the type has static shape. |
1046 | ShapedType Parser::parseElementsLiteralType(Type type) { |
1047 | // If the user didn't provide a type, parse the colon type for the literal. |
1048 | if (!type) { |
1049 | if (parseToken(expectedToken: Token::colon, message: "expected ':'" )) |
1050 | return nullptr; |
1051 | if (!(type = parseType())) |
1052 | return nullptr; |
1053 | } |
1054 | |
1055 | auto sType = dyn_cast<ShapedType>(type); |
1056 | if (!sType) { |
1057 | emitError(message: "elements literal must be a shaped type" ); |
1058 | return nullptr; |
1059 | } |
1060 | |
1061 | if (!sType.hasStaticShape()) |
1062 | return (emitError(message: "elements literal type must have static shape" ), nullptr); |
1063 | |
1064 | return sType; |
1065 | } |
1066 | |
1067 | /// Parse a sparse elements attribute. |
1068 | Attribute Parser::parseSparseElementsAttr(Type attrType) { |
1069 | SMLoc loc = getToken().getLoc(); |
1070 | consumeToken(kind: Token::kw_sparse); |
1071 | if (parseToken(expectedToken: Token::less, message: "Expected '<' after 'sparse'" )) |
1072 | return nullptr; |
1073 | |
1074 | // Check for the case where all elements are sparse. The indices are |
1075 | // represented by a 2-dimensional shape where the second dimension is the rank |
1076 | // of the type. |
1077 | Type indiceEltType = builder.getIntegerType(64); |
1078 | if (consumeIf(kind: Token::greater)) { |
1079 | ShapedType type = parseElementsLiteralType(attrType); |
1080 | if (!type) |
1081 | return nullptr; |
1082 | |
1083 | // Construct the sparse elements attr using zero element indice/value |
1084 | // attributes. |
1085 | ShapedType indicesType = |
1086 | RankedTensorType::get({0, type.getRank()}, indiceEltType); |
1087 | ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); |
1088 | return getChecked<SparseElementsAttr>( |
1089 | loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()), |
1090 | DenseElementsAttr::get(valuesType, ArrayRef<Attribute>())); |
1091 | } |
1092 | |
1093 | /// Parse the indices. We don't allow hex values here as we may need to use |
1094 | /// the inferred shape. |
1095 | auto indicesLoc = getToken().getLoc(); |
1096 | TensorLiteralParser indiceParser(*this); |
1097 | if (indiceParser.parse(/*allowHex=*/false)) |
1098 | return nullptr; |
1099 | |
1100 | if (parseToken(expectedToken: Token::comma, message: "expected ','" )) |
1101 | return nullptr; |
1102 | |
1103 | /// Parse the values. |
1104 | auto valuesLoc = getToken().getLoc(); |
1105 | TensorLiteralParser valuesParser(*this); |
1106 | if (valuesParser.parse(/*allowHex=*/true)) |
1107 | return nullptr; |
1108 | |
1109 | if (parseToken(expectedToken: Token::greater, message: "expected '>'" )) |
1110 | return nullptr; |
1111 | |
1112 | auto type = parseElementsLiteralType(attrType); |
1113 | if (!type) |
1114 | return nullptr; |
1115 | |
1116 | // If the indices are a splat, i.e. the literal parser parsed an element and |
1117 | // not a list, we set the shape explicitly. The indices are represented by a |
1118 | // 2-dimensional shape where the second dimension is the rank of the type. |
1119 | // Given that the parsed indices is a splat, we know that we only have one |
1120 | // indice and thus one for the first dimension. |
1121 | ShapedType indicesType; |
1122 | if (indiceParser.getShape().empty()) { |
1123 | indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); |
1124 | } else { |
1125 | // Otherwise, set the shape to the one parsed by the literal parser. |
1126 | indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); |
1127 | } |
1128 | auto indices = indiceParser.getAttr(indicesLoc, indicesType); |
1129 | |
1130 | // If the values are a splat, set the shape explicitly based on the number of |
1131 | // indices. The number of indices is encoded in the first dimension of the |
1132 | // indice shape type. |
1133 | auto valuesEltType = type.getElementType(); |
1134 | ShapedType valuesType = |
1135 | valuesParser.getShape().empty() |
1136 | ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) |
1137 | : RankedTensorType::get(valuesParser.getShape(), valuesEltType); |
1138 | auto values = valuesParser.getAttr(valuesLoc, valuesType); |
1139 | |
1140 | // Build the sparse elements attribute by the indices and values. |
1141 | return getChecked<SparseElementsAttr>(loc, type, indices, values); |
1142 | } |
1143 | |
1144 | Attribute Parser::parseStridedLayoutAttr() { |
1145 | // Callback for error emissing at the keyword token location. |
1146 | llvm::SMLoc loc = getToken().getLoc(); |
1147 | auto errorEmitter = [&] { return emitError(loc); }; |
1148 | |
1149 | consumeToken(kind: Token::kw_strided); |
1150 | if (failed(result: parseToken(expectedToken: Token::less, message: "expected '<' after 'strided'" )) || |
1151 | failed(result: parseToken(expectedToken: Token::l_square, message: "expected '['" ))) |
1152 | return nullptr; |
1153 | |
1154 | // Parses either an integer token or a question mark token. Reports an error |
1155 | // and returns std::nullopt if the current token is neither. The integer token |
1156 | // must fit into int64_t limits. |
1157 | auto parseStrideOrOffset = [&]() -> std::optional<int64_t> { |
1158 | if (consumeIf(Token::question)) |
1159 | return ShapedType::kDynamic; |
1160 | |
1161 | SMLoc loc = getToken().getLoc(); |
1162 | auto emitWrongTokenError = [&] { |
1163 | emitError(loc, message: "expected a 64-bit signed integer or '?'" ); |
1164 | return std::nullopt; |
1165 | }; |
1166 | |
1167 | bool negative = consumeIf(kind: Token::minus); |
1168 | |
1169 | if (getToken().is(k: Token::integer)) { |
1170 | std::optional<uint64_t> value = getToken().getUInt64IntegerValue(); |
1171 | if (!value || |
1172 | *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) |
1173 | return emitWrongTokenError(); |
1174 | consumeToken(); |
1175 | auto result = static_cast<int64_t>(*value); |
1176 | if (negative) |
1177 | result = -result; |
1178 | |
1179 | return result; |
1180 | } |
1181 | |
1182 | return emitWrongTokenError(); |
1183 | }; |
1184 | |
1185 | // Parse strides. |
1186 | SmallVector<int64_t> strides; |
1187 | if (!getToken().is(k: Token::r_square)) { |
1188 | do { |
1189 | std::optional<int64_t> stride = parseStrideOrOffset(); |
1190 | if (!stride) |
1191 | return nullptr; |
1192 | strides.push_back(Elt: *stride); |
1193 | } while (consumeIf(kind: Token::comma)); |
1194 | } |
1195 | |
1196 | if (failed(result: parseToken(expectedToken: Token::r_square, message: "expected ']'" ))) |
1197 | return nullptr; |
1198 | |
1199 | // Fast path in absence of offset. |
1200 | if (consumeIf(kind: Token::greater)) { |
1201 | if (failed(StridedLayoutAttr::verify(errorEmitter, |
1202 | /*offset=*/0, strides))) |
1203 | return nullptr; |
1204 | return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides); |
1205 | } |
1206 | |
1207 | if (failed(result: parseToken(expectedToken: Token::comma, message: "expected ','" )) || |
1208 | failed(result: parseToken(expectedToken: Token::kw_offset, message: "expected 'offset' after comma" )) || |
1209 | failed(result: parseToken(expectedToken: Token::colon, message: "expected ':' after 'offset'" ))) |
1210 | return nullptr; |
1211 | |
1212 | std::optional<int64_t> offset = parseStrideOrOffset(); |
1213 | if (!offset || failed(result: parseToken(expectedToken: Token::greater, message: "expected '>'" ))) |
1214 | return nullptr; |
1215 | |
1216 | if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides))) |
1217 | return nullptr; |
1218 | return StridedLayoutAttr::get(getContext(), *offset, strides); |
1219 | // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides); |
1220 | } |
1221 | |
1222 | /// Parse a distinct attribute. |
1223 | /// |
1224 | /// distinct-attribute ::= `distinct` |
1225 | /// `[` integer-literal `]<` attribute-value `>` |
1226 | /// |
1227 | Attribute Parser::parseDistinctAttr(Type type) { |
1228 | SMLoc loc = getToken().getLoc(); |
1229 | consumeToken(kind: Token::kw_distinct); |
1230 | if (parseToken(expectedToken: Token::l_square, message: "expected '[' after 'distinct'" )) |
1231 | return {}; |
1232 | |
1233 | // Parse the distinct integer identifier. |
1234 | Token token = getToken(); |
1235 | if (parseToken(expectedToken: Token::integer, message: "expected distinct ID" )) |
1236 | return {}; |
1237 | std::optional<uint64_t> value = token.getUInt64IntegerValue(); |
1238 | if (!value) { |
1239 | emitError(message: "expected an unsigned 64-bit integer" ); |
1240 | return {}; |
1241 | } |
1242 | |
1243 | // Parse the referenced attribute. |
1244 | if (parseToken(expectedToken: Token::r_square, message: "expected ']' to close distinct ID" ) || |
1245 | parseToken(expectedToken: Token::less, message: "expected '<' after distinct ID" )) |
1246 | return {}; |
1247 | |
1248 | Attribute referencedAttr; |
1249 | if (getToken().is(k: Token::greater)) { |
1250 | consumeToken(); |
1251 | referencedAttr = builder.getUnitAttr(); |
1252 | } else { |
1253 | referencedAttr = parseAttribute(type); |
1254 | if (!referencedAttr) { |
1255 | emitError(message: "expected attribute" ); |
1256 | return {}; |
1257 | } |
1258 | |
1259 | if (parseToken(expectedToken: Token::greater, message: "expected '>' to close distinct attribute" )) |
1260 | return {}; |
1261 | } |
1262 | |
1263 | // Add the distinct attribute to the parser state, if it has not been parsed |
1264 | // before. Otherwise, check if the parsed reference attribute matches the one |
1265 | // found in the parser state. |
1266 | DenseMap<uint64_t, DistinctAttr> &distinctAttrs = |
1267 | state.symbols.distinctAttributes; |
1268 | auto it = distinctAttrs.find(Val: *value); |
1269 | if (it == distinctAttrs.end()) { |
1270 | DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr); |
1271 | it = distinctAttrs.try_emplace(Key: *value, Args&: distinctAttr).first; |
1272 | } else if (it->getSecond().getReferencedAttr() != referencedAttr) { |
1273 | emitError(loc, message: "referenced attribute does not match previous definition: " ) |
1274 | << it->getSecond().getReferencedAttr(); |
1275 | return {}; |
1276 | } |
1277 | |
1278 | return it->getSecond(); |
1279 | } |
1280 | |