1//===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the parser for the MLIR Types.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Parser.h"
14
15#include "mlir/AsmParser/AsmParserState.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinDialect.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/DialectResourceBlobManager.h"
21#include "mlir/IR/IntegerSet.h"
22#include <optional>
23
24using namespace mlir;
25using namespace mlir::detail;
26
27/// Parse an arbitrary attribute.
28///
29/// attribute-value ::= `unit`
30/// | bool-literal
31/// | integer-literal (`:` (index-type | integer-type))?
32/// | float-literal (`:` float-type)?
33/// | string-literal (`:` type)?
34/// | type
35/// | `[` `:` (integer-type | float-type) tensor-literal `]`
36/// | `[` (attribute-value (`,` attribute-value)*)? `]`
37/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
38/// | symbol-ref-id (`::` symbol-ref-id)*
39/// | `dense` `<` tensor-literal `>` `:`
40/// (tensor-type | vector-type)
41/// | `sparse` `<` attribute-value `,` attribute-value `>`
42/// `:` (tensor-type | vector-type)
43/// | `strided` `<` `[` comma-separated-int-or-question `]`
44/// (`,` `offset` `:` integer-literal)? `>`
45/// | distinct-attribute
46/// | extended-attribute
47///
48Attribute Parser::parseAttribute(Type type) {
49 switch (getToken().getKind()) {
50 // Parse an AffineMap or IntegerSet attribute.
51 case Token::kw_affine_map: {
52 consumeToken(kind: Token::kw_affine_map);
53
54 AffineMap map;
55 if (parseToken(expectedToken: Token::less, message: "expected '<' in affine map") ||
56 parseAffineMapReference(map) ||
57 parseToken(expectedToken: Token::greater, message: "expected '>' in affine map"))
58 return Attribute();
59 return AffineMapAttr::get(map);
60 }
61 case Token::kw_affine_set: {
62 consumeToken(kind: Token::kw_affine_set);
63
64 IntegerSet set;
65 if (parseToken(expectedToken: Token::less, message: "expected '<' in integer set") ||
66 parseIntegerSetReference(set) ||
67 parseToken(expectedToken: Token::greater, message: "expected '>' in integer set"))
68 return Attribute();
69 return IntegerSetAttr::get(set);
70 }
71
72 // Parse an array attribute.
73 case Token::l_square: {
74 consumeToken(kind: Token::l_square);
75 SmallVector<Attribute, 4> elements;
76 auto parseElt = [&]() -> ParseResult {
77 elements.push_back(Elt: parseAttribute());
78 return elements.back() ? success() : failure();
79 };
80
81 if (parseCommaSeparatedListUntil(rightToken: Token::r_square, parseElement: parseElt))
82 return nullptr;
83 return builder.getArrayAttr(elements);
84 }
85
86 // Parse a boolean attribute.
87 case Token::kw_false:
88 consumeToken(kind: Token::kw_false);
89 return builder.getBoolAttr(value: false);
90 case Token::kw_true:
91 consumeToken(kind: Token::kw_true);
92 return builder.getBoolAttr(value: true);
93
94 // Parse a dense elements attribute.
95 case Token::kw_dense:
96 return parseDenseElementsAttr(attrType: type);
97
98 // Parse a dense resource elements attribute.
99 case Token::kw_dense_resource:
100 return parseDenseResourceElementsAttr(attrType: type);
101
102 // Parse a dense array attribute.
103 case Token::kw_array:
104 return parseDenseArrayAttr(type);
105
106 // Parse a dictionary attribute.
107 case Token::l_brace: {
108 NamedAttrList elements;
109 if (parseAttributeDict(attributes&: elements))
110 return nullptr;
111 return elements.getDictionary(getContext());
112 }
113
114 // Parse an extended attribute, i.e. alias or dialect attribute.
115 case Token::hash_identifier:
116 return parseExtendedAttr(type);
117
118 // Parse floating point and integer attributes.
119 case Token::floatliteral:
120 return parseFloatAttr(type, /*isNegative=*/false);
121 case Token::integer:
122 return parseDecOrHexAttr(type, /*isNegative=*/false);
123 case Token::minus: {
124 consumeToken(kind: Token::minus);
125 if (getToken().is(k: Token::integer))
126 return parseDecOrHexAttr(type, /*isNegative=*/true);
127 if (getToken().is(k: Token::floatliteral))
128 return parseFloatAttr(type, /*isNegative=*/true);
129
130 return (emitWrongTokenError(
131 message: "expected constant integer or floating point value"),
132 nullptr);
133 }
134
135 // Parse a location attribute.
136 case Token::kw_loc: {
137 consumeToken(kind: Token::kw_loc);
138
139 LocationAttr locAttr;
140 if (parseToken(expectedToken: Token::l_paren, message: "expected '(' in inline location") ||
141 parseLocationInstance(loc&: locAttr) ||
142 parseToken(expectedToken: Token::r_paren, message: "expected ')' in inline location"))
143 return Attribute();
144 return locAttr;
145 }
146
147 // Parse a sparse elements attribute.
148 case Token::kw_sparse:
149 return parseSparseElementsAttr(attrType: type);
150
151 // Parse a strided layout attribute.
152 case Token::kw_strided:
153 return parseStridedLayoutAttr();
154
155 // Parse a distinct attribute.
156 case Token::kw_distinct:
157 return parseDistinctAttr(type);
158
159 // Parse a string attribute.
160 case Token::string: {
161 auto val = getToken().getStringValue();
162 consumeToken(kind: Token::string);
163 // Parse the optional trailing colon type if one wasn't explicitly provided.
164 if (!type && consumeIf(kind: Token::colon) && !(type = parseType()))
165 return Attribute();
166
167 return type ? StringAttr::get(val, type)
168 : StringAttr::get(getContext(), val);
169 }
170
171 // Parse a symbol reference attribute.
172 case Token::at_identifier: {
173 // When populating the parser state, this is a list of locations for all of
174 // the nested references.
175 SmallVector<SMRange> referenceLocations;
176 if (state.asmState)
177 referenceLocations.push_back(Elt: getToken().getLocRange());
178
179 // Parse the top-level reference.
180 std::string nameStr = getToken().getSymbolReference();
181 consumeToken(kind: Token::at_identifier);
182
183 // Parse any nested references.
184 std::vector<FlatSymbolRefAttr> nestedRefs;
185 while (getToken().is(k: Token::colon)) {
186 // Check for the '::' prefix.
187 const char *curPointer = getToken().getLoc().getPointer();
188 consumeToken(kind: Token::colon);
189 if (!consumeIf(kind: Token::colon)) {
190 if (getToken().isNot(k1: Token::eof, k2: Token::error)) {
191 state.lex.resetPointer(newPointer: curPointer);
192 consumeToken();
193 }
194 break;
195 }
196 // Parse the reference itself.
197 auto curLoc = getToken().getLoc();
198 if (getToken().isNot(k: Token::at_identifier)) {
199 emitError(loc: curLoc, message: "expected nested symbol reference identifier");
200 return Attribute();
201 }
202
203 // If we are populating the assembly state, add the location for this
204 // reference.
205 if (state.asmState)
206 referenceLocations.push_back(Elt: getToken().getLocRange());
207
208 std::string nameStr = getToken().getSymbolReference();
209 consumeToken(kind: Token::at_identifier);
210 nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
211 }
212 SymbolRefAttr symbolRefAttr =
213 SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
214
215 // If we are populating the assembly state, record this symbol reference.
216 if (state.asmState)
217 state.asmState->addUses(symbolRefAttr, referenceLocations);
218 return symbolRefAttr;
219 }
220
221 // Parse a 'unit' attribute.
222 case Token::kw_unit:
223 consumeToken(kind: Token::kw_unit);
224 return builder.getUnitAttr();
225
226 // Handle completion of an attribute.
227 case Token::code_complete:
228 if (getToken().isCodeCompletionFor(kind: Token::hash_identifier))
229 return parseExtendedAttr(type);
230 return codeCompleteAttribute();
231
232 default:
233 // Parse a type attribute. We parse `Optional` here to allow for providing a
234 // better error message.
235 Type type;
236 OptionalParseResult result = parseOptionalType(type);
237 if (!result.has_value())
238 return emitWrongTokenError(message: "expected attribute value"), Attribute();
239 return failed(*result) ? Attribute() : TypeAttr::get(type);
240 }
241}
242
243/// Parse an optional attribute with the provided type.
244OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
245 Type type) {
246 switch (getToken().getKind()) {
247 case Token::at_identifier:
248 case Token::floatliteral:
249 case Token::integer:
250 case Token::hash_identifier:
251 case Token::kw_affine_map:
252 case Token::kw_affine_set:
253 case Token::kw_dense:
254 case Token::kw_dense_resource:
255 case Token::kw_false:
256 case Token::kw_loc:
257 case Token::kw_sparse:
258 case Token::kw_true:
259 case Token::kw_unit:
260 case Token::l_brace:
261 case Token::l_square:
262 case Token::minus:
263 case Token::string:
264 attribute = parseAttribute(type);
265 return success(IsSuccess: attribute != nullptr);
266
267 default:
268 // Parse an optional type attribute.
269 Type type;
270 OptionalParseResult result = parseOptionalType(type);
271 if (result.has_value() && succeeded(*result))
272 attribute = TypeAttr::get(type);
273 return result;
274 }
275}
276OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
277 Type type) {
278 return parseOptionalAttributeWithToken(kind: Token::l_square, attr&: attribute, type);
279}
280OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
281 Type type) {
282 return parseOptionalAttributeWithToken(kind: Token::string, attr&: attribute, type);
283}
284OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
285 Type type) {
286 return parseOptionalAttributeWithToken(kind: Token::at_identifier, attr&: result, type);
287}
288
289/// Attribute dictionary.
290///
291/// attribute-dict ::= `{` `}`
292/// | `{` attribute-entry (`,` attribute-entry)* `}`
293/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
294///
295ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
296 llvm::SmallDenseSet<StringAttr> seenKeys;
297 auto parseElt = [&]() -> ParseResult {
298 // The name of an attribute can either be a bare identifier, or a string.
299 std::optional<StringAttr> nameId;
300 if (getToken().is(k: Token::string))
301 nameId = builder.getStringAttr(getToken().getStringValue());
302 else if (getToken().isAny(k1: Token::bare_identifier, k2: Token::inttype) ||
303 getToken().isKeyword())
304 nameId = builder.getStringAttr(getTokenSpelling());
305 else
306 return emitWrongTokenError(message: "expected attribute name");
307
308 if (nameId->empty())
309 return emitError(message: "expected valid attribute name");
310
311 if (!seenKeys.insert(*nameId).second)
312 return emitError(message: "duplicate key '")
313 << nameId->getValue() << "' in dictionary attribute";
314 consumeToken();
315
316 // Lazy load a dialect in the context if there is a possible namespace.
317 auto splitName = nameId->strref().split('.');
318 if (!splitName.second.empty())
319 getContext()->getOrLoadDialect(splitName.first);
320
321 // Try to parse the '=' for the attribute value.
322 if (!consumeIf(kind: Token::equal)) {
323 // If there is no '=', we treat this as a unit attribute.
324 attributes.push_back(newAttribute: {*nameId, builder.getUnitAttr()});
325 return success();
326 }
327
328 auto attr = parseAttribute();
329 if (!attr)
330 return failure();
331 attributes.push_back(newAttribute: {*nameId, attr});
332 return success();
333 };
334
335 return parseCommaSeparatedList(delimiter: Delimiter::Braces, parseElementFn: parseElt,
336 contextMessage: " in attribute dictionary");
337}
338
339/// Parse a float attribute.
340Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
341 auto val = getToken().getFloatingPointValue();
342 if (!val)
343 return (emitError(message: "floating point value too large for attribute"), nullptr);
344 consumeToken(kind: Token::floatliteral);
345 if (!type) {
346 // Default to F64 when no type is specified.
347 if (!consumeIf(kind: Token::colon))
348 type = builder.getF64Type();
349 else if (!(type = parseType()))
350 return nullptr;
351 }
352 if (!isa<FloatType>(Val: type))
353 return (emitError(message: "floating point value not valid for specified type"),
354 nullptr);
355 return FloatAttr::get(type, isNegative ? -*val : *val);
356}
357
358/// Construct an APint from a parsed value, a known attribute type and
359/// sign.
360static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
361 StringRef spelling) {
362 // Parse the integer value into an APInt that is big enough to hold the value.
363 APInt result;
364 bool isHex = spelling.size() > 1 && spelling[1] == 'x';
365 if (spelling.getAsInteger(Radix: isHex ? 0 : 10, Result&: result))
366 return std::nullopt;
367
368 // Extend or truncate the bitwidth to the right size.
369 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
370 : type.getIntOrFloatBitWidth();
371
372 if (width > result.getBitWidth()) {
373 result = result.zext(width);
374 } else if (width < result.getBitWidth()) {
375 // The parser can return an unnecessarily wide result with leading zeros.
376 // This isn't a problem, but truncating off bits is bad.
377 if (result.countl_zero() < result.getBitWidth() - width)
378 return std::nullopt;
379
380 result = result.trunc(width);
381 }
382
383 if (width == 0) {
384 // 0 bit integers cannot be negative and manipulation of their sign bit will
385 // assert, so short-cut validation here.
386 if (isNegative)
387 return std::nullopt;
388 } else if (isNegative) {
389 // The value is negative, we have an overflow if the sign bit is not set
390 // in the negated apInt.
391 result.negate();
392 if (!result.isSignBitSet())
393 return std::nullopt;
394 } else if ((type.isSignedInteger() || type.isIndex()) &&
395 result.isSignBitSet()) {
396 // The value is a positive signed integer or index,
397 // we have an overflow if the sign bit is set.
398 return std::nullopt;
399 }
400
401 return result;
402}
403
404/// Parse a decimal or a hexadecimal literal, which can be either an integer
405/// or a float attribute.
406Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
407 Token tok = getToken();
408 StringRef spelling = tok.getSpelling();
409 SMLoc loc = tok.getLoc();
410
411 consumeToken(kind: Token::integer);
412 if (!type) {
413 // Default to i64 if not type is specified.
414 if (!consumeIf(kind: Token::colon))
415 type = builder.getIntegerType(64);
416 else if (!(type = parseType()))
417 return nullptr;
418 }
419
420 if (auto floatType = dyn_cast<FloatType>(type)) {
421 std::optional<APFloat> result;
422 if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
423 semantics: floatType.getFloatSemantics())))
424 return Attribute();
425 return FloatAttr::get(floatType, *result);
426 }
427
428 if (!isa<IntegerType, IndexType>(Val: type))
429 return emitError(loc, message: "integer literal not valid for specified type"),
430 nullptr;
431
432 if (isNegative && type.isUnsignedInteger()) {
433 emitError(loc,
434 message: "negative integer literal not valid for unsigned integer type");
435 return nullptr;
436 }
437
438 std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
439 if (!apInt)
440 return emitError(loc, message: "integer constant out of range for attribute"),
441 nullptr;
442 return builder.getIntegerAttr(type, *apInt);
443}
444
445//===----------------------------------------------------------------------===//
446// TensorLiteralParser
447//===----------------------------------------------------------------------===//
448
449/// Parse elements values stored within a hex string. On success, the values are
450/// stored into 'result'.
451static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
452 std::string &result) {
453 if (std::optional<std::string> value = tok.getHexStringValue()) {
454 result = std::move(*value);
455 return success();
456 }
457 return parser.emitError(
458 loc: tok.getLoc(), message: "expected string containing hex digits starting with `0x`");
459}
460
461namespace {
462/// This class implements a parser for TensorLiterals. A tensor literal is
463/// either a single element (e.g, 5) or a multi-dimensional list of elements
464/// (e.g., [[5, 5]]).
465class TensorLiteralParser {
466public:
467 TensorLiteralParser(Parser &p) : p(p) {}
468
469 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
470 /// may also parse a tensor literal that is store as a hex string.
471 ParseResult parse(bool allowHex);
472
473 /// Build a dense attribute instance with the parsed elements and the given
474 /// shaped type.
475 DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
476
477 ArrayRef<int64_t> getShape() const { return shape; }
478
479private:
480 /// Get the parsed elements for an integer attribute.
481 ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
482 std::vector<APInt> &intValues);
483
484 /// Get the parsed elements for a float attribute.
485 ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
486 std::vector<APFloat> &floatValues);
487
488 /// Build a Dense String attribute for the given type.
489 DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
490
491 /// Build a Dense attribute with hex data for the given type.
492 DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
493
494 /// Parse a single element, returning failure if it isn't a valid element
495 /// literal. For example:
496 /// parseElement(1) -> Success, 1
497 /// parseElement([1]) -> Failure
498 ParseResult parseElement();
499
500 /// Parse a list of either lists or elements, returning the dimensions of the
501 /// parsed sub-tensors in dims. For example:
502 /// parseList([1, 2, 3]) -> Success, [3]
503 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
504 /// parseList([[1, 2], 3]) -> Failure
505 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
506 ParseResult parseList(SmallVectorImpl<int64_t> &dims);
507
508 /// Parse a literal that was printed as a hex string.
509 ParseResult parseHexElements();
510
511 Parser &p;
512
513 /// The shape inferred from the parsed elements.
514 SmallVector<int64_t, 4> shape;
515
516 /// Storage used when parsing elements, this is a pair of <is_negated, token>.
517 std::vector<std::pair<bool, Token>> storage;
518
519 /// Storage used when parsing elements that were stored as hex values.
520 std::optional<Token> hexStorage;
521};
522} // namespace
523
524/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
525/// may also parse a tensor literal that is store as a hex string.
526ParseResult TensorLiteralParser::parse(bool allowHex) {
527 // If hex is allowed, check for a string literal.
528 if (allowHex && p.getToken().is(k: Token::string)) {
529 hexStorage = p.getToken();
530 p.consumeToken(kind: Token::string);
531 return success();
532 }
533 // Otherwise, parse a list or an individual element.
534 if (p.getToken().is(k: Token::l_square))
535 return parseList(dims&: shape);
536 return parseElement();
537}
538
539/// Build a dense attribute instance with the parsed elements and the given
540/// shaped type.
541DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
542 Type eltType = type.getElementType();
543
544 // Check to see if we parse the literal from a hex string.
545 if (hexStorage &&
546 (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType)))
547 return getHexAttr(loc, type);
548
549 // Check that the parsed storage size has the same number of elements to the
550 // type, or is a known splat.
551 if (!shape.empty() && getShape() != type.getShape()) {
552 p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
553 << "]) does not match type ([" << type.getShape() << "])";
554 return nullptr;
555 }
556
557 // Handle the case where no elements were parsed.
558 if (!hexStorage && storage.empty() && type.getNumElements()) {
559 p.emitError(loc) << "parsed zero elements, but type (" << type
560 << ") expected at least 1";
561 return nullptr;
562 }
563
564 // Handle complex types in the specific element type cases below.
565 bool isComplex = false;
566 if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
567 eltType = complexTy.getElementType();
568 isComplex = true;
569 // Complex types have N*2 elements or complex splat.
570 // Empty shape may mean a splat or empty literal, only validate splats.
571 bool isSplat = shape.empty() && type.getNumElements() != 0;
572 if (isSplat && storage.size() != 2) {
573 p.emitError(loc) << "parsed " << storage.size() << " elements, but type ("
574 << complexTy << ") expected 2 elements";
575 return nullptr;
576 }
577 if (!shape.empty() &&
578 storage.size() != static_cast<size_t>(type.getNumElements()) * 2) {
579 p.emitError(loc) << "parsed " << storage.size() << " elements, but type ("
580 << type << ") expected " << type.getNumElements() * 2
581 << " elements";
582 return nullptr;
583 }
584 }
585
586 // Handle integer and index types.
587 if (eltType.isIntOrIndex()) {
588 std::vector<APInt> intValues;
589 if (failed(Result: getIntAttrElements(loc, eltTy: eltType, intValues)))
590 return nullptr;
591 if (isComplex) {
592 // If this is a complex, treat the parsed values as complex values.
593 auto complexData = llvm::ArrayRef(
594 reinterpret_cast<std::complex<APInt> *>(intValues.data()),
595 intValues.size() / 2);
596 return DenseElementsAttr::get(type, complexData);
597 }
598 return DenseElementsAttr::get(type, intValues);
599 }
600 // Handle floating point types.
601 if (FloatType floatTy = dyn_cast<FloatType>(eltType)) {
602 std::vector<APFloat> floatValues;
603 if (failed(getFloatAttrElements(loc, eltTy: floatTy, floatValues)))
604 return nullptr;
605 if (isComplex) {
606 // If this is a complex, treat the parsed values as complex values.
607 auto complexData = llvm::ArrayRef(
608 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
609 floatValues.size() / 2);
610 return DenseElementsAttr::get(type, complexData);
611 }
612 return DenseElementsAttr::get(type, floatValues);
613 }
614
615 // Other types are assumed to be string representations.
616 return getStringAttr(loc, type, type.getElementType());
617}
618
619/// Build a Dense Integer attribute for the given type.
620ParseResult
621TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
622 std::vector<APInt> &intValues) {
623 intValues.reserve(n: storage.size());
624 bool isUintType = eltTy.isUnsignedInteger();
625 for (const auto &signAndToken : storage) {
626 bool isNegative = signAndToken.first;
627 const Token &token = signAndToken.second;
628 auto tokenLoc = token.getLoc();
629
630 if (isNegative && isUintType) {
631 return p.emitError(loc: tokenLoc)
632 << "expected unsigned integer elements, but parsed negative value";
633 }
634
635 // Check to see if floating point values were parsed.
636 if (token.is(k: Token::floatliteral)) {
637 return p.emitError(loc: tokenLoc)
638 << "expected integer elements, but parsed floating-point";
639 }
640
641 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
642 "unexpected token type");
643 if (token.isAny(k1: Token::kw_true, k2: Token::kw_false)) {
644 if (!eltTy.isInteger(width: 1)) {
645 return p.emitError(loc: tokenLoc)
646 << "expected i1 type for 'true' or 'false' values";
647 }
648 APInt apInt(1, token.is(k: Token::kw_true), /*isSigned=*/false);
649 intValues.push_back(x: apInt);
650 continue;
651 }
652
653 // Create APInt values for each element with the correct bitwidth.
654 std::optional<APInt> apInt =
655 buildAttributeAPInt(type: eltTy, isNegative, spelling: token.getSpelling());
656 if (!apInt)
657 return p.emitError(loc: tokenLoc, message: "integer constant out of range for type");
658 intValues.push_back(x: *apInt);
659 }
660 return success();
661}
662
663/// Build a Dense Float attribute for the given type.
664ParseResult
665TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
666 std::vector<APFloat> &floatValues) {
667 floatValues.reserve(n: storage.size());
668 for (const auto &signAndToken : storage) {
669 bool isNegative = signAndToken.first;
670 const Token &token = signAndToken.second;
671 std::optional<APFloat> result;
672 if (failed(p.parseFloatFromLiteral(result, tok: token, isNegative,
673 semantics: eltTy.getFloatSemantics())))
674 return failure();
675 floatValues.push_back(x: *result);
676 }
677 return success();
678}
679
680/// Build a Dense String attribute for the given type.
681DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
682 Type eltTy) {
683 if (hexStorage.has_value()) {
684 auto stringValue = hexStorage->getStringValue();
685 return DenseStringElementsAttr::get(type, {stringValue});
686 }
687
688 std::vector<std::string> stringValues;
689 std::vector<StringRef> stringRefValues;
690 stringValues.reserve(n: storage.size());
691 stringRefValues.reserve(n: storage.size());
692
693 for (auto val : storage) {
694 if (!val.second.is(k: Token::string)) {
695 p.emitError(loc) << "expected string token, got "
696 << val.second.getSpelling();
697 return nullptr;
698 }
699 stringValues.push_back(x: val.second.getStringValue());
700 stringRefValues.emplace_back(args&: stringValues.back());
701 }
702
703 return DenseStringElementsAttr::get(type, stringRefValues);
704}
705
706/// Build a Dense attribute with hex data for the given type.
707DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
708 Type elementType = type.getElementType();
709 if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) {
710 p.emitError(loc)
711 << "expected floating-point, integer, or complex element type, got "
712 << elementType;
713 return nullptr;
714 }
715
716 std::string data;
717 if (parseElementAttrHexValues(parser&: p, tok: *hexStorage, result&: data))
718 return nullptr;
719
720 ArrayRef<char> rawData(data.data(), data.size());
721 bool detectedSplat = false;
722 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
723 p.emitError(loc) << "elements hex data size is invalid for provided type: "
724 << type;
725 return nullptr;
726 }
727
728 if (llvm::endianness::native == llvm::endianness::big) {
729 // Convert endianess in big-endian(BE) machines. `rawData` is
730 // little-endian(LE) because HEX in raw data of dense element attribute
731 // is always LE format. It is converted into BE here to be used in BE
732 // machines.
733 SmallVector<char, 64> outDataVec(rawData.size());
734 MutableArrayRef<char> convRawData(outDataVec);
735 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
736 rawData, convRawData, type);
737 return DenseElementsAttr::getFromRawBuffer(type, convRawData);
738 }
739
740 return DenseElementsAttr::getFromRawBuffer(type, rawData);
741}
742
743ParseResult TensorLiteralParser::parseElement() {
744 switch (p.getToken().getKind()) {
745 // Parse a boolean element.
746 case Token::kw_true:
747 case Token::kw_false:
748 case Token::floatliteral:
749 case Token::integer:
750 storage.emplace_back(/*isNegative=*/args: false, args: p.getToken());
751 p.consumeToken();
752 break;
753
754 // Parse a signed integer or a negative floating-point element.
755 case Token::minus:
756 p.consumeToken(kind: Token::minus);
757 if (!p.getToken().isAny(k1: Token::floatliteral, k2: Token::integer))
758 return p.emitError(message: "expected integer or floating point literal");
759 storage.emplace_back(/*isNegative=*/args: true, args: p.getToken());
760 p.consumeToken();
761 break;
762
763 case Token::string:
764 storage.emplace_back(/*isNegative=*/args: false, args: p.getToken());
765 p.consumeToken();
766 break;
767
768 // Parse a complex element of the form '(' element ',' element ')'.
769 case Token::l_paren:
770 p.consumeToken(kind: Token::l_paren);
771 if (parseElement() ||
772 p.parseToken(expectedToken: Token::comma, message: "expected ',' between complex elements") ||
773 parseElement() ||
774 p.parseToken(expectedToken: Token::r_paren, message: "expected ')' after complex elements"))
775 return failure();
776 break;
777
778 default:
779 return p.emitError(message: "expected element literal of primitive type");
780 }
781
782 return success();
783}
784
785/// Parse a list of either lists or elements, returning the dimensions of the
786/// parsed sub-tensors in dims. For example:
787/// parseList([1, 2, 3]) -> Success, [3]
788/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
789/// parseList([[1, 2], 3]) -> Failure
790/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
791ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
792 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
793 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
794 if (prevDims == newDims)
795 return success();
796 return p.emitError(message: "tensor literal is invalid; ranks are not consistent "
797 "between elements");
798 };
799
800 bool first = true;
801 SmallVector<int64_t, 4> newDims;
802 unsigned size = 0;
803 auto parseOneElement = [&]() -> ParseResult {
804 SmallVector<int64_t, 4> thisDims;
805 if (p.getToken().getKind() == Token::l_square) {
806 if (parseList(dims&: thisDims))
807 return failure();
808 } else if (parseElement()) {
809 return failure();
810 }
811 ++size;
812 if (!first)
813 return checkDims(newDims, thisDims);
814 newDims = thisDims;
815 first = false;
816 return success();
817 };
818 if (p.parseCommaSeparatedList(delimiter: Parser::Delimiter::Square, parseElementFn: parseOneElement))
819 return failure();
820
821 // Return the sublists' dimensions with 'size' prepended.
822 dims.clear();
823 dims.push_back(Elt: size);
824 dims.append(in_start: newDims.begin(), in_end: newDims.end());
825 return success();
826}
827
828//===----------------------------------------------------------------------===//
829// DenseArrayAttr Parser
830//===----------------------------------------------------------------------===//
831
832namespace {
833/// A generic dense array element parser. It parsers integer and floating point
834/// elements.
835class DenseArrayElementParser {
836public:
837 explicit DenseArrayElementParser(Type type) : type(type) {}
838
839 /// Parse an integer element.
840 ParseResult parseIntegerElement(Parser &p);
841
842 /// Parse a floating point element.
843 ParseResult parseFloatElement(Parser &p);
844
845 /// Convert the current contents to a dense array.
846 DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); }
847
848private:
849 /// Append the raw data of an APInt to the result.
850 void append(const APInt &data);
851
852 /// The array element type.
853 Type type;
854 /// The resultant byte array representing the contents of the array.
855 std::vector<char> rawData;
856 /// The number of elements in the array.
857 int64_t size = 0;
858};
859} // namespace
860
861void DenseArrayElementParser::append(const APInt &data) {
862 if (data.getBitWidth()) {
863 assert(data.getBitWidth() % 8 == 0);
864 unsigned byteSize = data.getBitWidth() / 8;
865 size_t offset = rawData.size();
866 rawData.insert(position: rawData.end(), n: byteSize, x: 0);
867 llvm::StoreIntToMemory(
868 IntVal: data, Dst: reinterpret_cast<uint8_t *>(rawData.data() + offset), StoreBytes: byteSize);
869 }
870 ++size;
871}
872
873ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
874 bool isNegative = p.consumeIf(kind: Token::minus);
875
876 // Parse an integer literal as an APInt.
877 std::optional<APInt> value;
878 StringRef spelling = p.getToken().getSpelling();
879 if (p.getToken().isAny(k1: Token::kw_true, k2: Token::kw_false)) {
880 if (!type.isInteger(width: 1))
881 return p.emitError(message: "expected i1 type for 'true' or 'false' values");
882 value = APInt(/*numBits=*/8, p.getToken().is(k: Token::kw_true),
883 !type.isUnsignedInteger());
884 p.consumeToken();
885 } else if (p.consumeIf(kind: Token::integer)) {
886 value = buildAttributeAPInt(type, isNegative, spelling);
887 if (!value)
888 return p.emitError(message: "integer constant out of range");
889 } else {
890 return p.emitError(message: "expected integer literal");
891 }
892 append(data: *value);
893 return success();
894}
895
896ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
897 bool isNegative = p.consumeIf(kind: Token::minus);
898 Token token = p.getToken();
899 std::optional<APFloat> fromIntLit;
900 if (failed(
901 p.parseFloatFromLiteral(result&: fromIntLit, tok: token, isNegative,
902 semantics: cast<FloatType>(type).getFloatSemantics())))
903 return failure();
904 p.consumeToken();
905 append(data: fromIntLit->bitcastToAPInt());
906 return success();
907}
908
909/// Parse a dense array attribute.
910Attribute Parser::parseDenseArrayAttr(Type attrType) {
911 consumeToken(kind: Token::kw_array);
912 if (parseToken(expectedToken: Token::less, message: "expected '<' after 'array'"))
913 return {};
914
915 SMLoc typeLoc = getToken().getLoc();
916 Type eltType = parseType();
917 if (!eltType) {
918 emitError(loc: typeLoc, message: "expected an integer or floating point type");
919 return {};
920 }
921
922 // Only bool or integer and floating point elements divisible by bytes are
923 // supported.
924 if (!eltType.isIntOrIndexOrFloat()) {
925 emitError(loc: typeLoc, message: "expected integer or float type, got: ") << eltType;
926 return {};
927 }
928 if (!eltType.isInteger(width: 1) && eltType.getIntOrFloatBitWidth() % 8 != 0) {
929 emitError(loc: typeLoc, message: "element type bitwidth must be a multiple of 8");
930 return {};
931 }
932
933 // Check for empty list.
934 if (consumeIf(Token::greater))
935 return DenseArrayAttr::get(eltType, 0, {});
936
937 if (parseToken(expectedToken: Token::colon, message: "expected ':' after dense array type"))
938 return {};
939
940 DenseArrayElementParser eltParser(eltType);
941 if (eltType.isIntOrIndex()) {
942 if (parseCommaSeparatedList(
943 parseElementFn: [&] { return eltParser.parseIntegerElement(p&: *this); }))
944 return {};
945 } else {
946 if (parseCommaSeparatedList(
947 parseElementFn: [&] { return eltParser.parseFloatElement(p&: *this); }))
948 return {};
949 }
950 if (parseToken(expectedToken: Token::greater, message: "expected '>' to close an array attribute"))
951 return {};
952 return eltParser.getAttr();
953}
954
955/// Parse a dense elements attribute.
956Attribute Parser::parseDenseElementsAttr(Type attrType) {
957 auto attribLoc = getToken().getLoc();
958 consumeToken(kind: Token::kw_dense);
959 if (parseToken(expectedToken: Token::less, message: "expected '<' after 'dense'"))
960 return nullptr;
961
962 // Parse the literal data if necessary.
963 TensorLiteralParser literalParser(*this);
964 if (!consumeIf(kind: Token::greater)) {
965 if (literalParser.parse(/*allowHex=*/true) ||
966 parseToken(expectedToken: Token::greater, message: "expected '>'"))
967 return nullptr;
968 }
969
970 auto type = parseElementsLiteralType(attribLoc, attrType);
971 if (!type)
972 return nullptr;
973 return literalParser.getAttr(attribLoc, type);
974}
975
976Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
977 auto loc = getToken().getLoc();
978 consumeToken(kind: Token::kw_dense_resource);
979 if (parseToken(expectedToken: Token::less, message: "expected '<' after 'dense_resource'"))
980 return nullptr;
981
982 // Parse the resource handle.
983 FailureOr<AsmDialectResourceHandle> rawHandle =
984 parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
985 if (failed(Result: rawHandle) || parseToken(expectedToken: Token::greater, message: "expected '>'"))
986 return nullptr;
987
988 auto *handle = dyn_cast<DenseResourceElementsHandle>(Val: &*rawHandle);
989 if (!handle)
990 return emitError(loc, message: "invalid `dense_resource` handle type"), nullptr;
991
992 // Parse the type of the attribute if the user didn't provide one.
993 SMLoc typeLoc = loc;
994 if (!attrType) {
995 typeLoc = getToken().getLoc();
996 if (parseToken(expectedToken: Token::colon, message: "expected ':'") || !(attrType = parseType()))
997 return nullptr;
998 }
999
1000 ShapedType shapedType = dyn_cast<ShapedType>(attrType);
1001 if (!shapedType) {
1002 emitError(loc: typeLoc, message: "`dense_resource` expected a shaped type");
1003 return nullptr;
1004 }
1005
1006 return DenseResourceElementsAttr::get(shapedType, *handle);
1007}
1008
1009/// Shaped type for elements attribute.
1010///
1011/// elements-literal-type ::= vector-type | ranked-tensor-type
1012///
1013/// This method also checks the type has static shape.
1014ShapedType Parser::parseElementsLiteralType(SMLoc loc, Type type) {
1015 // If the user didn't provide a type, parse the colon type for the literal.
1016 if (!type) {
1017 if (parseToken(expectedToken: Token::colon, message: "expected ':'"))
1018 return nullptr;
1019 if (!(type = parseType()))
1020 return nullptr;
1021 }
1022
1023 auto sType = dyn_cast<ShapedType>(type);
1024 if (!sType) {
1025 emitError(loc, message: "elements literal must be a shaped type");
1026 return nullptr;
1027 }
1028
1029 if (!sType.hasStaticShape()) {
1030 emitError(loc, message: "elements literal type must have static shape");
1031 return nullptr;
1032 }
1033
1034 return sType;
1035}
1036
1037/// Parse a sparse elements attribute.
1038Attribute Parser::parseSparseElementsAttr(Type attrType) {
1039 SMLoc loc = getToken().getLoc();
1040 consumeToken(kind: Token::kw_sparse);
1041 if (parseToken(expectedToken: Token::less, message: "Expected '<' after 'sparse'"))
1042 return nullptr;
1043
1044 // Check for the case where all elements are sparse. The indices are
1045 // represented by a 2-dimensional shape where the second dimension is the rank
1046 // of the type.
1047 Type indiceEltType = builder.getIntegerType(64);
1048 if (consumeIf(kind: Token::greater)) {
1049 ShapedType type = parseElementsLiteralType(loc, attrType);
1050 if (!type)
1051 return nullptr;
1052
1053 // Construct the sparse elements attr using zero element indice/value
1054 // attributes.
1055 ShapedType indicesType =
1056 RankedTensorType::get({0, type.getRank()}, indiceEltType);
1057 ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1058 return getChecked<SparseElementsAttr>(
1059 loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1060 DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
1061 }
1062
1063 /// Parse the indices. We don't allow hex values here as we may need to use
1064 /// the inferred shape.
1065 auto indicesLoc = getToken().getLoc();
1066 TensorLiteralParser indiceParser(*this);
1067 if (indiceParser.parse(/*allowHex=*/false))
1068 return nullptr;
1069
1070 if (parseToken(expectedToken: Token::comma, message: "expected ','"))
1071 return nullptr;
1072
1073 /// Parse the values.
1074 auto valuesLoc = getToken().getLoc();
1075 TensorLiteralParser valuesParser(*this);
1076 if (valuesParser.parse(/*allowHex=*/true))
1077 return nullptr;
1078
1079 if (parseToken(expectedToken: Token::greater, message: "expected '>'"))
1080 return nullptr;
1081
1082 auto type = parseElementsLiteralType(loc, attrType);
1083 if (!type)
1084 return nullptr;
1085
1086 // If the indices are a splat, i.e. the literal parser parsed an element and
1087 // not a list, we set the shape explicitly. The indices are represented by a
1088 // 2-dimensional shape where the second dimension is the rank of the type.
1089 // Given that the parsed indices is a splat, we know that we only have one
1090 // indice and thus one for the first dimension.
1091 ShapedType indicesType;
1092 if (indiceParser.getShape().empty()) {
1093 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1094 } else {
1095 // Otherwise, set the shape to the one parsed by the literal parser.
1096 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
1097 }
1098 auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1099 if (!indices)
1100 return nullptr;
1101
1102 // If the values are a splat, set the shape explicitly based on the number of
1103 // indices. The number of indices is encoded in the first dimension of the
1104 // indice shape type.
1105 auto valuesEltType = type.getElementType();
1106 ShapedType valuesType =
1107 valuesParser.getShape().empty()
1108 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
1109 : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
1110 auto values = valuesParser.getAttr(valuesLoc, valuesType);
1111 if (!values)
1112 return nullptr;
1113
1114 // Build the sparse elements attribute by the indices and values.
1115 return getChecked<SparseElementsAttr>(loc, type, indices, values);
1116}
1117
1118Attribute Parser::parseStridedLayoutAttr() {
1119 // Callback for error emissing at the keyword token location.
1120 llvm::SMLoc loc = getToken().getLoc();
1121 auto errorEmitter = [&] { return emitError(loc); };
1122
1123 consumeToken(kind: Token::kw_strided);
1124 if (failed(Result: parseToken(expectedToken: Token::less, message: "expected '<' after 'strided'")) ||
1125 failed(Result: parseToken(expectedToken: Token::l_square, message: "expected '['")))
1126 return nullptr;
1127
1128 // Parses either an integer token or a question mark token. Reports an error
1129 // and returns std::nullopt if the current token is neither. The integer token
1130 // must fit into int64_t limits.
1131 auto parseStrideOrOffset = [&]() -> std::optional<int64_t> {
1132 if (consumeIf(Token::question))
1133 return ShapedType::kDynamic;
1134
1135 SMLoc loc = getToken().getLoc();
1136 auto emitWrongTokenError = [&] {
1137 emitError(loc, message: "expected a 64-bit signed integer or '?'");
1138 return std::nullopt;
1139 };
1140
1141 bool negative = consumeIf(kind: Token::minus);
1142
1143 if (getToken().is(k: Token::integer)) {
1144 std::optional<uint64_t> value = getToken().getUInt64IntegerValue();
1145 if (!value ||
1146 *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
1147 return emitWrongTokenError();
1148 consumeToken();
1149 auto result = static_cast<int64_t>(*value);
1150 if (negative)
1151 result = -result;
1152
1153 return result;
1154 }
1155
1156 return emitWrongTokenError();
1157 };
1158
1159 // Parse strides.
1160 SmallVector<int64_t> strides;
1161 if (!getToken().is(k: Token::r_square)) {
1162 do {
1163 std::optional<int64_t> stride = parseStrideOrOffset();
1164 if (!stride)
1165 return nullptr;
1166 strides.push_back(Elt: *stride);
1167 } while (consumeIf(kind: Token::comma));
1168 }
1169
1170 if (failed(Result: parseToken(expectedToken: Token::r_square, message: "expected ']'")))
1171 return nullptr;
1172
1173 // Fast path in absence of offset.
1174 if (consumeIf(kind: Token::greater)) {
1175 if (failed(StridedLayoutAttr::verify(errorEmitter,
1176 /*offset=*/0, strides)))
1177 return nullptr;
1178 return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides);
1179 }
1180
1181 if (failed(Result: parseToken(expectedToken: Token::comma, message: "expected ','")) ||
1182 failed(Result: parseToken(expectedToken: Token::kw_offset, message: "expected 'offset' after comma")) ||
1183 failed(Result: parseToken(expectedToken: Token::colon, message: "expected ':' after 'offset'")))
1184 return nullptr;
1185
1186 std::optional<int64_t> offset = parseStrideOrOffset();
1187 if (!offset || failed(Result: parseToken(expectedToken: Token::greater, message: "expected '>'")))
1188 return nullptr;
1189
1190 if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides)))
1191 return nullptr;
1192 return StridedLayoutAttr::get(getContext(), *offset, strides);
1193 // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides);
1194}
1195
1196/// Parse a distinct attribute.
1197///
1198/// distinct-attribute ::= `distinct`
1199/// `[` integer-literal `]<` attribute-value `>`
1200///
1201Attribute Parser::parseDistinctAttr(Type type) {
1202 SMLoc loc = getToken().getLoc();
1203 consumeToken(kind: Token::kw_distinct);
1204 if (parseToken(expectedToken: Token::l_square, message: "expected '[' after 'distinct'"))
1205 return {};
1206
1207 // Parse the distinct integer identifier.
1208 Token token = getToken();
1209 if (parseToken(expectedToken: Token::integer, message: "expected distinct ID"))
1210 return {};
1211 std::optional<uint64_t> value = token.getUInt64IntegerValue();
1212 if (!value) {
1213 emitError(message: "expected an unsigned 64-bit integer");
1214 return {};
1215 }
1216
1217 // Parse the referenced attribute.
1218 if (parseToken(expectedToken: Token::r_square, message: "expected ']' to close distinct ID") ||
1219 parseToken(expectedToken: Token::less, message: "expected '<' after distinct ID"))
1220 return {};
1221
1222 Attribute referencedAttr;
1223 if (getToken().is(k: Token::greater)) {
1224 consumeToken();
1225 referencedAttr = builder.getUnitAttr();
1226 } else {
1227 referencedAttr = parseAttribute(type);
1228 if (!referencedAttr) {
1229 emitError(message: "expected attribute");
1230 return {};
1231 }
1232
1233 if (parseToken(expectedToken: Token::greater, message: "expected '>' to close distinct attribute"))
1234 return {};
1235 }
1236
1237 // Add the distinct attribute to the parser state, if it has not been parsed
1238 // before. Otherwise, check if the parsed reference attribute matches the one
1239 // found in the parser state.
1240 DenseMap<uint64_t, DistinctAttr> &distinctAttrs =
1241 state.symbols.distinctAttributes;
1242 auto it = distinctAttrs.find(Val: *value);
1243 if (it == distinctAttrs.end()) {
1244 DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr);
1245 it = distinctAttrs.try_emplace(Key: *value, Args&: distinctAttr).first;
1246 } else if (it->getSecond().getReferencedAttr() != referencedAttr) {
1247 emitError(loc, message: "referenced attribute does not match previous definition: ")
1248 << it->getSecond().getReferencedAttr();
1249 return {};
1250 }
1251
1252 return it->getSecond();
1253}
1254

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/AsmParser/AttributeParser.cpp