1 | //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the parser for the MLIR Types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Parser.h" |
14 | |
15 | #include "mlir/AsmParser/AsmParserState.h" |
16 | #include "mlir/IR/AffineMap.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/IR/BuiltinDialect.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "mlir/IR/DialectResourceBlobManager.h" |
21 | #include "mlir/IR/IntegerSet.h" |
22 | #include <optional> |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::detail; |
26 | |
27 | /// Parse an arbitrary attribute. |
28 | /// |
29 | /// attribute-value ::= `unit` |
30 | /// | bool-literal |
31 | /// | integer-literal (`:` (index-type | integer-type))? |
32 | /// | float-literal (`:` float-type)? |
33 | /// | string-literal (`:` type)? |
34 | /// | type |
35 | /// | `[` `:` (integer-type | float-type) tensor-literal `]` |
36 | /// | `[` (attribute-value (`,` attribute-value)*)? `]` |
37 | /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` |
38 | /// | symbol-ref-id (`::` symbol-ref-id)* |
39 | /// | `dense` `<` tensor-literal `>` `:` |
40 | /// (tensor-type | vector-type) |
41 | /// | `sparse` `<` attribute-value `,` attribute-value `>` |
42 | /// `:` (tensor-type | vector-type) |
43 | /// | `strided` `<` `[` comma-separated-int-or-question `]` |
44 | /// (`,` `offset` `:` integer-literal)? `>` |
45 | /// | distinct-attribute |
46 | /// | extended-attribute |
47 | /// |
48 | Attribute Parser::parseAttribute(Type type) { |
49 | switch (getToken().getKind()) { |
50 | // Parse an AffineMap or IntegerSet attribute. |
51 | case Token::kw_affine_map: { |
52 | consumeToken(kind: Token::kw_affine_map); |
53 | |
54 | AffineMap map; |
55 | if (parseToken(expectedToken: Token::less, message: "expected '<' in affine map") || |
56 | parseAffineMapReference(map) || |
57 | parseToken(expectedToken: Token::greater, message: "expected '>' in affine map")) |
58 | return Attribute(); |
59 | return AffineMapAttr::get(map); |
60 | } |
61 | case Token::kw_affine_set: { |
62 | consumeToken(kind: Token::kw_affine_set); |
63 | |
64 | IntegerSet set; |
65 | if (parseToken(expectedToken: Token::less, message: "expected '<' in integer set") || |
66 | parseIntegerSetReference(set) || |
67 | parseToken(expectedToken: Token::greater, message: "expected '>' in integer set")) |
68 | return Attribute(); |
69 | return IntegerSetAttr::get(set); |
70 | } |
71 | |
72 | // Parse an array attribute. |
73 | case Token::l_square: { |
74 | consumeToken(kind: Token::l_square); |
75 | SmallVector<Attribute, 4> elements; |
76 | auto parseElt = [&]() -> ParseResult { |
77 | elements.push_back(Elt: parseAttribute()); |
78 | return elements.back() ? success() : failure(); |
79 | }; |
80 | |
81 | if (parseCommaSeparatedListUntil(rightToken: Token::r_square, parseElement: parseElt)) |
82 | return nullptr; |
83 | return builder.getArrayAttr(elements); |
84 | } |
85 | |
86 | // Parse a boolean attribute. |
87 | case Token::kw_false: |
88 | consumeToken(kind: Token::kw_false); |
89 | return builder.getBoolAttr(value: false); |
90 | case Token::kw_true: |
91 | consumeToken(kind: Token::kw_true); |
92 | return builder.getBoolAttr(value: true); |
93 | |
94 | // Parse a dense elements attribute. |
95 | case Token::kw_dense: |
96 | return parseDenseElementsAttr(attrType: type); |
97 | |
98 | // Parse a dense resource elements attribute. |
99 | case Token::kw_dense_resource: |
100 | return parseDenseResourceElementsAttr(attrType: type); |
101 | |
102 | // Parse a dense array attribute. |
103 | case Token::kw_array: |
104 | return parseDenseArrayAttr(type); |
105 | |
106 | // Parse a dictionary attribute. |
107 | case Token::l_brace: { |
108 | NamedAttrList elements; |
109 | if (parseAttributeDict(attributes&: elements)) |
110 | return nullptr; |
111 | return elements.getDictionary(getContext()); |
112 | } |
113 | |
114 | // Parse an extended attribute, i.e. alias or dialect attribute. |
115 | case Token::hash_identifier: |
116 | return parseExtendedAttr(type); |
117 | |
118 | // Parse floating point and integer attributes. |
119 | case Token::floatliteral: |
120 | return parseFloatAttr(type, /*isNegative=*/false); |
121 | case Token::integer: |
122 | return parseDecOrHexAttr(type, /*isNegative=*/false); |
123 | case Token::minus: { |
124 | consumeToken(kind: Token::minus); |
125 | if (getToken().is(k: Token::integer)) |
126 | return parseDecOrHexAttr(type, /*isNegative=*/true); |
127 | if (getToken().is(k: Token::floatliteral)) |
128 | return parseFloatAttr(type, /*isNegative=*/true); |
129 | |
130 | return (emitWrongTokenError( |
131 | message: "expected constant integer or floating point value"), |
132 | nullptr); |
133 | } |
134 | |
135 | // Parse a location attribute. |
136 | case Token::kw_loc: { |
137 | consumeToken(kind: Token::kw_loc); |
138 | |
139 | LocationAttr locAttr; |
140 | if (parseToken(expectedToken: Token::l_paren, message: "expected '(' in inline location") || |
141 | parseLocationInstance(loc&: locAttr) || |
142 | parseToken(expectedToken: Token::r_paren, message: "expected ')' in inline location")) |
143 | return Attribute(); |
144 | return locAttr; |
145 | } |
146 | |
147 | // Parse a sparse elements attribute. |
148 | case Token::kw_sparse: |
149 | return parseSparseElementsAttr(attrType: type); |
150 | |
151 | // Parse a strided layout attribute. |
152 | case Token::kw_strided: |
153 | return parseStridedLayoutAttr(); |
154 | |
155 | // Parse a distinct attribute. |
156 | case Token::kw_distinct: |
157 | return parseDistinctAttr(type); |
158 | |
159 | // Parse a string attribute. |
160 | case Token::string: { |
161 | auto val = getToken().getStringValue(); |
162 | consumeToken(kind: Token::string); |
163 | // Parse the optional trailing colon type if one wasn't explicitly provided. |
164 | if (!type && consumeIf(kind: Token::colon) && !(type = parseType())) |
165 | return Attribute(); |
166 | |
167 | return type ? StringAttr::get(val, type) |
168 | : StringAttr::get(getContext(), val); |
169 | } |
170 | |
171 | // Parse a symbol reference attribute. |
172 | case Token::at_identifier: { |
173 | // When populating the parser state, this is a list of locations for all of |
174 | // the nested references. |
175 | SmallVector<SMRange> referenceLocations; |
176 | if (state.asmState) |
177 | referenceLocations.push_back(Elt: getToken().getLocRange()); |
178 | |
179 | // Parse the top-level reference. |
180 | std::string nameStr = getToken().getSymbolReference(); |
181 | consumeToken(kind: Token::at_identifier); |
182 | |
183 | // Parse any nested references. |
184 | std::vector<FlatSymbolRefAttr> nestedRefs; |
185 | while (getToken().is(k: Token::colon)) { |
186 | // Check for the '::' prefix. |
187 | const char *curPointer = getToken().getLoc().getPointer(); |
188 | consumeToken(kind: Token::colon); |
189 | if (!consumeIf(kind: Token::colon)) { |
190 | if (getToken().isNot(k1: Token::eof, k2: Token::error)) { |
191 | state.lex.resetPointer(newPointer: curPointer); |
192 | consumeToken(); |
193 | } |
194 | break; |
195 | } |
196 | // Parse the reference itself. |
197 | auto curLoc = getToken().getLoc(); |
198 | if (getToken().isNot(k: Token::at_identifier)) { |
199 | emitError(loc: curLoc, message: "expected nested symbol reference identifier"); |
200 | return Attribute(); |
201 | } |
202 | |
203 | // If we are populating the assembly state, add the location for this |
204 | // reference. |
205 | if (state.asmState) |
206 | referenceLocations.push_back(Elt: getToken().getLocRange()); |
207 | |
208 | std::string nameStr = getToken().getSymbolReference(); |
209 | consumeToken(kind: Token::at_identifier); |
210 | nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); |
211 | } |
212 | SymbolRefAttr symbolRefAttr = |
213 | SymbolRefAttr::get(getContext(), nameStr, nestedRefs); |
214 | |
215 | // If we are populating the assembly state, record this symbol reference. |
216 | if (state.asmState) |
217 | state.asmState->addUses(symbolRefAttr, referenceLocations); |
218 | return symbolRefAttr; |
219 | } |
220 | |
221 | // Parse a 'unit' attribute. |
222 | case Token::kw_unit: |
223 | consumeToken(kind: Token::kw_unit); |
224 | return builder.getUnitAttr(); |
225 | |
226 | // Handle completion of an attribute. |
227 | case Token::code_complete: |
228 | if (getToken().isCodeCompletionFor(kind: Token::hash_identifier)) |
229 | return parseExtendedAttr(type); |
230 | return codeCompleteAttribute(); |
231 | |
232 | default: |
233 | // Parse a type attribute. We parse `Optional` here to allow for providing a |
234 | // better error message. |
235 | Type type; |
236 | OptionalParseResult result = parseOptionalType(type); |
237 | if (!result.has_value()) |
238 | return emitWrongTokenError(message: "expected attribute value"), Attribute(); |
239 | return failed(*result) ? Attribute() : TypeAttr::get(type); |
240 | } |
241 | } |
242 | |
243 | /// Parse an optional attribute with the provided type. |
244 | OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, |
245 | Type type) { |
246 | switch (getToken().getKind()) { |
247 | case Token::at_identifier: |
248 | case Token::floatliteral: |
249 | case Token::integer: |
250 | case Token::hash_identifier: |
251 | case Token::kw_affine_map: |
252 | case Token::kw_affine_set: |
253 | case Token::kw_dense: |
254 | case Token::kw_dense_resource: |
255 | case Token::kw_false: |
256 | case Token::kw_loc: |
257 | case Token::kw_sparse: |
258 | case Token::kw_true: |
259 | case Token::kw_unit: |
260 | case Token::l_brace: |
261 | case Token::l_square: |
262 | case Token::minus: |
263 | case Token::string: |
264 | attribute = parseAttribute(type); |
265 | return success(IsSuccess: attribute != nullptr); |
266 | |
267 | default: |
268 | // Parse an optional type attribute. |
269 | Type type; |
270 | OptionalParseResult result = parseOptionalType(type); |
271 | if (result.has_value() && succeeded(*result)) |
272 | attribute = TypeAttr::get(type); |
273 | return result; |
274 | } |
275 | } |
276 | OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute, |
277 | Type type) { |
278 | return parseOptionalAttributeWithToken(kind: Token::l_square, attr&: attribute, type); |
279 | } |
280 | OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute, |
281 | Type type) { |
282 | return parseOptionalAttributeWithToken(kind: Token::string, attr&: attribute, type); |
283 | } |
284 | OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result, |
285 | Type type) { |
286 | return parseOptionalAttributeWithToken(kind: Token::at_identifier, attr&: result, type); |
287 | } |
288 | |
289 | /// Attribute dictionary. |
290 | /// |
291 | /// attribute-dict ::= `{` `}` |
292 | /// | `{` attribute-entry (`,` attribute-entry)* `}` |
293 | /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value |
294 | /// |
295 | ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { |
296 | llvm::SmallDenseSet<StringAttr> seenKeys; |
297 | auto parseElt = [&]() -> ParseResult { |
298 | // The name of an attribute can either be a bare identifier, or a string. |
299 | std::optional<StringAttr> nameId; |
300 | if (getToken().is(k: Token::string)) |
301 | nameId = builder.getStringAttr(getToken().getStringValue()); |
302 | else if (getToken().isAny(k1: Token::bare_identifier, k2: Token::inttype) || |
303 | getToken().isKeyword()) |
304 | nameId = builder.getStringAttr(getTokenSpelling()); |
305 | else |
306 | return emitWrongTokenError(message: "expected attribute name"); |
307 | |
308 | if (nameId->empty()) |
309 | return emitError(message: "expected valid attribute name"); |
310 | |
311 | if (!seenKeys.insert(*nameId).second) |
312 | return emitError(message: "duplicate key '") |
313 | << nameId->getValue() << "' in dictionary attribute"; |
314 | consumeToken(); |
315 | |
316 | // Lazy load a dialect in the context if there is a possible namespace. |
317 | auto splitName = nameId->strref().split('.'); |
318 | if (!splitName.second.empty()) |
319 | getContext()->getOrLoadDialect(splitName.first); |
320 | |
321 | // Try to parse the '=' for the attribute value. |
322 | if (!consumeIf(kind: Token::equal)) { |
323 | // If there is no '=', we treat this as a unit attribute. |
324 | attributes.push_back(newAttribute: {*nameId, builder.getUnitAttr()}); |
325 | return success(); |
326 | } |
327 | |
328 | auto attr = parseAttribute(); |
329 | if (!attr) |
330 | return failure(); |
331 | attributes.push_back(newAttribute: {*nameId, attr}); |
332 | return success(); |
333 | }; |
334 | |
335 | return parseCommaSeparatedList(delimiter: Delimiter::Braces, parseElementFn: parseElt, |
336 | contextMessage: " in attribute dictionary"); |
337 | } |
338 | |
339 | /// Parse a float attribute. |
340 | Attribute Parser::parseFloatAttr(Type type, bool isNegative) { |
341 | auto val = getToken().getFloatingPointValue(); |
342 | if (!val) |
343 | return (emitError(message: "floating point value too large for attribute"), nullptr); |
344 | consumeToken(kind: Token::floatliteral); |
345 | if (!type) { |
346 | // Default to F64 when no type is specified. |
347 | if (!consumeIf(kind: Token::colon)) |
348 | type = builder.getF64Type(); |
349 | else if (!(type = parseType())) |
350 | return nullptr; |
351 | } |
352 | if (!isa<FloatType>(Val: type)) |
353 | return (emitError(message: "floating point value not valid for specified type"), |
354 | nullptr); |
355 | return FloatAttr::get(type, isNegative ? -*val : *val); |
356 | } |
357 | |
358 | /// Construct an APint from a parsed value, a known attribute type and |
359 | /// sign. |
360 | static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative, |
361 | StringRef spelling) { |
362 | // Parse the integer value into an APInt that is big enough to hold the value. |
363 | APInt result; |
364 | bool isHex = spelling.size() > 1 && spelling[1] == 'x'; |
365 | if (spelling.getAsInteger(Radix: isHex ? 0 : 10, Result&: result)) |
366 | return std::nullopt; |
367 | |
368 | // Extend or truncate the bitwidth to the right size. |
369 | unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth |
370 | : type.getIntOrFloatBitWidth(); |
371 | |
372 | if (width > result.getBitWidth()) { |
373 | result = result.zext(width); |
374 | } else if (width < result.getBitWidth()) { |
375 | // The parser can return an unnecessarily wide result with leading zeros. |
376 | // This isn't a problem, but truncating off bits is bad. |
377 | if (result.countl_zero() < result.getBitWidth() - width) |
378 | return std::nullopt; |
379 | |
380 | result = result.trunc(width); |
381 | } |
382 | |
383 | if (width == 0) { |
384 | // 0 bit integers cannot be negative and manipulation of their sign bit will |
385 | // assert, so short-cut validation here. |
386 | if (isNegative) |
387 | return std::nullopt; |
388 | } else if (isNegative) { |
389 | // The value is negative, we have an overflow if the sign bit is not set |
390 | // in the negated apInt. |
391 | result.negate(); |
392 | if (!result.isSignBitSet()) |
393 | return std::nullopt; |
394 | } else if ((type.isSignedInteger() || type.isIndex()) && |
395 | result.isSignBitSet()) { |
396 | // The value is a positive signed integer or index, |
397 | // we have an overflow if the sign bit is set. |
398 | return std::nullopt; |
399 | } |
400 | |
401 | return result; |
402 | } |
403 | |
404 | /// Parse a decimal or a hexadecimal literal, which can be either an integer |
405 | /// or a float attribute. |
406 | Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { |
407 | Token tok = getToken(); |
408 | StringRef spelling = tok.getSpelling(); |
409 | SMLoc loc = tok.getLoc(); |
410 | |
411 | consumeToken(kind: Token::integer); |
412 | if (!type) { |
413 | // Default to i64 if not type is specified. |
414 | if (!consumeIf(kind: Token::colon)) |
415 | type = builder.getIntegerType(64); |
416 | else if (!(type = parseType())) |
417 | return nullptr; |
418 | } |
419 | |
420 | if (auto floatType = dyn_cast<FloatType>(type)) { |
421 | std::optional<APFloat> result; |
422 | if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, |
423 | semantics: floatType.getFloatSemantics()))) |
424 | return Attribute(); |
425 | return FloatAttr::get(floatType, *result); |
426 | } |
427 | |
428 | if (!isa<IntegerType, IndexType>(Val: type)) |
429 | return emitError(loc, message: "integer literal not valid for specified type"), |
430 | nullptr; |
431 | |
432 | if (isNegative && type.isUnsignedInteger()) { |
433 | emitError(loc, |
434 | message: "negative integer literal not valid for unsigned integer type"); |
435 | return nullptr; |
436 | } |
437 | |
438 | std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); |
439 | if (!apInt) |
440 | return emitError(loc, message: "integer constant out of range for attribute"), |
441 | nullptr; |
442 | return builder.getIntegerAttr(type, *apInt); |
443 | } |
444 | |
445 | //===----------------------------------------------------------------------===// |
446 | // TensorLiteralParser |
447 | //===----------------------------------------------------------------------===// |
448 | |
449 | /// Parse elements values stored within a hex string. On success, the values are |
450 | /// stored into 'result'. |
451 | static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, |
452 | std::string &result) { |
453 | if (std::optional<std::string> value = tok.getHexStringValue()) { |
454 | result = std::move(*value); |
455 | return success(); |
456 | } |
457 | return parser.emitError( |
458 | loc: tok.getLoc(), message: "expected string containing hex digits starting with `0x`"); |
459 | } |
460 | |
461 | namespace { |
462 | /// This class implements a parser for TensorLiterals. A tensor literal is |
463 | /// either a single element (e.g, 5) or a multi-dimensional list of elements |
464 | /// (e.g., [[5, 5]]). |
465 | class TensorLiteralParser { |
466 | public: |
467 | TensorLiteralParser(Parser &p) : p(p) {} |
468 | |
469 | /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser |
470 | /// may also parse a tensor literal that is store as a hex string. |
471 | ParseResult parse(bool allowHex); |
472 | |
473 | /// Build a dense attribute instance with the parsed elements and the given |
474 | /// shaped type. |
475 | DenseElementsAttr getAttr(SMLoc loc, ShapedType type); |
476 | |
477 | ArrayRef<int64_t> getShape() const { return shape; } |
478 | |
479 | private: |
480 | /// Get the parsed elements for an integer attribute. |
481 | ParseResult getIntAttrElements(SMLoc loc, Type eltTy, |
482 | std::vector<APInt> &intValues); |
483 | |
484 | /// Get the parsed elements for a float attribute. |
485 | ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy, |
486 | std::vector<APFloat> &floatValues); |
487 | |
488 | /// Build a Dense String attribute for the given type. |
489 | DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); |
490 | |
491 | /// Build a Dense attribute with hex data for the given type. |
492 | DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); |
493 | |
494 | /// Parse a single element, returning failure if it isn't a valid element |
495 | /// literal. For example: |
496 | /// parseElement(1) -> Success, 1 |
497 | /// parseElement([1]) -> Failure |
498 | ParseResult parseElement(); |
499 | |
500 | /// Parse a list of either lists or elements, returning the dimensions of the |
501 | /// parsed sub-tensors in dims. For example: |
502 | /// parseList([1, 2, 3]) -> Success, [3] |
503 | /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] |
504 | /// parseList([[1, 2], 3]) -> Failure |
505 | /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure |
506 | ParseResult parseList(SmallVectorImpl<int64_t> &dims); |
507 | |
508 | /// Parse a literal that was printed as a hex string. |
509 | ParseResult parseHexElements(); |
510 | |
511 | Parser &p; |
512 | |
513 | /// The shape inferred from the parsed elements. |
514 | SmallVector<int64_t, 4> shape; |
515 | |
516 | /// Storage used when parsing elements, this is a pair of <is_negated, token>. |
517 | std::vector<std::pair<bool, Token>> storage; |
518 | |
519 | /// Storage used when parsing elements that were stored as hex values. |
520 | std::optional<Token> hexStorage; |
521 | }; |
522 | } // namespace |
523 | |
524 | /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser |
525 | /// may also parse a tensor literal that is store as a hex string. |
526 | ParseResult TensorLiteralParser::parse(bool allowHex) { |
527 | // If hex is allowed, check for a string literal. |
528 | if (allowHex && p.getToken().is(k: Token::string)) { |
529 | hexStorage = p.getToken(); |
530 | p.consumeToken(kind: Token::string); |
531 | return success(); |
532 | } |
533 | // Otherwise, parse a list or an individual element. |
534 | if (p.getToken().is(k: Token::l_square)) |
535 | return parseList(dims&: shape); |
536 | return parseElement(); |
537 | } |
538 | |
539 | /// Build a dense attribute instance with the parsed elements and the given |
540 | /// shaped type. |
541 | DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { |
542 | Type eltType = type.getElementType(); |
543 | |
544 | // Check to see if we parse the literal from a hex string. |
545 | if (hexStorage && |
546 | (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType))) |
547 | return getHexAttr(loc, type); |
548 | |
549 | // Check that the parsed storage size has the same number of elements to the |
550 | // type, or is a known splat. |
551 | if (!shape.empty() && getShape() != type.getShape()) { |
552 | p.emitError(loc) << "inferred shape of elements literal (["<< getShape() |
553 | << "]) does not match type (["<< type.getShape() << "])"; |
554 | return nullptr; |
555 | } |
556 | |
557 | // Handle the case where no elements were parsed. |
558 | if (!hexStorage && storage.empty() && type.getNumElements()) { |
559 | p.emitError(loc) << "parsed zero elements, but type ("<< type |
560 | << ") expected at least 1"; |
561 | return nullptr; |
562 | } |
563 | |
564 | // Handle complex types in the specific element type cases below. |
565 | bool isComplex = false; |
566 | if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) { |
567 | eltType = complexTy.getElementType(); |
568 | isComplex = true; |
569 | // Complex types have N*2 elements or complex splat. |
570 | // Empty shape may mean a splat or empty literal, only validate splats. |
571 | bool isSplat = shape.empty() && type.getNumElements() != 0; |
572 | if (isSplat && storage.size() != 2) { |
573 | p.emitError(loc) << "parsed "<< storage.size() << " elements, but type (" |
574 | << complexTy << ") expected 2 elements"; |
575 | return nullptr; |
576 | } |
577 | if (!shape.empty() && |
578 | storage.size() != static_cast<size_t>(type.getNumElements()) * 2) { |
579 | p.emitError(loc) << "parsed "<< storage.size() << " elements, but type (" |
580 | << type << ") expected "<< type.getNumElements() * 2 |
581 | << " elements"; |
582 | return nullptr; |
583 | } |
584 | } |
585 | |
586 | // Handle integer and index types. |
587 | if (eltType.isIntOrIndex()) { |
588 | std::vector<APInt> intValues; |
589 | if (failed(Result: getIntAttrElements(loc, eltTy: eltType, intValues))) |
590 | return nullptr; |
591 | if (isComplex) { |
592 | // If this is a complex, treat the parsed values as complex values. |
593 | auto complexData = llvm::ArrayRef( |
594 | reinterpret_cast<std::complex<APInt> *>(intValues.data()), |
595 | intValues.size() / 2); |
596 | return DenseElementsAttr::get(type, complexData); |
597 | } |
598 | return DenseElementsAttr::get(type, intValues); |
599 | } |
600 | // Handle floating point types. |
601 | if (FloatType floatTy = dyn_cast<FloatType>(eltType)) { |
602 | std::vector<APFloat> floatValues; |
603 | if (failed(getFloatAttrElements(loc, eltTy: floatTy, floatValues))) |
604 | return nullptr; |
605 | if (isComplex) { |
606 | // If this is a complex, treat the parsed values as complex values. |
607 | auto complexData = llvm::ArrayRef( |
608 | reinterpret_cast<std::complex<APFloat> *>(floatValues.data()), |
609 | floatValues.size() / 2); |
610 | return DenseElementsAttr::get(type, complexData); |
611 | } |
612 | return DenseElementsAttr::get(type, floatValues); |
613 | } |
614 | |
615 | // Other types are assumed to be string representations. |
616 | return getStringAttr(loc, type, type.getElementType()); |
617 | } |
618 | |
619 | /// Build a Dense Integer attribute for the given type. |
620 | ParseResult |
621 | TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, |
622 | std::vector<APInt> &intValues) { |
623 | intValues.reserve(n: storage.size()); |
624 | bool isUintType = eltTy.isUnsignedInteger(); |
625 | for (const auto &signAndToken : storage) { |
626 | bool isNegative = signAndToken.first; |
627 | const Token &token = signAndToken.second; |
628 | auto tokenLoc = token.getLoc(); |
629 | |
630 | if (isNegative && isUintType) { |
631 | return p.emitError(loc: tokenLoc) |
632 | << "expected unsigned integer elements, but parsed negative value"; |
633 | } |
634 | |
635 | // Check to see if floating point values were parsed. |
636 | if (token.is(k: Token::floatliteral)) { |
637 | return p.emitError(loc: tokenLoc) |
638 | << "expected integer elements, but parsed floating-point"; |
639 | } |
640 | |
641 | assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && |
642 | "unexpected token type"); |
643 | if (token.isAny(k1: Token::kw_true, k2: Token::kw_false)) { |
644 | if (!eltTy.isInteger(width: 1)) { |
645 | return p.emitError(loc: tokenLoc) |
646 | << "expected i1 type for 'true' or 'false' values"; |
647 | } |
648 | APInt apInt(1, token.is(k: Token::kw_true), /*isSigned=*/false); |
649 | intValues.push_back(x: apInt); |
650 | continue; |
651 | } |
652 | |
653 | // Create APInt values for each element with the correct bitwidth. |
654 | std::optional<APInt> apInt = |
655 | buildAttributeAPInt(type: eltTy, isNegative, spelling: token.getSpelling()); |
656 | if (!apInt) |
657 | return p.emitError(loc: tokenLoc, message: "integer constant out of range for type"); |
658 | intValues.push_back(x: *apInt); |
659 | } |
660 | return success(); |
661 | } |
662 | |
663 | /// Build a Dense Float attribute for the given type. |
664 | ParseResult |
665 | TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, |
666 | std::vector<APFloat> &floatValues) { |
667 | floatValues.reserve(n: storage.size()); |
668 | for (const auto &signAndToken : storage) { |
669 | bool isNegative = signAndToken.first; |
670 | const Token &token = signAndToken.second; |
671 | std::optional<APFloat> result; |
672 | if (failed(p.parseFloatFromLiteral(result, tok: token, isNegative, |
673 | semantics: eltTy.getFloatSemantics()))) |
674 | return failure(); |
675 | floatValues.push_back(x: *result); |
676 | } |
677 | return success(); |
678 | } |
679 | |
680 | /// Build a Dense String attribute for the given type. |
681 | DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, |
682 | Type eltTy) { |
683 | if (hexStorage.has_value()) { |
684 | auto stringValue = hexStorage->getStringValue(); |
685 | return DenseStringElementsAttr::get(type, {stringValue}); |
686 | } |
687 | |
688 | std::vector<std::string> stringValues; |
689 | std::vector<StringRef> stringRefValues; |
690 | stringValues.reserve(n: storage.size()); |
691 | stringRefValues.reserve(n: storage.size()); |
692 | |
693 | for (auto val : storage) { |
694 | if (!val.second.is(k: Token::string)) { |
695 | p.emitError(loc) << "expected string token, got " |
696 | << val.second.getSpelling(); |
697 | return nullptr; |
698 | } |
699 | stringValues.push_back(x: val.second.getStringValue()); |
700 | stringRefValues.emplace_back(args&: stringValues.back()); |
701 | } |
702 | |
703 | return DenseStringElementsAttr::get(type, stringRefValues); |
704 | } |
705 | |
706 | /// Build a Dense attribute with hex data for the given type. |
707 | DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { |
708 | Type elementType = type.getElementType(); |
709 | if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) { |
710 | p.emitError(loc) |
711 | << "expected floating-point, integer, or complex element type, got " |
712 | << elementType; |
713 | return nullptr; |
714 | } |
715 | |
716 | std::string data; |
717 | if (parseElementAttrHexValues(parser&: p, tok: *hexStorage, result&: data)) |
718 | return nullptr; |
719 | |
720 | ArrayRef<char> rawData(data.data(), data.size()); |
721 | bool detectedSplat = false; |
722 | if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { |
723 | p.emitError(loc) << "elements hex data size is invalid for provided type: " |
724 | << type; |
725 | return nullptr; |
726 | } |
727 | |
728 | if (llvm::endianness::native == llvm::endianness::big) { |
729 | // Convert endianess in big-endian(BE) machines. `rawData` is |
730 | // little-endian(LE) because HEX in raw data of dense element attribute |
731 | // is always LE format. It is converted into BE here to be used in BE |
732 | // machines. |
733 | SmallVector<char, 64> outDataVec(rawData.size()); |
734 | MutableArrayRef<char> convRawData(outDataVec); |
735 | DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( |
736 | rawData, convRawData, type); |
737 | return DenseElementsAttr::getFromRawBuffer(type, convRawData); |
738 | } |
739 | |
740 | return DenseElementsAttr::getFromRawBuffer(type, rawData); |
741 | } |
742 | |
743 | ParseResult TensorLiteralParser::parseElement() { |
744 | switch (p.getToken().getKind()) { |
745 | // Parse a boolean element. |
746 | case Token::kw_true: |
747 | case Token::kw_false: |
748 | case Token::floatliteral: |
749 | case Token::integer: |
750 | storage.emplace_back(/*isNegative=*/args: false, args: p.getToken()); |
751 | p.consumeToken(); |
752 | break; |
753 | |
754 | // Parse a signed integer or a negative floating-point element. |
755 | case Token::minus: |
756 | p.consumeToken(kind: Token::minus); |
757 | if (!p.getToken().isAny(k1: Token::floatliteral, k2: Token::integer)) |
758 | return p.emitError(message: "expected integer or floating point literal"); |
759 | storage.emplace_back(/*isNegative=*/args: true, args: p.getToken()); |
760 | p.consumeToken(); |
761 | break; |
762 | |
763 | case Token::string: |
764 | storage.emplace_back(/*isNegative=*/args: false, args: p.getToken()); |
765 | p.consumeToken(); |
766 | break; |
767 | |
768 | // Parse a complex element of the form '(' element ',' element ')'. |
769 | case Token::l_paren: |
770 | p.consumeToken(kind: Token::l_paren); |
771 | if (parseElement() || |
772 | p.parseToken(expectedToken: Token::comma, message: "expected ',' between complex elements") || |
773 | parseElement() || |
774 | p.parseToken(expectedToken: Token::r_paren, message: "expected ')' after complex elements")) |
775 | return failure(); |
776 | break; |
777 | |
778 | default: |
779 | return p.emitError(message: "expected element literal of primitive type"); |
780 | } |
781 | |
782 | return success(); |
783 | } |
784 | |
785 | /// Parse a list of either lists or elements, returning the dimensions of the |
786 | /// parsed sub-tensors in dims. For example: |
787 | /// parseList([1, 2, 3]) -> Success, [3] |
788 | /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] |
789 | /// parseList([[1, 2], 3]) -> Failure |
790 | /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure |
791 | ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { |
792 | auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, |
793 | const SmallVectorImpl<int64_t> &newDims) -> ParseResult { |
794 | if (prevDims == newDims) |
795 | return success(); |
796 | return p.emitError(message: "tensor literal is invalid; ranks are not consistent " |
797 | "between elements"); |
798 | }; |
799 | |
800 | bool first = true; |
801 | SmallVector<int64_t, 4> newDims; |
802 | unsigned size = 0; |
803 | auto parseOneElement = [&]() -> ParseResult { |
804 | SmallVector<int64_t, 4> thisDims; |
805 | if (p.getToken().getKind() == Token::l_square) { |
806 | if (parseList(dims&: thisDims)) |
807 | return failure(); |
808 | } else if (parseElement()) { |
809 | return failure(); |
810 | } |
811 | ++size; |
812 | if (!first) |
813 | return checkDims(newDims, thisDims); |
814 | newDims = thisDims; |
815 | first = false; |
816 | return success(); |
817 | }; |
818 | if (p.parseCommaSeparatedList(delimiter: Parser::Delimiter::Square, parseElementFn: parseOneElement)) |
819 | return failure(); |
820 | |
821 | // Return the sublists' dimensions with 'size' prepended. |
822 | dims.clear(); |
823 | dims.push_back(Elt: size); |
824 | dims.append(in_start: newDims.begin(), in_end: newDims.end()); |
825 | return success(); |
826 | } |
827 | |
828 | //===----------------------------------------------------------------------===// |
829 | // DenseArrayAttr Parser |
830 | //===----------------------------------------------------------------------===// |
831 | |
832 | namespace { |
833 | /// A generic dense array element parser. It parsers integer and floating point |
834 | /// elements. |
835 | class DenseArrayElementParser { |
836 | public: |
837 | explicit DenseArrayElementParser(Type type) : type(type) {} |
838 | |
839 | /// Parse an integer element. |
840 | ParseResult parseIntegerElement(Parser &p); |
841 | |
842 | /// Parse a floating point element. |
843 | ParseResult parseFloatElement(Parser &p); |
844 | |
845 | /// Convert the current contents to a dense array. |
846 | DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); } |
847 | |
848 | private: |
849 | /// Append the raw data of an APInt to the result. |
850 | void append(const APInt &data); |
851 | |
852 | /// The array element type. |
853 | Type type; |
854 | /// The resultant byte array representing the contents of the array. |
855 | std::vector<char> rawData; |
856 | /// The number of elements in the array. |
857 | int64_t size = 0; |
858 | }; |
859 | } // namespace |
860 | |
861 | void DenseArrayElementParser::append(const APInt &data) { |
862 | if (data.getBitWidth()) { |
863 | assert(data.getBitWidth() % 8 == 0); |
864 | unsigned byteSize = data.getBitWidth() / 8; |
865 | size_t offset = rawData.size(); |
866 | rawData.insert(position: rawData.end(), n: byteSize, x: 0); |
867 | llvm::StoreIntToMemory( |
868 | IntVal: data, Dst: reinterpret_cast<uint8_t *>(rawData.data() + offset), StoreBytes: byteSize); |
869 | } |
870 | ++size; |
871 | } |
872 | |
873 | ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { |
874 | bool isNegative = p.consumeIf(kind: Token::minus); |
875 | |
876 | // Parse an integer literal as an APInt. |
877 | std::optional<APInt> value; |
878 | StringRef spelling = p.getToken().getSpelling(); |
879 | if (p.getToken().isAny(k1: Token::kw_true, k2: Token::kw_false)) { |
880 | if (!type.isInteger(width: 1)) |
881 | return p.emitError(message: "expected i1 type for 'true' or 'false' values"); |
882 | value = APInt(/*numBits=*/8, p.getToken().is(k: Token::kw_true), |
883 | !type.isUnsignedInteger()); |
884 | p.consumeToken(); |
885 | } else if (p.consumeIf(kind: Token::integer)) { |
886 | value = buildAttributeAPInt(type, isNegative, spelling); |
887 | if (!value) |
888 | return p.emitError(message: "integer constant out of range"); |
889 | } else { |
890 | return p.emitError(message: "expected integer literal"); |
891 | } |
892 | append(data: *value); |
893 | return success(); |
894 | } |
895 | |
896 | ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { |
897 | bool isNegative = p.consumeIf(kind: Token::minus); |
898 | Token token = p.getToken(); |
899 | std::optional<APFloat> fromIntLit; |
900 | if (failed( |
901 | p.parseFloatFromLiteral(result&: fromIntLit, tok: token, isNegative, |
902 | semantics: cast<FloatType>(type).getFloatSemantics()))) |
903 | return failure(); |
904 | p.consumeToken(); |
905 | append(data: fromIntLit->bitcastToAPInt()); |
906 | return success(); |
907 | } |
908 | |
909 | /// Parse a dense array attribute. |
910 | Attribute Parser::parseDenseArrayAttr(Type attrType) { |
911 | consumeToken(kind: Token::kw_array); |
912 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'array'")) |
913 | return {}; |
914 | |
915 | SMLoc typeLoc = getToken().getLoc(); |
916 | Type eltType = parseType(); |
917 | if (!eltType) { |
918 | emitError(loc: typeLoc, message: "expected an integer or floating point type"); |
919 | return {}; |
920 | } |
921 | |
922 | // Only bool or integer and floating point elements divisible by bytes are |
923 | // supported. |
924 | if (!eltType.isIntOrIndexOrFloat()) { |
925 | emitError(loc: typeLoc, message: "expected integer or float type, got: ") << eltType; |
926 | return {}; |
927 | } |
928 | if (!eltType.isInteger(width: 1) && eltType.getIntOrFloatBitWidth() % 8 != 0) { |
929 | emitError(loc: typeLoc, message: "element type bitwidth must be a multiple of 8"); |
930 | return {}; |
931 | } |
932 | |
933 | // Check for empty list. |
934 | if (consumeIf(Token::greater)) |
935 | return DenseArrayAttr::get(eltType, 0, {}); |
936 | |
937 | if (parseToken(expectedToken: Token::colon, message: "expected ':' after dense array type")) |
938 | return {}; |
939 | |
940 | DenseArrayElementParser eltParser(eltType); |
941 | if (eltType.isIntOrIndex()) { |
942 | if (parseCommaSeparatedList( |
943 | parseElementFn: [&] { return eltParser.parseIntegerElement(p&: *this); })) |
944 | return {}; |
945 | } else { |
946 | if (parseCommaSeparatedList( |
947 | parseElementFn: [&] { return eltParser.parseFloatElement(p&: *this); })) |
948 | return {}; |
949 | } |
950 | if (parseToken(expectedToken: Token::greater, message: "expected '>' to close an array attribute")) |
951 | return {}; |
952 | return eltParser.getAttr(); |
953 | } |
954 | |
955 | /// Parse a dense elements attribute. |
956 | Attribute Parser::parseDenseElementsAttr(Type attrType) { |
957 | auto attribLoc = getToken().getLoc(); |
958 | consumeToken(kind: Token::kw_dense); |
959 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'dense'")) |
960 | return nullptr; |
961 | |
962 | // Parse the literal data if necessary. |
963 | TensorLiteralParser literalParser(*this); |
964 | if (!consumeIf(kind: Token::greater)) { |
965 | if (literalParser.parse(/*allowHex=*/true) || |
966 | parseToken(expectedToken: Token::greater, message: "expected '>'")) |
967 | return nullptr; |
968 | } |
969 | |
970 | auto type = parseElementsLiteralType(attribLoc, attrType); |
971 | if (!type) |
972 | return nullptr; |
973 | return literalParser.getAttr(attribLoc, type); |
974 | } |
975 | |
976 | Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { |
977 | auto loc = getToken().getLoc(); |
978 | consumeToken(kind: Token::kw_dense_resource); |
979 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'dense_resource'")) |
980 | return nullptr; |
981 | |
982 | // Parse the resource handle. |
983 | FailureOr<AsmDialectResourceHandle> rawHandle = |
984 | parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>()); |
985 | if (failed(Result: rawHandle) || parseToken(expectedToken: Token::greater, message: "expected '>'")) |
986 | return nullptr; |
987 | |
988 | auto *handle = dyn_cast<DenseResourceElementsHandle>(Val: &*rawHandle); |
989 | if (!handle) |
990 | return emitError(loc, message: "invalid `dense_resource` handle type"), nullptr; |
991 | |
992 | // Parse the type of the attribute if the user didn't provide one. |
993 | SMLoc typeLoc = loc; |
994 | if (!attrType) { |
995 | typeLoc = getToken().getLoc(); |
996 | if (parseToken(expectedToken: Token::colon, message: "expected ':'") || !(attrType = parseType())) |
997 | return nullptr; |
998 | } |
999 | |
1000 | ShapedType shapedType = dyn_cast<ShapedType>(attrType); |
1001 | if (!shapedType) { |
1002 | emitError(loc: typeLoc, message: "`dense_resource` expected a shaped type"); |
1003 | return nullptr; |
1004 | } |
1005 | |
1006 | return DenseResourceElementsAttr::get(shapedType, *handle); |
1007 | } |
1008 | |
1009 | /// Shaped type for elements attribute. |
1010 | /// |
1011 | /// elements-literal-type ::= vector-type | ranked-tensor-type |
1012 | /// |
1013 | /// This method also checks the type has static shape. |
1014 | ShapedType Parser::parseElementsLiteralType(SMLoc loc, Type type) { |
1015 | // If the user didn't provide a type, parse the colon type for the literal. |
1016 | if (!type) { |
1017 | if (parseToken(expectedToken: Token::colon, message: "expected ':'")) |
1018 | return nullptr; |
1019 | if (!(type = parseType())) |
1020 | return nullptr; |
1021 | } |
1022 | |
1023 | auto sType = dyn_cast<ShapedType>(type); |
1024 | if (!sType) { |
1025 | emitError(loc, message: "elements literal must be a shaped type"); |
1026 | return nullptr; |
1027 | } |
1028 | |
1029 | if (!sType.hasStaticShape()) { |
1030 | emitError(loc, message: "elements literal type must have static shape"); |
1031 | return nullptr; |
1032 | } |
1033 | |
1034 | return sType; |
1035 | } |
1036 | |
1037 | /// Parse a sparse elements attribute. |
1038 | Attribute Parser::parseSparseElementsAttr(Type attrType) { |
1039 | SMLoc loc = getToken().getLoc(); |
1040 | consumeToken(kind: Token::kw_sparse); |
1041 | if (parseToken(expectedToken: Token::less, message: "Expected '<' after 'sparse'")) |
1042 | return nullptr; |
1043 | |
1044 | // Check for the case where all elements are sparse. The indices are |
1045 | // represented by a 2-dimensional shape where the second dimension is the rank |
1046 | // of the type. |
1047 | Type indiceEltType = builder.getIntegerType(64); |
1048 | if (consumeIf(kind: Token::greater)) { |
1049 | ShapedType type = parseElementsLiteralType(loc, attrType); |
1050 | if (!type) |
1051 | return nullptr; |
1052 | |
1053 | // Construct the sparse elements attr using zero element indice/value |
1054 | // attributes. |
1055 | ShapedType indicesType = |
1056 | RankedTensorType::get({0, type.getRank()}, indiceEltType); |
1057 | ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); |
1058 | return getChecked<SparseElementsAttr>( |
1059 | loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()), |
1060 | DenseElementsAttr::get(valuesType, ArrayRef<Attribute>())); |
1061 | } |
1062 | |
1063 | /// Parse the indices. We don't allow hex values here as we may need to use |
1064 | /// the inferred shape. |
1065 | auto indicesLoc = getToken().getLoc(); |
1066 | TensorLiteralParser indiceParser(*this); |
1067 | if (indiceParser.parse(/*allowHex=*/false)) |
1068 | return nullptr; |
1069 | |
1070 | if (parseToken(expectedToken: Token::comma, message: "expected ','")) |
1071 | return nullptr; |
1072 | |
1073 | /// Parse the values. |
1074 | auto valuesLoc = getToken().getLoc(); |
1075 | TensorLiteralParser valuesParser(*this); |
1076 | if (valuesParser.parse(/*allowHex=*/true)) |
1077 | return nullptr; |
1078 | |
1079 | if (parseToken(expectedToken: Token::greater, message: "expected '>'")) |
1080 | return nullptr; |
1081 | |
1082 | auto type = parseElementsLiteralType(loc, attrType); |
1083 | if (!type) |
1084 | return nullptr; |
1085 | |
1086 | // If the indices are a splat, i.e. the literal parser parsed an element and |
1087 | // not a list, we set the shape explicitly. The indices are represented by a |
1088 | // 2-dimensional shape where the second dimension is the rank of the type. |
1089 | // Given that the parsed indices is a splat, we know that we only have one |
1090 | // indice and thus one for the first dimension. |
1091 | ShapedType indicesType; |
1092 | if (indiceParser.getShape().empty()) { |
1093 | indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); |
1094 | } else { |
1095 | // Otherwise, set the shape to the one parsed by the literal parser. |
1096 | indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); |
1097 | } |
1098 | auto indices = indiceParser.getAttr(indicesLoc, indicesType); |
1099 | if (!indices) |
1100 | return nullptr; |
1101 | |
1102 | // If the values are a splat, set the shape explicitly based on the number of |
1103 | // indices. The number of indices is encoded in the first dimension of the |
1104 | // indice shape type. |
1105 | auto valuesEltType = type.getElementType(); |
1106 | ShapedType valuesType = |
1107 | valuesParser.getShape().empty() |
1108 | ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) |
1109 | : RankedTensorType::get(valuesParser.getShape(), valuesEltType); |
1110 | auto values = valuesParser.getAttr(valuesLoc, valuesType); |
1111 | if (!values) |
1112 | return nullptr; |
1113 | |
1114 | // Build the sparse elements attribute by the indices and values. |
1115 | return getChecked<SparseElementsAttr>(loc, type, indices, values); |
1116 | } |
1117 | |
1118 | Attribute Parser::parseStridedLayoutAttr() { |
1119 | // Callback for error emissing at the keyword token location. |
1120 | llvm::SMLoc loc = getToken().getLoc(); |
1121 | auto errorEmitter = [&] { return emitError(loc); }; |
1122 | |
1123 | consumeToken(kind: Token::kw_strided); |
1124 | if (failed(Result: parseToken(expectedToken: Token::less, message: "expected '<' after 'strided'")) || |
1125 | failed(Result: parseToken(expectedToken: Token::l_square, message: "expected '['"))) |
1126 | return nullptr; |
1127 | |
1128 | // Parses either an integer token or a question mark token. Reports an error |
1129 | // and returns std::nullopt if the current token is neither. The integer token |
1130 | // must fit into int64_t limits. |
1131 | auto parseStrideOrOffset = [&]() -> std::optional<int64_t> { |
1132 | if (consumeIf(Token::question)) |
1133 | return ShapedType::kDynamic; |
1134 | |
1135 | SMLoc loc = getToken().getLoc(); |
1136 | auto emitWrongTokenError = [&] { |
1137 | emitError(loc, message: "expected a 64-bit signed integer or '?'"); |
1138 | return std::nullopt; |
1139 | }; |
1140 | |
1141 | bool negative = consumeIf(kind: Token::minus); |
1142 | |
1143 | if (getToken().is(k: Token::integer)) { |
1144 | std::optional<uint64_t> value = getToken().getUInt64IntegerValue(); |
1145 | if (!value || |
1146 | *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) |
1147 | return emitWrongTokenError(); |
1148 | consumeToken(); |
1149 | auto result = static_cast<int64_t>(*value); |
1150 | if (negative) |
1151 | result = -result; |
1152 | |
1153 | return result; |
1154 | } |
1155 | |
1156 | return emitWrongTokenError(); |
1157 | }; |
1158 | |
1159 | // Parse strides. |
1160 | SmallVector<int64_t> strides; |
1161 | if (!getToken().is(k: Token::r_square)) { |
1162 | do { |
1163 | std::optional<int64_t> stride = parseStrideOrOffset(); |
1164 | if (!stride) |
1165 | return nullptr; |
1166 | strides.push_back(Elt: *stride); |
1167 | } while (consumeIf(kind: Token::comma)); |
1168 | } |
1169 | |
1170 | if (failed(Result: parseToken(expectedToken: Token::r_square, message: "expected ']'"))) |
1171 | return nullptr; |
1172 | |
1173 | // Fast path in absence of offset. |
1174 | if (consumeIf(kind: Token::greater)) { |
1175 | if (failed(StridedLayoutAttr::verify(errorEmitter, |
1176 | /*offset=*/0, strides))) |
1177 | return nullptr; |
1178 | return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides); |
1179 | } |
1180 | |
1181 | if (failed(Result: parseToken(expectedToken: Token::comma, message: "expected ','")) || |
1182 | failed(Result: parseToken(expectedToken: Token::kw_offset, message: "expected 'offset' after comma")) || |
1183 | failed(Result: parseToken(expectedToken: Token::colon, message: "expected ':' after 'offset'"))) |
1184 | return nullptr; |
1185 | |
1186 | std::optional<int64_t> offset = parseStrideOrOffset(); |
1187 | if (!offset || failed(Result: parseToken(expectedToken: Token::greater, message: "expected '>'"))) |
1188 | return nullptr; |
1189 | |
1190 | if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides))) |
1191 | return nullptr; |
1192 | return StridedLayoutAttr::get(getContext(), *offset, strides); |
1193 | // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides); |
1194 | } |
1195 | |
1196 | /// Parse a distinct attribute. |
1197 | /// |
1198 | /// distinct-attribute ::= `distinct` |
1199 | /// `[` integer-literal `]<` attribute-value `>` |
1200 | /// |
1201 | Attribute Parser::parseDistinctAttr(Type type) { |
1202 | SMLoc loc = getToken().getLoc(); |
1203 | consumeToken(kind: Token::kw_distinct); |
1204 | if (parseToken(expectedToken: Token::l_square, message: "expected '[' after 'distinct'")) |
1205 | return {}; |
1206 | |
1207 | // Parse the distinct integer identifier. |
1208 | Token token = getToken(); |
1209 | if (parseToken(expectedToken: Token::integer, message: "expected distinct ID")) |
1210 | return {}; |
1211 | std::optional<uint64_t> value = token.getUInt64IntegerValue(); |
1212 | if (!value) { |
1213 | emitError(message: "expected an unsigned 64-bit integer"); |
1214 | return {}; |
1215 | } |
1216 | |
1217 | // Parse the referenced attribute. |
1218 | if (parseToken(expectedToken: Token::r_square, message: "expected ']' to close distinct ID") || |
1219 | parseToken(expectedToken: Token::less, message: "expected '<' after distinct ID")) |
1220 | return {}; |
1221 | |
1222 | Attribute referencedAttr; |
1223 | if (getToken().is(k: Token::greater)) { |
1224 | consumeToken(); |
1225 | referencedAttr = builder.getUnitAttr(); |
1226 | } else { |
1227 | referencedAttr = parseAttribute(type); |
1228 | if (!referencedAttr) { |
1229 | emitError(message: "expected attribute"); |
1230 | return {}; |
1231 | } |
1232 | |
1233 | if (parseToken(expectedToken: Token::greater, message: "expected '>' to close distinct attribute")) |
1234 | return {}; |
1235 | } |
1236 | |
1237 | // Add the distinct attribute to the parser state, if it has not been parsed |
1238 | // before. Otherwise, check if the parsed reference attribute matches the one |
1239 | // found in the parser state. |
1240 | DenseMap<uint64_t, DistinctAttr> &distinctAttrs = |
1241 | state.symbols.distinctAttributes; |
1242 | auto it = distinctAttrs.find(Val: *value); |
1243 | if (it == distinctAttrs.end()) { |
1244 | DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr); |
1245 | it = distinctAttrs.try_emplace(Key: *value, Args&: distinctAttr).first; |
1246 | } else if (it->getSecond().getReferencedAttr() != referencedAttr) { |
1247 | emitError(loc, message: "referenced attribute does not match previous definition: ") |
1248 | << it->getSecond().getReferencedAttr(); |
1249 | return {}; |
1250 | } |
1251 | |
1252 | return it->getSecond(); |
1253 | } |
1254 |
Definitions
- parseAttribute
- parseOptionalAttribute
- parseOptionalAttribute
- parseOptionalAttribute
- parseOptionalAttribute
- parseAttributeDict
- parseFloatAttr
- buildAttributeAPInt
- parseDecOrHexAttr
- parseElementAttrHexValues
- TensorLiteralParser
- TensorLiteralParser
- getShape
- parse
- getAttr
- getIntAttrElements
- getFloatAttrElements
- getStringAttr
- getHexAttr
- parseElement
- parseList
- DenseArrayElementParser
- DenseArrayElementParser
- getAttr
- append
- parseIntegerElement
- parseFloatElement
- parseDenseArrayAttr
- parseDenseElementsAttr
- parseDenseResourceElementsAttr
- parseElementsLiteralType
- parseSparseElementsAttr
- parseStridedLayoutAttr
Learn to use CMake with our Intro Training
Find out more