1//===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the parser for the MLIR Types.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Parser.h"
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/BuiltinAttributeInterfaces.h"
16#include "mlir/IR/BuiltinTypeInterfaces.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/OpDefinition.h"
19#include "mlir/IR/TensorEncoding.h"
20#include "mlir/IR/Types.h"
21#include "mlir/Support/LLVM.h"
22#include "mlir/Support/LogicalResult.h"
23#include "llvm/ADT/STLExtras.h"
24#include <cassert>
25#include <cstdint>
26#include <limits>
27#include <optional>
28
29using namespace mlir;
30using namespace mlir::detail;
31
32/// Optionally parse a type.
33OptionalParseResult Parser::parseOptionalType(Type &type) {
34 // There are many different starting tokens for a type, check them here.
35 switch (getToken().getKind()) {
36 case Token::l_paren:
37 case Token::kw_memref:
38 case Token::kw_tensor:
39 case Token::kw_complex:
40 case Token::kw_tuple:
41 case Token::kw_vector:
42 case Token::inttype:
43 case Token::kw_f8E5M2:
44 case Token::kw_f8E4M3FN:
45 case Token::kw_f8E5M2FNUZ:
46 case Token::kw_f8E4M3FNUZ:
47 case Token::kw_f8E4M3B11FNUZ:
48 case Token::kw_bf16:
49 case Token::kw_f16:
50 case Token::kw_tf32:
51 case Token::kw_f32:
52 case Token::kw_f64:
53 case Token::kw_f80:
54 case Token::kw_f128:
55 case Token::kw_index:
56 case Token::kw_none:
57 case Token::exclamation_identifier:
58 return failure(isFailure: !(type = parseType()));
59
60 default:
61 return std::nullopt;
62 }
63}
64
65/// Parse an arbitrary type.
66///
67/// type ::= function-type
68/// | non-function-type
69///
70Type Parser::parseType() {
71 if (getToken().is(k: Token::l_paren))
72 return parseFunctionType();
73 return parseNonFunctionType();
74}
75
76/// Parse a function result type.
77///
78/// function-result-type ::= type-list-parens
79/// | non-function-type
80///
81ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
82 if (getToken().is(k: Token::l_paren))
83 return parseTypeListParens(elements);
84
85 Type t = parseNonFunctionType();
86 if (!t)
87 return failure();
88 elements.push_back(Elt: t);
89 return success();
90}
91
92/// Parse a list of types without an enclosing parenthesis. The list must have
93/// at least one member.
94///
95/// type-list-no-parens ::= type (`,` type)*
96///
97ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
98 auto parseElt = [&]() -> ParseResult {
99 auto elt = parseType();
100 elements.push_back(Elt: elt);
101 return elt ? success() : failure();
102 };
103
104 return parseCommaSeparatedList(parseElementFn: parseElt);
105}
106
107/// Parse a parenthesized list of types.
108///
109/// type-list-parens ::= `(` `)`
110/// | `(` type-list-no-parens `)`
111///
112ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
113 if (parseToken(expectedToken: Token::l_paren, message: "expected '('"))
114 return failure();
115
116 // Handle empty lists.
117 if (getToken().is(k: Token::r_paren))
118 return consumeToken(), success();
119
120 if (parseTypeListNoParens(elements) ||
121 parseToken(expectedToken: Token::r_paren, message: "expected ')'"))
122 return failure();
123 return success();
124}
125
126/// Parse a complex type.
127///
128/// complex-type ::= `complex` `<` type `>`
129///
130Type Parser::parseComplexType() {
131 consumeToken(kind: Token::kw_complex);
132
133 // Parse the '<'.
134 if (parseToken(expectedToken: Token::less, message: "expected '<' in complex type"))
135 return nullptr;
136
137 SMLoc elementTypeLoc = getToken().getLoc();
138 auto elementType = parseType();
139 if (!elementType ||
140 parseToken(expectedToken: Token::greater, message: "expected '>' in complex type"))
141 return nullptr;
142 if (!isa<FloatType>(Val: elementType) && !isa<IntegerType>(Val: elementType))
143 return emitError(loc: elementTypeLoc, message: "invalid element type for complex"),
144 nullptr;
145
146 return ComplexType::get(elementType);
147}
148
149/// Parse a function type.
150///
151/// function-type ::= type-list-parens `->` function-result-type
152///
153Type Parser::parseFunctionType() {
154 assert(getToken().is(Token::l_paren));
155
156 SmallVector<Type, 4> arguments, results;
157 if (parseTypeListParens(elements&: arguments) ||
158 parseToken(expectedToken: Token::arrow, message: "expected '->' in function type") ||
159 parseFunctionResultTypes(elements&: results))
160 return nullptr;
161
162 return builder.getFunctionType(arguments, results);
163}
164
165/// Parse a memref type.
166///
167/// memref-type ::= ranked-memref-type | unranked-memref-type
168///
169/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
170/// (`,` layout-specification)? (`,` memory-space)? `>`
171///
172/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
173///
174/// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
175/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
176/// layout-specification ::= semi-affine-map | strided-layout | attribute
177/// memory-space ::= integer-literal | attribute
178///
179Type Parser::parseMemRefType() {
180 SMLoc loc = getToken().getLoc();
181 consumeToken(kind: Token::kw_memref);
182
183 if (parseToken(expectedToken: Token::less, message: "expected '<' in memref type"))
184 return nullptr;
185
186 bool isUnranked;
187 SmallVector<int64_t, 4> dimensions;
188
189 if (consumeIf(kind: Token::star)) {
190 // This is an unranked memref type.
191 isUnranked = true;
192 if (parseXInDimensionList())
193 return nullptr;
194
195 } else {
196 isUnranked = false;
197 if (parseDimensionListRanked(dimensions))
198 return nullptr;
199 }
200
201 // Parse the element type.
202 auto typeLoc = getToken().getLoc();
203 auto elementType = parseType();
204 if (!elementType)
205 return nullptr;
206
207 // Check that memref is formed from allowed types.
208 if (!BaseMemRefType::isValidElementType(type: elementType))
209 return emitError(loc: typeLoc, message: "invalid memref element type"), nullptr;
210
211 MemRefLayoutAttrInterface layout;
212 Attribute memorySpace;
213
214 auto parseElt = [&]() -> ParseResult {
215 // Either it is MemRefLayoutAttrInterface or memory space attribute.
216 Attribute attr = parseAttribute();
217 if (!attr)
218 return failure();
219
220 if (isa<MemRefLayoutAttrInterface>(attr)) {
221 layout = cast<MemRefLayoutAttrInterface>(attr);
222 } else if (memorySpace) {
223 return emitError(message: "multiple memory spaces specified in memref type");
224 } else {
225 memorySpace = attr;
226 return success();
227 }
228
229 if (isUnranked)
230 return emitError(message: "cannot have affine map for unranked memref type");
231 if (memorySpace)
232 return emitError(message: "expected memory space to be last in memref type");
233
234 return success();
235 };
236
237 // Parse a list of mappings and address space if present.
238 if (!consumeIf(kind: Token::greater)) {
239 // Parse comma separated list of affine maps, followed by memory space.
240 if (parseToken(expectedToken: Token::comma, message: "expected ',' or '>' in memref type") ||
241 parseCommaSeparatedListUntil(rightToken: Token::greater, parseElement: parseElt,
242 /*allowEmptyList=*/false)) {
243 return nullptr;
244 }
245 }
246
247 if (isUnranked)
248 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
249
250 return getChecked<MemRefType>(loc, dimensions, elementType, layout,
251 memorySpace);
252}
253
254/// Parse any type except the function type.
255///
256/// non-function-type ::= integer-type
257/// | index-type
258/// | float-type
259/// | extended-type
260/// | vector-type
261/// | tensor-type
262/// | memref-type
263/// | complex-type
264/// | tuple-type
265/// | none-type
266///
267/// index-type ::= `index`
268/// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
269/// none-type ::= `none`
270///
271Type Parser::parseNonFunctionType() {
272 switch (getToken().getKind()) {
273 default:
274 return (emitWrongTokenError(message: "expected non-function type"), nullptr);
275 case Token::kw_memref:
276 return parseMemRefType();
277 case Token::kw_tensor:
278 return parseTensorType();
279 case Token::kw_complex:
280 return parseComplexType();
281 case Token::kw_tuple:
282 return parseTupleType();
283 case Token::kw_vector:
284 return parseVectorType();
285 // integer-type
286 case Token::inttype: {
287 auto width = getToken().getIntTypeBitwidth();
288 if (!width.has_value())
289 return (emitError(message: "invalid integer width"), nullptr);
290 if (*width > IntegerType::kMaxWidth) {
291 emitError(getToken().getLoc(), "integer bitwidth is limited to ")
292 << IntegerType::kMaxWidth << " bits";
293 return nullptr;
294 }
295
296 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
297 if (std::optional<bool> signedness = getToken().getIntTypeSignedness())
298 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
299
300 consumeToken(kind: Token::inttype);
301 return IntegerType::get(getContext(), *width, signSemantics);
302 }
303
304 // float-type
305 case Token::kw_f8E5M2:
306 consumeToken(kind: Token::kw_f8E5M2);
307 return builder.getFloat8E5M2Type();
308 case Token::kw_f8E4M3FN:
309 consumeToken(kind: Token::kw_f8E4M3FN);
310 return builder.getFloat8E4M3FNType();
311 case Token::kw_f8E5M2FNUZ:
312 consumeToken(kind: Token::kw_f8E5M2FNUZ);
313 return builder.getFloat8E5M2FNUZType();
314 case Token::kw_f8E4M3FNUZ:
315 consumeToken(kind: Token::kw_f8E4M3FNUZ);
316 return builder.getFloat8E4M3FNUZType();
317 case Token::kw_f8E4M3B11FNUZ:
318 consumeToken(kind: Token::kw_f8E4M3B11FNUZ);
319 return builder.getFloat8E4M3B11FNUZType();
320 case Token::kw_bf16:
321 consumeToken(kind: Token::kw_bf16);
322 return builder.getBF16Type();
323 case Token::kw_f16:
324 consumeToken(kind: Token::kw_f16);
325 return builder.getF16Type();
326 case Token::kw_tf32:
327 consumeToken(kind: Token::kw_tf32);
328 return builder.getTF32Type();
329 case Token::kw_f32:
330 consumeToken(kind: Token::kw_f32);
331 return builder.getF32Type();
332 case Token::kw_f64:
333 consumeToken(kind: Token::kw_f64);
334 return builder.getF64Type();
335 case Token::kw_f80:
336 consumeToken(kind: Token::kw_f80);
337 return builder.getF80Type();
338 case Token::kw_f128:
339 consumeToken(kind: Token::kw_f128);
340 return builder.getF128Type();
341
342 // index-type
343 case Token::kw_index:
344 consumeToken(kind: Token::kw_index);
345 return builder.getIndexType();
346
347 // none-type
348 case Token::kw_none:
349 consumeToken(kind: Token::kw_none);
350 return builder.getNoneType();
351
352 // extended type
353 case Token::exclamation_identifier:
354 return parseExtendedType();
355
356 // Handle completion of a dialect type.
357 case Token::code_complete:
358 if (getToken().isCodeCompletionFor(kind: Token::exclamation_identifier))
359 return parseExtendedType();
360 return codeCompleteType();
361 }
362}
363
364/// Parse a tensor type.
365///
366/// tensor-type ::= `tensor` `<` dimension-list type `>`
367/// dimension-list ::= dimension-list-ranked | `*x`
368///
369Type Parser::parseTensorType() {
370 consumeToken(kind: Token::kw_tensor);
371
372 if (parseToken(expectedToken: Token::less, message: "expected '<' in tensor type"))
373 return nullptr;
374
375 bool isUnranked;
376 SmallVector<int64_t, 4> dimensions;
377
378 if (consumeIf(kind: Token::star)) {
379 // This is an unranked tensor type.
380 isUnranked = true;
381
382 if (parseXInDimensionList())
383 return nullptr;
384
385 } else {
386 isUnranked = false;
387 if (parseDimensionListRanked(dimensions))
388 return nullptr;
389 }
390
391 // Parse the element type.
392 auto elementTypeLoc = getToken().getLoc();
393 auto elementType = parseType();
394
395 // Parse an optional encoding attribute.
396 Attribute encoding;
397 if (consumeIf(kind: Token::comma)) {
398 auto parseResult = parseOptionalAttribute(attribute&: encoding);
399 if (parseResult.has_value()) {
400 if (failed(result: parseResult.value()))
401 return nullptr;
402 if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
403 if (failed(v.verifyEncoding(dimensions, elementType,
404 [&] { return emitError(); })))
405 return nullptr;
406 }
407 }
408 }
409
410 if (!elementType || parseToken(expectedToken: Token::greater, message: "expected '>' in tensor type"))
411 return nullptr;
412 if (!TensorType::isValidElementType(type: elementType))
413 return emitError(loc: elementTypeLoc, message: "invalid tensor element type"), nullptr;
414
415 if (isUnranked) {
416 if (encoding)
417 return emitError(message: "cannot apply encoding to unranked tensor"), nullptr;
418 return UnrankedTensorType::get(elementType);
419 }
420 return RankedTensorType::get(dimensions, elementType, encoding);
421}
422
423/// Parse a tuple type.
424///
425/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
426///
427Type Parser::parseTupleType() {
428 consumeToken(kind: Token::kw_tuple);
429
430 // Parse the '<'.
431 if (parseToken(expectedToken: Token::less, message: "expected '<' in tuple type"))
432 return nullptr;
433
434 // Check for an empty tuple by directly parsing '>'.
435 if (consumeIf(Token::greater))
436 return TupleType::get(getContext());
437
438 // Parse the element types and the '>'.
439 SmallVector<Type, 4> types;
440 if (parseTypeListNoParens(elements&: types) ||
441 parseToken(expectedToken: Token::greater, message: "expected '>' in tuple type"))
442 return nullptr;
443
444 return TupleType::get(getContext(), types);
445}
446
447/// Parse a vector type.
448///
449/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
450/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
451/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
452///
453VectorType Parser::parseVectorType() {
454 consumeToken(kind: Token::kw_vector);
455
456 if (parseToken(expectedToken: Token::less, message: "expected '<' in vector type"))
457 return nullptr;
458
459 SmallVector<int64_t, 4> dimensions;
460 SmallVector<bool, 4> scalableDims;
461 if (parseVectorDimensionList(dimensions, scalableDims))
462 return nullptr;
463 if (any_of(Range&: dimensions, P: [](int64_t i) { return i <= 0; }))
464 return emitError(loc: getToken().getLoc(),
465 message: "vector types must have positive constant sizes"),
466 nullptr;
467
468 // Parse the element type.
469 auto typeLoc = getToken().getLoc();
470 auto elementType = parseType();
471 if (!elementType || parseToken(expectedToken: Token::greater, message: "expected '>' in vector type"))
472 return nullptr;
473
474 if (!VectorType::isValidElementType(elementType))
475 return emitError(loc: typeLoc, message: "vector elements must be int/index/float type"),
476 nullptr;
477
478 return VectorType::get(dimensions, elementType, scalableDims);
479}
480
481/// Parse a dimension list in a vector type. This populates the dimension list.
482/// For i-th dimension, `scalableDims[i]` contains either:
483/// * `false` for a non-scalable dimension (e.g. `4`),
484/// * `true` for a scalable dimension (e.g. `[4]`).
485///
486/// vector-dim-list := (static-dim-list `x`)?
487/// static-dim-list ::= static-dim (`x` static-dim)*
488/// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
489///
490ParseResult
491Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
492 SmallVectorImpl<bool> &scalableDims) {
493 // If there is a set of fixed-length dimensions, consume it
494 while (getToken().is(k: Token::integer) || getToken().is(k: Token::l_square)) {
495 int64_t value;
496 bool scalable = consumeIf(kind: Token::l_square);
497 if (parseIntegerInDimensionList(value))
498 return failure();
499 dimensions.push_back(Elt: value);
500 if (scalable) {
501 if (!consumeIf(kind: Token::r_square))
502 return emitWrongTokenError(message: "missing ']' closing scalable dimension");
503 }
504 scalableDims.push_back(Elt: scalable);
505 // Make sure we have an 'x' or something like 'xbf32'.
506 if (parseXInDimensionList())
507 return failure();
508 }
509
510 return success();
511}
512
513/// Parse a dimension list of a tensor or memref type. This populates the
514/// dimension list, using ShapedType::kDynamic for the `?` dimensions if
515/// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing
516/// `x` is configurable.
517///
518/// dimension-list ::= eps | dimension (`x` dimension)*
519/// dimension-list-with-trailing-x ::= (dimension `x`)*
520/// dimension ::= `?` | decimal-literal
521///
522/// When `allowDynamic` is not set, this is used to parse:
523///
524/// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
525/// static-dimension-list-with-trailing-x ::= (dimension `x`)*
526ParseResult
527Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
528 bool allowDynamic, bool withTrailingX) {
529 auto parseDim = [&]() -> LogicalResult {
530 auto loc = getToken().getLoc();
531 if (consumeIf(kind: Token::question)) {
532 if (!allowDynamic)
533 return emitError(loc, message: "expected static shape");
534 dimensions.push_back(ShapedType::kDynamic);
535 } else {
536 int64_t value;
537 if (failed(result: parseIntegerInDimensionList(value)))
538 return failure();
539 dimensions.push_back(Elt: value);
540 }
541 return success();
542 };
543
544 if (withTrailingX) {
545 while (getToken().isAny(k1: Token::integer, k2: Token::question)) {
546 if (failed(result: parseDim()) || failed(result: parseXInDimensionList()))
547 return failure();
548 }
549 return success();
550 }
551
552 if (getToken().isAny(k1: Token::integer, k2: Token::question)) {
553 if (failed(result: parseDim()))
554 return failure();
555 while (getToken().is(k: Token::bare_identifier) &&
556 getTokenSpelling()[0] == 'x') {
557 if (failed(result: parseXInDimensionList()) || failed(result: parseDim()))
558 return failure();
559 }
560 }
561 return success();
562}
563
564ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
565 // Hexadecimal integer literals (starting with `0x`) are not allowed in
566 // aggregate type declarations. Therefore, `0xf32` should be processed as
567 // a sequence of separate elements `0`, `x`, `f32`.
568 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
569 // We can get here only if the token is an integer literal. Hexadecimal
570 // integer literals can only start with `0x` (`1x` wouldn't lex as a
571 // literal, just `1` would, at which point we don't get into this
572 // branch).
573 assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
574 value = 0;
575 state.lex.resetPointer(newPointer: getTokenSpelling().data() + 1);
576 consumeToken();
577 } else {
578 // Make sure this integer value is in bound and valid.
579 std::optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
580 if (!dimension ||
581 *dimension > (uint64_t)std::numeric_limits<int64_t>::max())
582 return emitError(message: "invalid dimension");
583 value = (int64_t)*dimension;
584 consumeToken(kind: Token::integer);
585 }
586 return success();
587}
588
589/// Parse an 'x' token in a dimension list, handling the case where the x is
590/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
591/// token.
592ParseResult Parser::parseXInDimensionList() {
593 if (getToken().isNot(k: Token::bare_identifier) || getTokenSpelling()[0] != 'x')
594 return emitWrongTokenError(message: "expected 'x' in dimension list");
595
596 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
597 if (getTokenSpelling().size() != 1)
598 state.lex.resetPointer(newPointer: getTokenSpelling().data() + 1);
599
600 // Consume the 'x'.
601 consumeToken(kind: Token::bare_identifier);
602
603 return success();
604}
605

source code of mlir/lib/AsmParser/TypeParser.cpp