| 1 | //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the parser for the MLIR Types. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "Parser.h" |
| 14 | |
| 15 | #include "mlir/AsmParser/AsmParserState.h" |
| 16 | #include "mlir/IR/AffineMap.h" |
| 17 | #include "mlir/IR/BuiltinAttributes.h" |
| 18 | #include "mlir/IR/BuiltinDialect.h" |
| 19 | #include "mlir/IR/BuiltinTypes.h" |
| 20 | #include "mlir/IR/DialectResourceBlobManager.h" |
| 21 | #include "mlir/IR/IntegerSet.h" |
| 22 | #include <optional> |
| 23 | |
| 24 | using namespace mlir; |
| 25 | using namespace mlir::detail; |
| 26 | |
| 27 | /// Parse an arbitrary attribute. |
| 28 | /// |
| 29 | /// attribute-value ::= `unit` |
| 30 | /// | bool-literal |
| 31 | /// | integer-literal (`:` (index-type | integer-type))? |
| 32 | /// | float-literal (`:` float-type)? |
| 33 | /// | string-literal (`:` type)? |
| 34 | /// | type |
| 35 | /// | `[` `:` (integer-type | float-type) tensor-literal `]` |
| 36 | /// | `[` (attribute-value (`,` attribute-value)*)? `]` |
| 37 | /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` |
| 38 | /// | symbol-ref-id (`::` symbol-ref-id)* |
| 39 | /// | `dense` `<` tensor-literal `>` `:` |
| 40 | /// (tensor-type | vector-type) |
| 41 | /// | `sparse` `<` attribute-value `,` attribute-value `>` |
| 42 | /// `:` (tensor-type | vector-type) |
| 43 | /// | `strided` `<` `[` comma-separated-int-or-question `]` |
| 44 | /// (`,` `offset` `:` integer-literal)? `>` |
| 45 | /// | distinct-attribute |
| 46 | /// | extended-attribute |
| 47 | /// |
| 48 | Attribute Parser::parseAttribute(Type type) { |
| 49 | switch (getToken().getKind()) { |
| 50 | // Parse an AffineMap or IntegerSet attribute. |
| 51 | case Token::kw_affine_map: { |
| 52 | consumeToken(kind: Token::kw_affine_map); |
| 53 | |
| 54 | AffineMap map; |
| 55 | if (parseToken(expectedToken: Token::less, message: "expected '<' in affine map" ) || |
| 56 | parseAffineMapReference(map) || |
| 57 | parseToken(expectedToken: Token::greater, message: "expected '>' in affine map" )) |
| 58 | return Attribute(); |
| 59 | return AffineMapAttr::get(map); |
| 60 | } |
| 61 | case Token::kw_affine_set: { |
| 62 | consumeToken(kind: Token::kw_affine_set); |
| 63 | |
| 64 | IntegerSet set; |
| 65 | if (parseToken(expectedToken: Token::less, message: "expected '<' in integer set" ) || |
| 66 | parseIntegerSetReference(set) || |
| 67 | parseToken(expectedToken: Token::greater, message: "expected '>' in integer set" )) |
| 68 | return Attribute(); |
| 69 | return IntegerSetAttr::get(set); |
| 70 | } |
| 71 | |
| 72 | // Parse an array attribute. |
| 73 | case Token::l_square: { |
| 74 | consumeToken(kind: Token::l_square); |
| 75 | SmallVector<Attribute, 4> elements; |
| 76 | auto parseElt = [&]() -> ParseResult { |
| 77 | elements.push_back(Elt: parseAttribute()); |
| 78 | return elements.back() ? success() : failure(); |
| 79 | }; |
| 80 | |
| 81 | if (parseCommaSeparatedListUntil(rightToken: Token::r_square, parseElement: parseElt)) |
| 82 | return nullptr; |
| 83 | return builder.getArrayAttr(elements); |
| 84 | } |
| 85 | |
| 86 | // Parse a boolean attribute. |
| 87 | case Token::kw_false: |
| 88 | consumeToken(kind: Token::kw_false); |
| 89 | return builder.getBoolAttr(value: false); |
| 90 | case Token::kw_true: |
| 91 | consumeToken(kind: Token::kw_true); |
| 92 | return builder.getBoolAttr(value: true); |
| 93 | |
| 94 | // Parse a dense elements attribute. |
| 95 | case Token::kw_dense: |
| 96 | return parseDenseElementsAttr(attrType: type); |
| 97 | |
| 98 | // Parse a dense resource elements attribute. |
| 99 | case Token::kw_dense_resource: |
| 100 | return parseDenseResourceElementsAttr(attrType: type); |
| 101 | |
| 102 | // Parse a dense array attribute. |
| 103 | case Token::kw_array: |
| 104 | return parseDenseArrayAttr(type); |
| 105 | |
| 106 | // Parse a dictionary attribute. |
| 107 | case Token::l_brace: { |
| 108 | NamedAttrList elements; |
| 109 | if (parseAttributeDict(attributes&: elements)) |
| 110 | return nullptr; |
| 111 | return elements.getDictionary(getContext()); |
| 112 | } |
| 113 | |
| 114 | // Parse an extended attribute, i.e. alias or dialect attribute. |
| 115 | case Token::hash_identifier: |
| 116 | return parseExtendedAttr(type); |
| 117 | |
| 118 | // Parse floating point and integer attributes. |
| 119 | case Token::floatliteral: |
| 120 | return parseFloatAttr(type, /*isNegative=*/false); |
| 121 | case Token::integer: |
| 122 | return parseDecOrHexAttr(type, /*isNegative=*/false); |
| 123 | case Token::minus: { |
| 124 | consumeToken(kind: Token::minus); |
| 125 | if (getToken().is(k: Token::integer)) |
| 126 | return parseDecOrHexAttr(type, /*isNegative=*/true); |
| 127 | if (getToken().is(k: Token::floatliteral)) |
| 128 | return parseFloatAttr(type, /*isNegative=*/true); |
| 129 | |
| 130 | return (emitWrongTokenError( |
| 131 | message: "expected constant integer or floating point value" ), |
| 132 | nullptr); |
| 133 | } |
| 134 | |
| 135 | // Parse a location attribute. |
| 136 | case Token::kw_loc: { |
| 137 | consumeToken(kind: Token::kw_loc); |
| 138 | |
| 139 | LocationAttr locAttr; |
| 140 | if (parseToken(expectedToken: Token::l_paren, message: "expected '(' in inline location" ) || |
| 141 | parseLocationInstance(loc&: locAttr) || |
| 142 | parseToken(expectedToken: Token::r_paren, message: "expected ')' in inline location" )) |
| 143 | return Attribute(); |
| 144 | return locAttr; |
| 145 | } |
| 146 | |
| 147 | // Parse a sparse elements attribute. |
| 148 | case Token::kw_sparse: |
| 149 | return parseSparseElementsAttr(attrType: type); |
| 150 | |
| 151 | // Parse a strided layout attribute. |
| 152 | case Token::kw_strided: |
| 153 | return parseStridedLayoutAttr(); |
| 154 | |
| 155 | // Parse a distinct attribute. |
| 156 | case Token::kw_distinct: |
| 157 | return parseDistinctAttr(type); |
| 158 | |
| 159 | // Parse a string attribute. |
| 160 | case Token::string: { |
| 161 | auto val = getToken().getStringValue(); |
| 162 | consumeToken(kind: Token::string); |
| 163 | // Parse the optional trailing colon type if one wasn't explicitly provided. |
| 164 | if (!type && consumeIf(kind: Token::colon) && !(type = parseType())) |
| 165 | return Attribute(); |
| 166 | |
| 167 | return type ? StringAttr::get(val, type) |
| 168 | : StringAttr::get(getContext(), val); |
| 169 | } |
| 170 | |
| 171 | // Parse a symbol reference attribute. |
| 172 | case Token::at_identifier: { |
| 173 | // When populating the parser state, this is a list of locations for all of |
| 174 | // the nested references. |
| 175 | SmallVector<SMRange> referenceLocations; |
| 176 | if (state.asmState) |
| 177 | referenceLocations.push_back(Elt: getToken().getLocRange()); |
| 178 | |
| 179 | // Parse the top-level reference. |
| 180 | std::string nameStr = getToken().getSymbolReference(); |
| 181 | consumeToken(kind: Token::at_identifier); |
| 182 | |
| 183 | // Parse any nested references. |
| 184 | std::vector<FlatSymbolRefAttr> nestedRefs; |
| 185 | while (getToken().is(k: Token::colon)) { |
| 186 | // Check for the '::' prefix. |
| 187 | const char *curPointer = getToken().getLoc().getPointer(); |
| 188 | consumeToken(kind: Token::colon); |
| 189 | if (!consumeIf(kind: Token::colon)) { |
| 190 | if (getToken().isNot(k1: Token::eof, k2: Token::error)) { |
| 191 | state.lex.resetPointer(newPointer: curPointer); |
| 192 | consumeToken(); |
| 193 | } |
| 194 | break; |
| 195 | } |
| 196 | // Parse the reference itself. |
| 197 | auto curLoc = getToken().getLoc(); |
| 198 | if (getToken().isNot(k: Token::at_identifier)) { |
| 199 | emitError(loc: curLoc, message: "expected nested symbol reference identifier" ); |
| 200 | return Attribute(); |
| 201 | } |
| 202 | |
| 203 | // If we are populating the assembly state, add the location for this |
| 204 | // reference. |
| 205 | if (state.asmState) |
| 206 | referenceLocations.push_back(Elt: getToken().getLocRange()); |
| 207 | |
| 208 | std::string nameStr = getToken().getSymbolReference(); |
| 209 | consumeToken(kind: Token::at_identifier); |
| 210 | nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); |
| 211 | } |
| 212 | SymbolRefAttr symbolRefAttr = |
| 213 | SymbolRefAttr::get(getContext(), nameStr, nestedRefs); |
| 214 | |
| 215 | // If we are populating the assembly state, record this symbol reference. |
| 216 | if (state.asmState) |
| 217 | state.asmState->addUses(symbolRefAttr, referenceLocations); |
| 218 | return symbolRefAttr; |
| 219 | } |
| 220 | |
| 221 | // Parse a 'unit' attribute. |
| 222 | case Token::kw_unit: |
| 223 | consumeToken(kind: Token::kw_unit); |
| 224 | return builder.getUnitAttr(); |
| 225 | |
| 226 | // Handle completion of an attribute. |
| 227 | case Token::code_complete: |
| 228 | if (getToken().isCodeCompletionFor(kind: Token::hash_identifier)) |
| 229 | return parseExtendedAttr(type); |
| 230 | return codeCompleteAttribute(); |
| 231 | |
| 232 | default: |
| 233 | // Parse a type attribute. We parse `Optional` here to allow for providing a |
| 234 | // better error message. |
| 235 | Type type; |
| 236 | OptionalParseResult result = parseOptionalType(type); |
| 237 | if (!result.has_value()) |
| 238 | return emitWrongTokenError(message: "expected attribute value" ), Attribute(); |
| 239 | return failed(*result) ? Attribute() : TypeAttr::get(type); |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | /// Parse an optional attribute with the provided type. |
| 244 | OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, |
| 245 | Type type) { |
| 246 | switch (getToken().getKind()) { |
| 247 | case Token::at_identifier: |
| 248 | case Token::floatliteral: |
| 249 | case Token::integer: |
| 250 | case Token::hash_identifier: |
| 251 | case Token::kw_affine_map: |
| 252 | case Token::kw_affine_set: |
| 253 | case Token::kw_dense: |
| 254 | case Token::kw_dense_resource: |
| 255 | case Token::kw_false: |
| 256 | case Token::kw_loc: |
| 257 | case Token::kw_sparse: |
| 258 | case Token::kw_true: |
| 259 | case Token::kw_unit: |
| 260 | case Token::l_brace: |
| 261 | case Token::l_square: |
| 262 | case Token::minus: |
| 263 | case Token::string: |
| 264 | attribute = parseAttribute(type); |
| 265 | return success(IsSuccess: attribute != nullptr); |
| 266 | |
| 267 | default: |
| 268 | // Parse an optional type attribute. |
| 269 | Type type; |
| 270 | OptionalParseResult result = parseOptionalType(type); |
| 271 | if (result.has_value() && succeeded(*result)) |
| 272 | attribute = TypeAttr::get(type); |
| 273 | return result; |
| 274 | } |
| 275 | } |
| 276 | OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute, |
| 277 | Type type) { |
| 278 | return parseOptionalAttributeWithToken(kind: Token::l_square, attr&: attribute, type); |
| 279 | } |
| 280 | OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute, |
| 281 | Type type) { |
| 282 | return parseOptionalAttributeWithToken(kind: Token::string, attr&: attribute, type); |
| 283 | } |
| 284 | OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result, |
| 285 | Type type) { |
| 286 | return parseOptionalAttributeWithToken(kind: Token::at_identifier, attr&: result, type); |
| 287 | } |
| 288 | |
| 289 | /// Attribute dictionary. |
| 290 | /// |
| 291 | /// attribute-dict ::= `{` `}` |
| 292 | /// | `{` attribute-entry (`,` attribute-entry)* `}` |
| 293 | /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value |
| 294 | /// |
| 295 | ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { |
| 296 | llvm::SmallDenseSet<StringAttr> seenKeys; |
| 297 | auto parseElt = [&]() -> ParseResult { |
| 298 | // The name of an attribute can either be a bare identifier, or a string. |
| 299 | std::optional<StringAttr> nameId; |
| 300 | if (getToken().is(k: Token::string)) |
| 301 | nameId = builder.getStringAttr(getToken().getStringValue()); |
| 302 | else if (getToken().isAny(k1: Token::bare_identifier, k2: Token::inttype) || |
| 303 | getToken().isKeyword()) |
| 304 | nameId = builder.getStringAttr(getTokenSpelling()); |
| 305 | else |
| 306 | return emitWrongTokenError(message: "expected attribute name" ); |
| 307 | |
| 308 | if (nameId->empty()) |
| 309 | return emitError(message: "expected valid attribute name" ); |
| 310 | |
| 311 | if (!seenKeys.insert(*nameId).second) |
| 312 | return emitError(message: "duplicate key '" ) |
| 313 | << nameId->getValue() << "' in dictionary attribute" ; |
| 314 | consumeToken(); |
| 315 | |
| 316 | // Lazy load a dialect in the context if there is a possible namespace. |
| 317 | auto splitName = nameId->strref().split('.'); |
| 318 | if (!splitName.second.empty()) |
| 319 | getContext()->getOrLoadDialect(splitName.first); |
| 320 | |
| 321 | // Try to parse the '=' for the attribute value. |
| 322 | if (!consumeIf(kind: Token::equal)) { |
| 323 | // If there is no '=', we treat this as a unit attribute. |
| 324 | attributes.push_back(newAttribute: {*nameId, builder.getUnitAttr()}); |
| 325 | return success(); |
| 326 | } |
| 327 | |
| 328 | auto attr = parseAttribute(); |
| 329 | if (!attr) |
| 330 | return failure(); |
| 331 | attributes.push_back(newAttribute: {*nameId, attr}); |
| 332 | return success(); |
| 333 | }; |
| 334 | |
| 335 | return parseCommaSeparatedList(delimiter: Delimiter::Braces, parseElementFn: parseElt, |
| 336 | contextMessage: " in attribute dictionary" ); |
| 337 | } |
| 338 | |
| 339 | /// Parse a float attribute. |
| 340 | Attribute Parser::parseFloatAttr(Type type, bool isNegative) { |
| 341 | auto val = getToken().getFloatingPointValue(); |
| 342 | if (!val) |
| 343 | return (emitError(message: "floating point value too large for attribute" ), nullptr); |
| 344 | consumeToken(kind: Token::floatliteral); |
| 345 | if (!type) { |
| 346 | // Default to F64 when no type is specified. |
| 347 | if (!consumeIf(kind: Token::colon)) |
| 348 | type = builder.getF64Type(); |
| 349 | else if (!(type = parseType())) |
| 350 | return nullptr; |
| 351 | } |
| 352 | if (!isa<FloatType>(Val: type)) |
| 353 | return (emitError(message: "floating point value not valid for specified type" ), |
| 354 | nullptr); |
| 355 | return FloatAttr::get(type, isNegative ? -*val : *val); |
| 356 | } |
| 357 | |
| 358 | /// Construct an APint from a parsed value, a known attribute type and |
| 359 | /// sign. |
| 360 | static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative, |
| 361 | StringRef spelling) { |
| 362 | // Parse the integer value into an APInt that is big enough to hold the value. |
| 363 | APInt result; |
| 364 | bool isHex = spelling.size() > 1 && spelling[1] == 'x'; |
| 365 | if (spelling.getAsInteger(Radix: isHex ? 0 : 10, Result&: result)) |
| 366 | return std::nullopt; |
| 367 | |
| 368 | // Extend or truncate the bitwidth to the right size. |
| 369 | unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth |
| 370 | : type.getIntOrFloatBitWidth(); |
| 371 | |
| 372 | if (width > result.getBitWidth()) { |
| 373 | result = result.zext(width); |
| 374 | } else if (width < result.getBitWidth()) { |
| 375 | // The parser can return an unnecessarily wide result with leading zeros. |
| 376 | // This isn't a problem, but truncating off bits is bad. |
| 377 | if (result.countl_zero() < result.getBitWidth() - width) |
| 378 | return std::nullopt; |
| 379 | |
| 380 | result = result.trunc(width); |
| 381 | } |
| 382 | |
| 383 | if (width == 0) { |
| 384 | // 0 bit integers cannot be negative and manipulation of their sign bit will |
| 385 | // assert, so short-cut validation here. |
| 386 | if (isNegative) |
| 387 | return std::nullopt; |
| 388 | } else if (isNegative) { |
| 389 | // The value is negative, we have an overflow if the sign bit is not set |
| 390 | // in the negated apInt. |
| 391 | result.negate(); |
| 392 | if (!result.isSignBitSet()) |
| 393 | return std::nullopt; |
| 394 | } else if ((type.isSignedInteger() || type.isIndex()) && |
| 395 | result.isSignBitSet()) { |
| 396 | // The value is a positive signed integer or index, |
| 397 | // we have an overflow if the sign bit is set. |
| 398 | return std::nullopt; |
| 399 | } |
| 400 | |
| 401 | return result; |
| 402 | } |
| 403 | |
| 404 | /// Parse a decimal or a hexadecimal literal, which can be either an integer |
| 405 | /// or a float attribute. |
| 406 | Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { |
| 407 | Token tok = getToken(); |
| 408 | StringRef spelling = tok.getSpelling(); |
| 409 | SMLoc loc = tok.getLoc(); |
| 410 | |
| 411 | consumeToken(kind: Token::integer); |
| 412 | if (!type) { |
| 413 | // Default to i64 if not type is specified. |
| 414 | if (!consumeIf(kind: Token::colon)) |
| 415 | type = builder.getIntegerType(64); |
| 416 | else if (!(type = parseType())) |
| 417 | return nullptr; |
| 418 | } |
| 419 | |
| 420 | if (auto floatType = dyn_cast<FloatType>(type)) { |
| 421 | std::optional<APFloat> result; |
| 422 | if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, |
| 423 | semantics: floatType.getFloatSemantics()))) |
| 424 | return Attribute(); |
| 425 | return FloatAttr::get(floatType, *result); |
| 426 | } |
| 427 | |
| 428 | if (!isa<IntegerType, IndexType>(Val: type)) |
| 429 | return emitError(loc, message: "integer literal not valid for specified type" ), |
| 430 | nullptr; |
| 431 | |
| 432 | if (isNegative && type.isUnsignedInteger()) { |
| 433 | emitError(loc, |
| 434 | message: "negative integer literal not valid for unsigned integer type" ); |
| 435 | return nullptr; |
| 436 | } |
| 437 | |
| 438 | std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); |
| 439 | if (!apInt) |
| 440 | return emitError(loc, message: "integer constant out of range for attribute" ), |
| 441 | nullptr; |
| 442 | return builder.getIntegerAttr(type, *apInt); |
| 443 | } |
| 444 | |
| 445 | //===----------------------------------------------------------------------===// |
| 446 | // TensorLiteralParser |
| 447 | //===----------------------------------------------------------------------===// |
| 448 | |
| 449 | /// Parse elements values stored within a hex string. On success, the values are |
| 450 | /// stored into 'result'. |
| 451 | static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, |
| 452 | std::string &result) { |
| 453 | if (std::optional<std::string> value = tok.getHexStringValue()) { |
| 454 | result = std::move(*value); |
| 455 | return success(); |
| 456 | } |
| 457 | return parser.emitError( |
| 458 | loc: tok.getLoc(), message: "expected string containing hex digits starting with `0x`" ); |
| 459 | } |
| 460 | |
| 461 | namespace { |
| 462 | /// This class implements a parser for TensorLiterals. A tensor literal is |
| 463 | /// either a single element (e.g, 5) or a multi-dimensional list of elements |
| 464 | /// (e.g., [[5, 5]]). |
| 465 | class TensorLiteralParser { |
| 466 | public: |
| 467 | TensorLiteralParser(Parser &p) : p(p) {} |
| 468 | |
| 469 | /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser |
| 470 | /// may also parse a tensor literal that is store as a hex string. |
| 471 | ParseResult parse(bool allowHex); |
| 472 | |
| 473 | /// Build a dense attribute instance with the parsed elements and the given |
| 474 | /// shaped type. |
| 475 | DenseElementsAttr getAttr(SMLoc loc, ShapedType type); |
| 476 | |
| 477 | ArrayRef<int64_t> getShape() const { return shape; } |
| 478 | |
| 479 | private: |
| 480 | /// Get the parsed elements for an integer attribute. |
| 481 | ParseResult getIntAttrElements(SMLoc loc, Type eltTy, |
| 482 | std::vector<APInt> &intValues); |
| 483 | |
| 484 | /// Get the parsed elements for a float attribute. |
| 485 | ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy, |
| 486 | std::vector<APFloat> &floatValues); |
| 487 | |
| 488 | /// Build a Dense String attribute for the given type. |
| 489 | DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); |
| 490 | |
| 491 | /// Build a Dense attribute with hex data for the given type. |
| 492 | DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); |
| 493 | |
| 494 | /// Parse a single element, returning failure if it isn't a valid element |
| 495 | /// literal. For example: |
| 496 | /// parseElement(1) -> Success, 1 |
| 497 | /// parseElement([1]) -> Failure |
| 498 | ParseResult parseElement(); |
| 499 | |
| 500 | /// Parse a list of either lists or elements, returning the dimensions of the |
| 501 | /// parsed sub-tensors in dims. For example: |
| 502 | /// parseList([1, 2, 3]) -> Success, [3] |
| 503 | /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] |
| 504 | /// parseList([[1, 2], 3]) -> Failure |
| 505 | /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure |
| 506 | ParseResult parseList(SmallVectorImpl<int64_t> &dims); |
| 507 | |
| 508 | /// Parse a literal that was printed as a hex string. |
| 509 | ParseResult parseHexElements(); |
| 510 | |
| 511 | Parser &p; |
| 512 | |
| 513 | /// The shape inferred from the parsed elements. |
| 514 | SmallVector<int64_t, 4> shape; |
| 515 | |
| 516 | /// Storage used when parsing elements, this is a pair of <is_negated, token>. |
| 517 | std::vector<std::pair<bool, Token>> storage; |
| 518 | |
| 519 | /// Storage used when parsing elements that were stored as hex values. |
| 520 | std::optional<Token> hexStorage; |
| 521 | }; |
| 522 | } // namespace |
| 523 | |
| 524 | /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser |
| 525 | /// may also parse a tensor literal that is store as a hex string. |
| 526 | ParseResult TensorLiteralParser::parse(bool allowHex) { |
| 527 | // If hex is allowed, check for a string literal. |
| 528 | if (allowHex && p.getToken().is(k: Token::string)) { |
| 529 | hexStorage = p.getToken(); |
| 530 | p.consumeToken(kind: Token::string); |
| 531 | return success(); |
| 532 | } |
| 533 | // Otherwise, parse a list or an individual element. |
| 534 | if (p.getToken().is(k: Token::l_square)) |
| 535 | return parseList(dims&: shape); |
| 536 | return parseElement(); |
| 537 | } |
| 538 | |
| 539 | /// Build a dense attribute instance with the parsed elements and the given |
| 540 | /// shaped type. |
| 541 | DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { |
| 542 | Type eltType = type.getElementType(); |
| 543 | |
| 544 | // Check to see if we parse the literal from a hex string. |
| 545 | if (hexStorage && |
| 546 | (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType))) |
| 547 | return getHexAttr(loc, type); |
| 548 | |
| 549 | // Check that the parsed storage size has the same number of elements to the |
| 550 | // type, or is a known splat. |
| 551 | if (!shape.empty() && getShape() != type.getShape()) { |
| 552 | p.emitError(loc) << "inferred shape of elements literal ([" << getShape() |
| 553 | << "]) does not match type ([" << type.getShape() << "])" ; |
| 554 | return nullptr; |
| 555 | } |
| 556 | |
| 557 | // Handle the case where no elements were parsed. |
| 558 | if (!hexStorage && storage.empty() && type.getNumElements()) { |
| 559 | p.emitError(loc) << "parsed zero elements, but type (" << type |
| 560 | << ") expected at least 1" ; |
| 561 | return nullptr; |
| 562 | } |
| 563 | |
| 564 | // Handle complex types in the specific element type cases below. |
| 565 | bool isComplex = false; |
| 566 | if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) { |
| 567 | eltType = complexTy.getElementType(); |
| 568 | isComplex = true; |
| 569 | // Complex types have N*2 elements or complex splat. |
| 570 | // Empty shape may mean a splat or empty literal, only validate splats. |
| 571 | bool isSplat = shape.empty() && type.getNumElements() != 0; |
| 572 | if (isSplat && storage.size() != 2) { |
| 573 | p.emitError(loc) << "parsed " << storage.size() << " elements, but type (" |
| 574 | << complexTy << ") expected 2 elements" ; |
| 575 | return nullptr; |
| 576 | } |
| 577 | if (!shape.empty() && |
| 578 | storage.size() != static_cast<size_t>(type.getNumElements()) * 2) { |
| 579 | p.emitError(loc) << "parsed " << storage.size() << " elements, but type (" |
| 580 | << type << ") expected " << type.getNumElements() * 2 |
| 581 | << " elements" ; |
| 582 | return nullptr; |
| 583 | } |
| 584 | } |
| 585 | |
| 586 | // Handle integer and index types. |
| 587 | if (eltType.isIntOrIndex()) { |
| 588 | std::vector<APInt> intValues; |
| 589 | if (failed(Result: getIntAttrElements(loc, eltTy: eltType, intValues))) |
| 590 | return nullptr; |
| 591 | if (isComplex) { |
| 592 | // If this is a complex, treat the parsed values as complex values. |
| 593 | auto complexData = llvm::ArrayRef( |
| 594 | reinterpret_cast<std::complex<APInt> *>(intValues.data()), |
| 595 | intValues.size() / 2); |
| 596 | return DenseElementsAttr::get(type, complexData); |
| 597 | } |
| 598 | return DenseElementsAttr::get(type, intValues); |
| 599 | } |
| 600 | // Handle floating point types. |
| 601 | if (FloatType floatTy = dyn_cast<FloatType>(eltType)) { |
| 602 | std::vector<APFloat> floatValues; |
| 603 | if (failed(getFloatAttrElements(loc, eltTy: floatTy, floatValues))) |
| 604 | return nullptr; |
| 605 | if (isComplex) { |
| 606 | // If this is a complex, treat the parsed values as complex values. |
| 607 | auto complexData = llvm::ArrayRef( |
| 608 | reinterpret_cast<std::complex<APFloat> *>(floatValues.data()), |
| 609 | floatValues.size() / 2); |
| 610 | return DenseElementsAttr::get(type, complexData); |
| 611 | } |
| 612 | return DenseElementsAttr::get(type, floatValues); |
| 613 | } |
| 614 | |
| 615 | // Other types are assumed to be string representations. |
| 616 | return getStringAttr(loc, type, type.getElementType()); |
| 617 | } |
| 618 | |
| 619 | /// Build a Dense Integer attribute for the given type. |
| 620 | ParseResult |
| 621 | TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, |
| 622 | std::vector<APInt> &intValues) { |
| 623 | intValues.reserve(n: storage.size()); |
| 624 | bool isUintType = eltTy.isUnsignedInteger(); |
| 625 | for (const auto &signAndToken : storage) { |
| 626 | bool isNegative = signAndToken.first; |
| 627 | const Token &token = signAndToken.second; |
| 628 | auto tokenLoc = token.getLoc(); |
| 629 | |
| 630 | if (isNegative && isUintType) { |
| 631 | return p.emitError(loc: tokenLoc) |
| 632 | << "expected unsigned integer elements, but parsed negative value" ; |
| 633 | } |
| 634 | |
| 635 | // Check to see if floating point values were parsed. |
| 636 | if (token.is(k: Token::floatliteral)) { |
| 637 | return p.emitError(loc: tokenLoc) |
| 638 | << "expected integer elements, but parsed floating-point" ; |
| 639 | } |
| 640 | |
| 641 | assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && |
| 642 | "unexpected token type" ); |
| 643 | if (token.isAny(k1: Token::kw_true, k2: Token::kw_false)) { |
| 644 | if (!eltTy.isInteger(width: 1)) { |
| 645 | return p.emitError(loc: tokenLoc) |
| 646 | << "expected i1 type for 'true' or 'false' values" ; |
| 647 | } |
| 648 | APInt apInt(1, token.is(k: Token::kw_true), /*isSigned=*/false); |
| 649 | intValues.push_back(x: apInt); |
| 650 | continue; |
| 651 | } |
| 652 | |
| 653 | // Create APInt values for each element with the correct bitwidth. |
| 654 | std::optional<APInt> apInt = |
| 655 | buildAttributeAPInt(type: eltTy, isNegative, spelling: token.getSpelling()); |
| 656 | if (!apInt) |
| 657 | return p.emitError(loc: tokenLoc, message: "integer constant out of range for type" ); |
| 658 | intValues.push_back(x: *apInt); |
| 659 | } |
| 660 | return success(); |
| 661 | } |
| 662 | |
| 663 | /// Build a Dense Float attribute for the given type. |
| 664 | ParseResult |
| 665 | TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, |
| 666 | std::vector<APFloat> &floatValues) { |
| 667 | floatValues.reserve(n: storage.size()); |
| 668 | for (const auto &signAndToken : storage) { |
| 669 | bool isNegative = signAndToken.first; |
| 670 | const Token &token = signAndToken.second; |
| 671 | std::optional<APFloat> result; |
| 672 | if (failed(p.parseFloatFromLiteral(result, tok: token, isNegative, |
| 673 | semantics: eltTy.getFloatSemantics()))) |
| 674 | return failure(); |
| 675 | floatValues.push_back(x: *result); |
| 676 | } |
| 677 | return success(); |
| 678 | } |
| 679 | |
| 680 | /// Build a Dense String attribute for the given type. |
| 681 | DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, |
| 682 | Type eltTy) { |
| 683 | if (hexStorage.has_value()) { |
| 684 | auto stringValue = hexStorage->getStringValue(); |
| 685 | return DenseStringElementsAttr::get(type, {stringValue}); |
| 686 | } |
| 687 | |
| 688 | std::vector<std::string> stringValues; |
| 689 | std::vector<StringRef> stringRefValues; |
| 690 | stringValues.reserve(n: storage.size()); |
| 691 | stringRefValues.reserve(n: storage.size()); |
| 692 | |
| 693 | for (auto val : storage) { |
| 694 | if (!val.second.is(k: Token::string)) { |
| 695 | p.emitError(loc) << "expected string token, got " |
| 696 | << val.second.getSpelling(); |
| 697 | return nullptr; |
| 698 | } |
| 699 | stringValues.push_back(x: val.second.getStringValue()); |
| 700 | stringRefValues.emplace_back(args&: stringValues.back()); |
| 701 | } |
| 702 | |
| 703 | return DenseStringElementsAttr::get(type, stringRefValues); |
| 704 | } |
| 705 | |
| 706 | /// Build a Dense attribute with hex data for the given type. |
| 707 | DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { |
| 708 | Type elementType = type.getElementType(); |
| 709 | if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) { |
| 710 | p.emitError(loc) |
| 711 | << "expected floating-point, integer, or complex element type, got " |
| 712 | << elementType; |
| 713 | return nullptr; |
| 714 | } |
| 715 | |
| 716 | std::string data; |
| 717 | if (parseElementAttrHexValues(parser&: p, tok: *hexStorage, result&: data)) |
| 718 | return nullptr; |
| 719 | |
| 720 | ArrayRef<char> rawData(data.data(), data.size()); |
| 721 | bool detectedSplat = false; |
| 722 | if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { |
| 723 | p.emitError(loc) << "elements hex data size is invalid for provided type: " |
| 724 | << type; |
| 725 | return nullptr; |
| 726 | } |
| 727 | |
| 728 | if (llvm::endianness::native == llvm::endianness::big) { |
| 729 | // Convert endianess in big-endian(BE) machines. `rawData` is |
| 730 | // little-endian(LE) because HEX in raw data of dense element attribute |
| 731 | // is always LE format. It is converted into BE here to be used in BE |
| 732 | // machines. |
| 733 | SmallVector<char, 64> outDataVec(rawData.size()); |
| 734 | MutableArrayRef<char> convRawData(outDataVec); |
| 735 | DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( |
| 736 | rawData, convRawData, type); |
| 737 | return DenseElementsAttr::getFromRawBuffer(type, convRawData); |
| 738 | } |
| 739 | |
| 740 | return DenseElementsAttr::getFromRawBuffer(type, rawData); |
| 741 | } |
| 742 | |
| 743 | ParseResult TensorLiteralParser::parseElement() { |
| 744 | switch (p.getToken().getKind()) { |
| 745 | // Parse a boolean element. |
| 746 | case Token::kw_true: |
| 747 | case Token::kw_false: |
| 748 | case Token::floatliteral: |
| 749 | case Token::integer: |
| 750 | storage.emplace_back(/*isNegative=*/args: false, args: p.getToken()); |
| 751 | p.consumeToken(); |
| 752 | break; |
| 753 | |
| 754 | // Parse a signed integer or a negative floating-point element. |
| 755 | case Token::minus: |
| 756 | p.consumeToken(kind: Token::minus); |
| 757 | if (!p.getToken().isAny(k1: Token::floatliteral, k2: Token::integer)) |
| 758 | return p.emitError(message: "expected integer or floating point literal" ); |
| 759 | storage.emplace_back(/*isNegative=*/args: true, args: p.getToken()); |
| 760 | p.consumeToken(); |
| 761 | break; |
| 762 | |
| 763 | case Token::string: |
| 764 | storage.emplace_back(/*isNegative=*/args: false, args: p.getToken()); |
| 765 | p.consumeToken(); |
| 766 | break; |
| 767 | |
| 768 | // Parse a complex element of the form '(' element ',' element ')'. |
| 769 | case Token::l_paren: |
| 770 | p.consumeToken(kind: Token::l_paren); |
| 771 | if (parseElement() || |
| 772 | p.parseToken(expectedToken: Token::comma, message: "expected ',' between complex elements" ) || |
| 773 | parseElement() || |
| 774 | p.parseToken(expectedToken: Token::r_paren, message: "expected ')' after complex elements" )) |
| 775 | return failure(); |
| 776 | break; |
| 777 | |
| 778 | default: |
| 779 | return p.emitError(message: "expected element literal of primitive type" ); |
| 780 | } |
| 781 | |
| 782 | return success(); |
| 783 | } |
| 784 | |
| 785 | /// Parse a list of either lists or elements, returning the dimensions of the |
| 786 | /// parsed sub-tensors in dims. For example: |
| 787 | /// parseList([1, 2, 3]) -> Success, [3] |
| 788 | /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] |
| 789 | /// parseList([[1, 2], 3]) -> Failure |
| 790 | /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure |
| 791 | ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { |
| 792 | auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, |
| 793 | const SmallVectorImpl<int64_t> &newDims) -> ParseResult { |
| 794 | if (prevDims == newDims) |
| 795 | return success(); |
| 796 | return p.emitError(message: "tensor literal is invalid; ranks are not consistent " |
| 797 | "between elements" ); |
| 798 | }; |
| 799 | |
| 800 | bool first = true; |
| 801 | SmallVector<int64_t, 4> newDims; |
| 802 | unsigned size = 0; |
| 803 | auto parseOneElement = [&]() -> ParseResult { |
| 804 | SmallVector<int64_t, 4> thisDims; |
| 805 | if (p.getToken().getKind() == Token::l_square) { |
| 806 | if (parseList(dims&: thisDims)) |
| 807 | return failure(); |
| 808 | } else if (parseElement()) { |
| 809 | return failure(); |
| 810 | } |
| 811 | ++size; |
| 812 | if (!first) |
| 813 | return checkDims(newDims, thisDims); |
| 814 | newDims = thisDims; |
| 815 | first = false; |
| 816 | return success(); |
| 817 | }; |
| 818 | if (p.parseCommaSeparatedList(delimiter: Parser::Delimiter::Square, parseElementFn: parseOneElement)) |
| 819 | return failure(); |
| 820 | |
| 821 | // Return the sublists' dimensions with 'size' prepended. |
| 822 | dims.clear(); |
| 823 | dims.push_back(Elt: size); |
| 824 | dims.append(in_start: newDims.begin(), in_end: newDims.end()); |
| 825 | return success(); |
| 826 | } |
| 827 | |
| 828 | //===----------------------------------------------------------------------===// |
| 829 | // DenseArrayAttr Parser |
| 830 | //===----------------------------------------------------------------------===// |
| 831 | |
| 832 | namespace { |
| 833 | /// A generic dense array element parser. It parsers integer and floating point |
| 834 | /// elements. |
| 835 | class DenseArrayElementParser { |
| 836 | public: |
| 837 | explicit DenseArrayElementParser(Type type) : type(type) {} |
| 838 | |
| 839 | /// Parse an integer element. |
| 840 | ParseResult parseIntegerElement(Parser &p); |
| 841 | |
| 842 | /// Parse a floating point element. |
| 843 | ParseResult parseFloatElement(Parser &p); |
| 844 | |
| 845 | /// Convert the current contents to a dense array. |
| 846 | DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); } |
| 847 | |
| 848 | private: |
| 849 | /// Append the raw data of an APInt to the result. |
| 850 | void append(const APInt &data); |
| 851 | |
| 852 | /// The array element type. |
| 853 | Type type; |
| 854 | /// The resultant byte array representing the contents of the array. |
| 855 | std::vector<char> rawData; |
| 856 | /// The number of elements in the array. |
| 857 | int64_t size = 0; |
| 858 | }; |
| 859 | } // namespace |
| 860 | |
| 861 | void DenseArrayElementParser::append(const APInt &data) { |
| 862 | if (data.getBitWidth()) { |
| 863 | assert(data.getBitWidth() % 8 == 0); |
| 864 | unsigned byteSize = data.getBitWidth() / 8; |
| 865 | size_t offset = rawData.size(); |
| 866 | rawData.insert(position: rawData.end(), n: byteSize, x: 0); |
| 867 | llvm::StoreIntToMemory( |
| 868 | IntVal: data, Dst: reinterpret_cast<uint8_t *>(rawData.data() + offset), StoreBytes: byteSize); |
| 869 | } |
| 870 | ++size; |
| 871 | } |
| 872 | |
| 873 | ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { |
| 874 | bool isNegative = p.consumeIf(kind: Token::minus); |
| 875 | |
| 876 | // Parse an integer literal as an APInt. |
| 877 | std::optional<APInt> value; |
| 878 | StringRef spelling = p.getToken().getSpelling(); |
| 879 | if (p.getToken().isAny(k1: Token::kw_true, k2: Token::kw_false)) { |
| 880 | if (!type.isInteger(width: 1)) |
| 881 | return p.emitError(message: "expected i1 type for 'true' or 'false' values" ); |
| 882 | value = APInt(/*numBits=*/8, p.getToken().is(k: Token::kw_true), |
| 883 | !type.isUnsignedInteger()); |
| 884 | p.consumeToken(); |
| 885 | } else if (p.consumeIf(kind: Token::integer)) { |
| 886 | value = buildAttributeAPInt(type, isNegative, spelling); |
| 887 | if (!value) |
| 888 | return p.emitError(message: "integer constant out of range" ); |
| 889 | } else { |
| 890 | return p.emitError(message: "expected integer literal" ); |
| 891 | } |
| 892 | append(data: *value); |
| 893 | return success(); |
| 894 | } |
| 895 | |
| 896 | ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { |
| 897 | bool isNegative = p.consumeIf(kind: Token::minus); |
| 898 | Token token = p.getToken(); |
| 899 | std::optional<APFloat> fromIntLit; |
| 900 | if (failed( |
| 901 | p.parseFloatFromLiteral(result&: fromIntLit, tok: token, isNegative, |
| 902 | semantics: cast<FloatType>(type).getFloatSemantics()))) |
| 903 | return failure(); |
| 904 | p.consumeToken(); |
| 905 | append(data: fromIntLit->bitcastToAPInt()); |
| 906 | return success(); |
| 907 | } |
| 908 | |
| 909 | /// Parse a dense array attribute. |
| 910 | Attribute Parser::parseDenseArrayAttr(Type attrType) { |
| 911 | consumeToken(kind: Token::kw_array); |
| 912 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'array'" )) |
| 913 | return {}; |
| 914 | |
| 915 | SMLoc typeLoc = getToken().getLoc(); |
| 916 | Type eltType = parseType(); |
| 917 | if (!eltType) { |
| 918 | emitError(loc: typeLoc, message: "expected an integer or floating point type" ); |
| 919 | return {}; |
| 920 | } |
| 921 | |
| 922 | // Only bool or integer and floating point elements divisible by bytes are |
| 923 | // supported. |
| 924 | if (!eltType.isIntOrIndexOrFloat()) { |
| 925 | emitError(loc: typeLoc, message: "expected integer or float type, got: " ) << eltType; |
| 926 | return {}; |
| 927 | } |
| 928 | if (!eltType.isInteger(width: 1) && eltType.getIntOrFloatBitWidth() % 8 != 0) { |
| 929 | emitError(loc: typeLoc, message: "element type bitwidth must be a multiple of 8" ); |
| 930 | return {}; |
| 931 | } |
| 932 | |
| 933 | // Check for empty list. |
| 934 | if (consumeIf(Token::greater)) |
| 935 | return DenseArrayAttr::get(eltType, 0, {}); |
| 936 | |
| 937 | if (parseToken(expectedToken: Token::colon, message: "expected ':' after dense array type" )) |
| 938 | return {}; |
| 939 | |
| 940 | DenseArrayElementParser eltParser(eltType); |
| 941 | if (eltType.isIntOrIndex()) { |
| 942 | if (parseCommaSeparatedList( |
| 943 | parseElementFn: [&] { return eltParser.parseIntegerElement(p&: *this); })) |
| 944 | return {}; |
| 945 | } else { |
| 946 | if (parseCommaSeparatedList( |
| 947 | parseElementFn: [&] { return eltParser.parseFloatElement(p&: *this); })) |
| 948 | return {}; |
| 949 | } |
| 950 | if (parseToken(expectedToken: Token::greater, message: "expected '>' to close an array attribute" )) |
| 951 | return {}; |
| 952 | return eltParser.getAttr(); |
| 953 | } |
| 954 | |
| 955 | /// Parse a dense elements attribute. |
| 956 | Attribute Parser::parseDenseElementsAttr(Type attrType) { |
| 957 | auto attribLoc = getToken().getLoc(); |
| 958 | consumeToken(kind: Token::kw_dense); |
| 959 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'dense'" )) |
| 960 | return nullptr; |
| 961 | |
| 962 | // Parse the literal data if necessary. |
| 963 | TensorLiteralParser literalParser(*this); |
| 964 | if (!consumeIf(kind: Token::greater)) { |
| 965 | if (literalParser.parse(/*allowHex=*/true) || |
| 966 | parseToken(expectedToken: Token::greater, message: "expected '>'" )) |
| 967 | return nullptr; |
| 968 | } |
| 969 | |
| 970 | auto type = parseElementsLiteralType(attribLoc, attrType); |
| 971 | if (!type) |
| 972 | return nullptr; |
| 973 | return literalParser.getAttr(attribLoc, type); |
| 974 | } |
| 975 | |
| 976 | Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { |
| 977 | auto loc = getToken().getLoc(); |
| 978 | consumeToken(kind: Token::kw_dense_resource); |
| 979 | if (parseToken(expectedToken: Token::less, message: "expected '<' after 'dense_resource'" )) |
| 980 | return nullptr; |
| 981 | |
| 982 | // Parse the resource handle. |
| 983 | FailureOr<AsmDialectResourceHandle> rawHandle = |
| 984 | parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>()); |
| 985 | if (failed(Result: rawHandle) || parseToken(expectedToken: Token::greater, message: "expected '>'" )) |
| 986 | return nullptr; |
| 987 | |
| 988 | auto *handle = dyn_cast<DenseResourceElementsHandle>(Val: &*rawHandle); |
| 989 | if (!handle) |
| 990 | return emitError(loc, message: "invalid `dense_resource` handle type" ), nullptr; |
| 991 | |
| 992 | // Parse the type of the attribute if the user didn't provide one. |
| 993 | SMLoc typeLoc = loc; |
| 994 | if (!attrType) { |
| 995 | typeLoc = getToken().getLoc(); |
| 996 | if (parseToken(expectedToken: Token::colon, message: "expected ':'" ) || !(attrType = parseType())) |
| 997 | return nullptr; |
| 998 | } |
| 999 | |
| 1000 | ShapedType shapedType = dyn_cast<ShapedType>(attrType); |
| 1001 | if (!shapedType) { |
| 1002 | emitError(loc: typeLoc, message: "`dense_resource` expected a shaped type" ); |
| 1003 | return nullptr; |
| 1004 | } |
| 1005 | |
| 1006 | return DenseResourceElementsAttr::get(shapedType, *handle); |
| 1007 | } |
| 1008 | |
| 1009 | /// Shaped type for elements attribute. |
| 1010 | /// |
| 1011 | /// elements-literal-type ::= vector-type | ranked-tensor-type |
| 1012 | /// |
| 1013 | /// This method also checks the type has static shape. |
| 1014 | ShapedType Parser::parseElementsLiteralType(SMLoc loc, Type type) { |
| 1015 | // If the user didn't provide a type, parse the colon type for the literal. |
| 1016 | if (!type) { |
| 1017 | if (parseToken(expectedToken: Token::colon, message: "expected ':'" )) |
| 1018 | return nullptr; |
| 1019 | if (!(type = parseType())) |
| 1020 | return nullptr; |
| 1021 | } |
| 1022 | |
| 1023 | auto sType = dyn_cast<ShapedType>(type); |
| 1024 | if (!sType) { |
| 1025 | emitError(loc, message: "elements literal must be a shaped type" ); |
| 1026 | return nullptr; |
| 1027 | } |
| 1028 | |
| 1029 | if (!sType.hasStaticShape()) { |
| 1030 | emitError(loc, message: "elements literal type must have static shape" ); |
| 1031 | return nullptr; |
| 1032 | } |
| 1033 | |
| 1034 | return sType; |
| 1035 | } |
| 1036 | |
| 1037 | /// Parse a sparse elements attribute. |
| 1038 | Attribute Parser::parseSparseElementsAttr(Type attrType) { |
| 1039 | SMLoc loc = getToken().getLoc(); |
| 1040 | consumeToken(kind: Token::kw_sparse); |
| 1041 | if (parseToken(expectedToken: Token::less, message: "Expected '<' after 'sparse'" )) |
| 1042 | return nullptr; |
| 1043 | |
| 1044 | // Check for the case where all elements are sparse. The indices are |
| 1045 | // represented by a 2-dimensional shape where the second dimension is the rank |
| 1046 | // of the type. |
| 1047 | Type indiceEltType = builder.getIntegerType(64); |
| 1048 | if (consumeIf(kind: Token::greater)) { |
| 1049 | ShapedType type = parseElementsLiteralType(loc, attrType); |
| 1050 | if (!type) |
| 1051 | return nullptr; |
| 1052 | |
| 1053 | // Construct the sparse elements attr using zero element indice/value |
| 1054 | // attributes. |
| 1055 | ShapedType indicesType = |
| 1056 | RankedTensorType::get({0, type.getRank()}, indiceEltType); |
| 1057 | ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); |
| 1058 | return getChecked<SparseElementsAttr>( |
| 1059 | loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()), |
| 1060 | DenseElementsAttr::get(valuesType, ArrayRef<Attribute>())); |
| 1061 | } |
| 1062 | |
| 1063 | /// Parse the indices. We don't allow hex values here as we may need to use |
| 1064 | /// the inferred shape. |
| 1065 | auto indicesLoc = getToken().getLoc(); |
| 1066 | TensorLiteralParser indiceParser(*this); |
| 1067 | if (indiceParser.parse(/*allowHex=*/false)) |
| 1068 | return nullptr; |
| 1069 | |
| 1070 | if (parseToken(expectedToken: Token::comma, message: "expected ','" )) |
| 1071 | return nullptr; |
| 1072 | |
| 1073 | /// Parse the values. |
| 1074 | auto valuesLoc = getToken().getLoc(); |
| 1075 | TensorLiteralParser valuesParser(*this); |
| 1076 | if (valuesParser.parse(/*allowHex=*/true)) |
| 1077 | return nullptr; |
| 1078 | |
| 1079 | if (parseToken(expectedToken: Token::greater, message: "expected '>'" )) |
| 1080 | return nullptr; |
| 1081 | |
| 1082 | auto type = parseElementsLiteralType(loc, attrType); |
| 1083 | if (!type) |
| 1084 | return nullptr; |
| 1085 | |
| 1086 | // If the indices are a splat, i.e. the literal parser parsed an element and |
| 1087 | // not a list, we set the shape explicitly. The indices are represented by a |
| 1088 | // 2-dimensional shape where the second dimension is the rank of the type. |
| 1089 | // Given that the parsed indices is a splat, we know that we only have one |
| 1090 | // indice and thus one for the first dimension. |
| 1091 | ShapedType indicesType; |
| 1092 | if (indiceParser.getShape().empty()) { |
| 1093 | indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); |
| 1094 | } else { |
| 1095 | // Otherwise, set the shape to the one parsed by the literal parser. |
| 1096 | indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); |
| 1097 | } |
| 1098 | auto indices = indiceParser.getAttr(indicesLoc, indicesType); |
| 1099 | if (!indices) |
| 1100 | return nullptr; |
| 1101 | |
| 1102 | // If the values are a splat, set the shape explicitly based on the number of |
| 1103 | // indices. The number of indices is encoded in the first dimension of the |
| 1104 | // indice shape type. |
| 1105 | auto valuesEltType = type.getElementType(); |
| 1106 | ShapedType valuesType = |
| 1107 | valuesParser.getShape().empty() |
| 1108 | ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) |
| 1109 | : RankedTensorType::get(valuesParser.getShape(), valuesEltType); |
| 1110 | auto values = valuesParser.getAttr(valuesLoc, valuesType); |
| 1111 | if (!values) |
| 1112 | return nullptr; |
| 1113 | |
| 1114 | // Build the sparse elements attribute by the indices and values. |
| 1115 | return getChecked<SparseElementsAttr>(loc, type, indices, values); |
| 1116 | } |
| 1117 | |
| 1118 | Attribute Parser::parseStridedLayoutAttr() { |
| 1119 | // Callback for error emissing at the keyword token location. |
| 1120 | llvm::SMLoc loc = getToken().getLoc(); |
| 1121 | auto errorEmitter = [&] { return emitError(loc); }; |
| 1122 | |
| 1123 | consumeToken(kind: Token::kw_strided); |
| 1124 | if (failed(Result: parseToken(expectedToken: Token::less, message: "expected '<' after 'strided'" )) || |
| 1125 | failed(Result: parseToken(expectedToken: Token::l_square, message: "expected '['" ))) |
| 1126 | return nullptr; |
| 1127 | |
| 1128 | // Parses either an integer token or a question mark token. Reports an error |
| 1129 | // and returns std::nullopt if the current token is neither. The integer token |
| 1130 | // must fit into int64_t limits. |
| 1131 | auto parseStrideOrOffset = [&]() -> std::optional<int64_t> { |
| 1132 | if (consumeIf(Token::question)) |
| 1133 | return ShapedType::kDynamic; |
| 1134 | |
| 1135 | SMLoc loc = getToken().getLoc(); |
| 1136 | auto emitWrongTokenError = [&] { |
| 1137 | emitError(loc, message: "expected a 64-bit signed integer or '?'" ); |
| 1138 | return std::nullopt; |
| 1139 | }; |
| 1140 | |
| 1141 | bool negative = consumeIf(kind: Token::minus); |
| 1142 | |
| 1143 | if (getToken().is(k: Token::integer)) { |
| 1144 | std::optional<uint64_t> value = getToken().getUInt64IntegerValue(); |
| 1145 | if (!value || |
| 1146 | *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) |
| 1147 | return emitWrongTokenError(); |
| 1148 | consumeToken(); |
| 1149 | auto result = static_cast<int64_t>(*value); |
| 1150 | if (negative) |
| 1151 | result = -result; |
| 1152 | |
| 1153 | return result; |
| 1154 | } |
| 1155 | |
| 1156 | return emitWrongTokenError(); |
| 1157 | }; |
| 1158 | |
| 1159 | // Parse strides. |
| 1160 | SmallVector<int64_t> strides; |
| 1161 | if (!getToken().is(k: Token::r_square)) { |
| 1162 | do { |
| 1163 | std::optional<int64_t> stride = parseStrideOrOffset(); |
| 1164 | if (!stride) |
| 1165 | return nullptr; |
| 1166 | strides.push_back(Elt: *stride); |
| 1167 | } while (consumeIf(kind: Token::comma)); |
| 1168 | } |
| 1169 | |
| 1170 | if (failed(Result: parseToken(expectedToken: Token::r_square, message: "expected ']'" ))) |
| 1171 | return nullptr; |
| 1172 | |
| 1173 | // Fast path in absence of offset. |
| 1174 | if (consumeIf(kind: Token::greater)) { |
| 1175 | if (failed(StridedLayoutAttr::verify(errorEmitter, |
| 1176 | /*offset=*/0, strides))) |
| 1177 | return nullptr; |
| 1178 | return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides); |
| 1179 | } |
| 1180 | |
| 1181 | if (failed(Result: parseToken(expectedToken: Token::comma, message: "expected ','" )) || |
| 1182 | failed(Result: parseToken(expectedToken: Token::kw_offset, message: "expected 'offset' after comma" )) || |
| 1183 | failed(Result: parseToken(expectedToken: Token::colon, message: "expected ':' after 'offset'" ))) |
| 1184 | return nullptr; |
| 1185 | |
| 1186 | std::optional<int64_t> offset = parseStrideOrOffset(); |
| 1187 | if (!offset || failed(Result: parseToken(expectedToken: Token::greater, message: "expected '>'" ))) |
| 1188 | return nullptr; |
| 1189 | |
| 1190 | if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides))) |
| 1191 | return nullptr; |
| 1192 | return StridedLayoutAttr::get(getContext(), *offset, strides); |
| 1193 | // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides); |
| 1194 | } |
| 1195 | |
| 1196 | /// Parse a distinct attribute. |
| 1197 | /// |
| 1198 | /// distinct-attribute ::= `distinct` |
| 1199 | /// `[` integer-literal `]<` attribute-value `>` |
| 1200 | /// |
| 1201 | Attribute Parser::parseDistinctAttr(Type type) { |
| 1202 | SMLoc loc = getToken().getLoc(); |
| 1203 | consumeToken(kind: Token::kw_distinct); |
| 1204 | if (parseToken(expectedToken: Token::l_square, message: "expected '[' after 'distinct'" )) |
| 1205 | return {}; |
| 1206 | |
| 1207 | // Parse the distinct integer identifier. |
| 1208 | Token token = getToken(); |
| 1209 | if (parseToken(expectedToken: Token::integer, message: "expected distinct ID" )) |
| 1210 | return {}; |
| 1211 | std::optional<uint64_t> value = token.getUInt64IntegerValue(); |
| 1212 | if (!value) { |
| 1213 | emitError(message: "expected an unsigned 64-bit integer" ); |
| 1214 | return {}; |
| 1215 | } |
| 1216 | |
| 1217 | // Parse the referenced attribute. |
| 1218 | if (parseToken(expectedToken: Token::r_square, message: "expected ']' to close distinct ID" ) || |
| 1219 | parseToken(expectedToken: Token::less, message: "expected '<' after distinct ID" )) |
| 1220 | return {}; |
| 1221 | |
| 1222 | Attribute referencedAttr; |
| 1223 | if (getToken().is(k: Token::greater)) { |
| 1224 | consumeToken(); |
| 1225 | referencedAttr = builder.getUnitAttr(); |
| 1226 | } else { |
| 1227 | referencedAttr = parseAttribute(type); |
| 1228 | if (!referencedAttr) { |
| 1229 | emitError(message: "expected attribute" ); |
| 1230 | return {}; |
| 1231 | } |
| 1232 | |
| 1233 | if (parseToken(expectedToken: Token::greater, message: "expected '>' to close distinct attribute" )) |
| 1234 | return {}; |
| 1235 | } |
| 1236 | |
| 1237 | // Add the distinct attribute to the parser state, if it has not been parsed |
| 1238 | // before. Otherwise, check if the parsed reference attribute matches the one |
| 1239 | // found in the parser state. |
| 1240 | DenseMap<uint64_t, DistinctAttr> &distinctAttrs = |
| 1241 | state.symbols.distinctAttributes; |
| 1242 | auto it = distinctAttrs.find(Val: *value); |
| 1243 | if (it == distinctAttrs.end()) { |
| 1244 | DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr); |
| 1245 | it = distinctAttrs.try_emplace(Key: *value, Args&: distinctAttr).first; |
| 1246 | } else if (it->getSecond().getReferencedAttr() != referencedAttr) { |
| 1247 | emitError(loc, message: "referenced attribute does not match previous definition: " ) |
| 1248 | << it->getSecond().getReferencedAttr(); |
| 1249 | return {}; |
| 1250 | } |
| 1251 | |
| 1252 | return it->getSecond(); |
| 1253 | } |
| 1254 | |