1//===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the parser for the dialect symbols, such as extended
10// attributes and types.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AsmParserImpl.h"
15#include "Parser.h"
16#include "mlir/AsmParser/AsmParserState.h"
17#include "mlir/IR/AsmState.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/BuiltinAttributeInterfaces.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/Support/LLVM.h"
26#include "llvm/Support/MemoryBuffer.h"
27#include "llvm/Support/SourceMgr.h"
28#include <cassert>
29#include <cstddef>
30#include <utility>
31
32using namespace mlir;
33using namespace mlir::detail;
34using llvm::MemoryBuffer;
35using llvm::SourceMgr;
36
37namespace {
38/// This class provides the main implementation of the DialectAsmParser that
39/// allows for dialects to parse attributes and types. This allows for dialect
40/// hooking into the main MLIR parsing logic.
41class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
42public:
43 CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
44 : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
45 fullSpec(fullSpec) {}
46 ~CustomDialectAsmParser() override = default;
47
48 /// Returns the full specification of the symbol being parsed. This allows
49 /// for using a separate parser if necessary.
50 StringRef getFullSymbolSpec() const override { return fullSpec; }
51
52private:
53 /// The full symbol specification.
54 StringRef fullSpec;
55};
56} // namespace
57
58///
59/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
60/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
61/// | '(' pretty-dialect-sym-contents+ ')'
62/// | '[' pretty-dialect-sym-contents+ ']'
63/// | '{' pretty-dialect-sym-contents+ '}'
64/// | '[^[<({>\])}\0]+'
65///
66ParseResult Parser::parseDialectSymbolBody(StringRef &body,
67 bool &isCodeCompletion) {
68 // Symbol bodies are a relatively unstructured format that contains a series
69 // of properly nested punctuation, with anything else in the middle. Scan
70 // ahead to find it and consume it if successful, otherwise emit an error.
71 const char *curPtr = getTokenSpelling().data();
72
73 // Scan over the nested punctuation, bailing out on error and consuming until
74 // we find the end. We know that we're currently looking at the '<', so we can
75 // go until we find the matching '>' character.
76 assert(*curPtr == '<');
77 SmallVector<char, 8> nestedPunctuation;
78 const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
79
80 // Functor used to emit an unbalanced punctuation error.
81 auto emitPunctError = [&] {
82 return emitError() << "unbalanced '" << nestedPunctuation.back()
83 << "' character in pretty dialect name";
84 };
85 // Functor used to check for unbalanced punctuation.
86 auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
87 if (nestedPunctuation.back() != expectedToken)
88 return emitPunctError();
89 nestedPunctuation.pop_back();
90 return success();
91 };
92 const char *curBufferEnd = state.lex.getBufferEnd();
93 do {
94 // Handle code completions, which may appear in the middle of the symbol
95 // body.
96 if (curPtr == codeCompleteLoc) {
97 isCodeCompletion = true;
98 nestedPunctuation.clear();
99 break;
100 }
101
102 if (curBufferEnd == curPtr) {
103 if (!nestedPunctuation.empty())
104 return emitPunctError();
105 return emitError(message: "unexpected nul or EOF in pretty dialect name");
106 }
107
108 char c = *curPtr++;
109 switch (c) {
110 case '\0':
111 // This also handles the EOF case.
112 if (!nestedPunctuation.empty())
113 return emitPunctError();
114 return emitError(message: "unexpected nul or EOF in pretty dialect name");
115 case '<':
116 case '[':
117 case '(':
118 case '{':
119 nestedPunctuation.push_back(Elt: c);
120 continue;
121
122 case '-':
123 // The sequence `->` is treated as special token.
124 if (*curPtr == '>')
125 ++curPtr;
126 continue;
127
128 case '>':
129 if (failed(Result: checkNestedPunctuation('<')))
130 return failure();
131 break;
132 case ']':
133 if (failed(Result: checkNestedPunctuation('[')))
134 return failure();
135 break;
136 case ')':
137 if (failed(Result: checkNestedPunctuation('(')))
138 return failure();
139 break;
140 case '}':
141 if (failed(Result: checkNestedPunctuation('{')))
142 return failure();
143 break;
144 case '"': {
145 // Dispatch to the lexer to lex past strings.
146 resetToken(tokPos: curPtr - 1);
147 curPtr = state.curToken.getEndLoc().getPointer();
148
149 // Handle code completions, which may appear in the middle of the symbol
150 // body.
151 if (state.curToken.isCodeCompletion()) {
152 isCodeCompletion = true;
153 nestedPunctuation.clear();
154 break;
155 }
156
157 // Otherwise, ensure this token was actually a string.
158 if (state.curToken.isNot(k: Token::string))
159 return failure();
160 break;
161 }
162
163 default:
164 continue;
165 }
166 } while (!nestedPunctuation.empty());
167
168 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
169 // consuming all this stuff, and return.
170 resetToken(tokPos: curPtr);
171
172 unsigned length = curPtr - body.begin();
173 body = StringRef(body.data(), length);
174 return success();
175}
176
177/// Parse an extended dialect symbol.
178template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
179static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
180 SymbolAliasMap &aliases,
181 CreateFn &&createSymbol) {
182 Token tok = p.getToken();
183
184 // Handle code completion of the extended symbol.
185 StringRef identifier = tok.getSpelling().drop_front();
186 if (tok.isCodeCompletion() && identifier.empty())
187 return p.codeCompleteDialectSymbol(aliases);
188
189 // Parse the dialect namespace.
190 SMRange range = p.getToken().getLocRange();
191 SMLoc loc = p.getToken().getLoc();
192 p.consumeToken();
193
194 // Check to see if this is a pretty name.
195 auto [dialectName, symbolData] = identifier.split(Separator: '.');
196 bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
197
198 // Check to see if the symbol has trailing data, i.e. has an immediately
199 // following '<'.
200 bool hasTrailingData =
201 p.getToken().is(k: Token::less) &&
202 identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
203
204 // If there is no '<' token following this, and if the typename contains no
205 // dot, then we are parsing a symbol alias.
206 if (!hasTrailingData && !isPrettyName) {
207 // Check for an alias for this type.
208 auto aliasIt = aliases.find(identifier);
209 if (aliasIt == aliases.end())
210 return (p.emitWrongTokenError(message: "undefined symbol alias id '" + identifier +
211 "'"),
212 nullptr);
213 if (asmState) {
214 if constexpr (std::is_same_v<Symbol, Type>)
215 asmState->addTypeAliasUses(name: identifier, locations: range);
216 else
217 asmState->addAttrAliasUses(name: identifier, locations: range);
218 }
219 return aliasIt->second;
220 }
221
222 // If this isn't an alias, we are parsing a dialect-specific symbol. If the
223 // name contains a dot, then this is the "pretty" form. If not, it is the
224 // verbose form that looks like <...>.
225 if (!isPrettyName) {
226 // Point the symbol data to the end of the dialect name to start.
227 symbolData = StringRef(dialectName.end(), 0);
228
229 // Parse the body of the symbol.
230 bool isCodeCompletion = false;
231 if (p.parseDialectSymbolBody(body&: symbolData, isCodeCompletion))
232 return nullptr;
233 symbolData = symbolData.drop_front();
234
235 // If the body contained a code completion it won't have the trailing `>`
236 // token, so don't drop it.
237 if (!isCodeCompletion)
238 symbolData = symbolData.drop_back();
239 } else {
240 loc = SMLoc::getFromPointer(Ptr: symbolData.data());
241
242 // If the dialect's symbol is followed immediately by a <, then lex the body
243 // of it into prettyName.
244 if (hasTrailingData && p.parseDialectSymbolBody(body&: symbolData))
245 return nullptr;
246 }
247
248 if constexpr (std::is_same_v<Symbol, Attribute>) {
249 auto &cache = p.getState().symbols.attributesCache;
250 auto cacheIt = cache.find(Key: symbolData);
251 // Skip cached attribute if it has type.
252 if (cacheIt != cache.end() && !p.getToken().is(k: Token::colon))
253 return cacheIt->second;
254
255 return cache[symbolData] = createSymbol(dialectName, symbolData, loc);
256 }
257 return createSymbol(dialectName, symbolData, loc);
258}
259
260/// Parse an extended attribute.
261///
262/// extended-attribute ::= (dialect-attribute | attribute-alias)
263/// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
264/// (`:` type)?
265/// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
266/// attribute-alias ::= `#` alias-name
267///
268Attribute Parser::parseExtendedAttr(Type type) {
269 MLIRContext *ctx = getContext();
270 Attribute attr = parseExtendedSymbol<Attribute>(
271 p&: *this, asmState: state.asmState, aliases&: state.symbols.attributeAliasDefinitions,
272 createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
273 // Parse an optional trailing colon type.
274 Type attrType = type;
275 if (consumeIf(kind: Token::colon) && !(attrType = parseType()))
276 return Attribute();
277
278 // If we found a registered dialect, then ask it to parse the attribute.
279 if (Dialect *dialect =
280 builder.getContext()->getOrLoadDialect(name: dialectName)) {
281 // Temporarily reset the lexer to let the dialect parse the attribute.
282 const char *curLexerPos = getToken().getLoc().getPointer();
283 resetToken(tokPos: symbolData.data());
284
285 // Parse the attribute.
286 CustomDialectAsmParser customParser(symbolData, *this);
287 Attribute attr = dialect->parseAttribute(parser&: customParser, type: attrType);
288 resetToken(tokPos: curLexerPos);
289 return attr;
290 }
291
292 // Otherwise, form a new opaque attribute.
293 return OpaqueAttr::getChecked(
294 emitError: [&] { return emitError(loc); }, dialect: StringAttr::get(context: ctx, bytes: dialectName),
295 attrData: symbolData, type: attrType ? attrType : NoneType::get(context: ctx));
296 });
297
298 // Ensure that the attribute has the same type as requested.
299 auto typedAttr = dyn_cast_or_null<TypedAttr>(Val&: attr);
300 if (type && typedAttr && typedAttr.getType() != type) {
301 emitError(message: "attribute type different than expected: expected ")
302 << type << ", but got " << typedAttr.getType();
303 return nullptr;
304 }
305 return attr;
306}
307
308/// Parse an extended type.
309///
310/// extended-type ::= (dialect-type | type-alias)
311/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
312/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
313/// type-alias ::= `!` alias-name
314///
315Type Parser::parseExtendedType() {
316 MLIRContext *ctx = getContext();
317 return parseExtendedSymbol<Type>(
318 p&: *this, asmState: state.asmState, aliases&: state.symbols.typeAliasDefinitions,
319 createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
320 // If we found a registered dialect, then ask it to parse the type.
321 if (auto *dialect = ctx->getOrLoadDialect(name: dialectName)) {
322 // Temporarily reset the lexer to let the dialect parse the type.
323 const char *curLexerPos = getToken().getLoc().getPointer();
324 resetToken(tokPos: symbolData.data());
325
326 // Parse the type.
327 CustomDialectAsmParser customParser(symbolData, *this);
328 Type type = dialect->parseType(parser&: customParser);
329 resetToken(tokPos: curLexerPos);
330 return type;
331 }
332
333 // Otherwise, form a new opaque type.
334 return OpaqueType::getChecked(emitError: [&] { return emitError(loc); },
335 dialectNamespace: StringAttr::get(context: ctx, bytes: dialectName),
336 typeData: symbolData);
337 });
338}
339
340//===----------------------------------------------------------------------===//
341// mlir::parseAttribute/parseType
342//===----------------------------------------------------------------------===//
343
344/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
345/// parsing failed, nullptr is returned.
346template <typename T, typename ParserFn>
347static T parseSymbol(StringRef inputStr, MLIRContext *context,
348 size_t *numReadOut, bool isKnownNullTerminated,
349 llvm::StringMap<Attribute> *attributesCache,
350 ParserFn &&parserFn) {
351 // Set the buffer name to the string being parsed, so that it appears in error
352 // diagnostics.
353 auto memBuffer =
354 isKnownNullTerminated
355 ? MemoryBuffer::getMemBuffer(InputData: inputStr,
356 /*BufferName=*/inputStr)
357 : MemoryBuffer::getMemBufferCopy(InputData: inputStr, /*BufferName=*/inputStr);
358 SourceMgr sourceMgr;
359 sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc());
360 SymbolState aliasState;
361 if (attributesCache)
362 aliasState.attributesCache = *attributesCache;
363
364 ParserConfig config(context);
365 ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
366 /*codeCompleteContext=*/nullptr);
367 Parser parser(state);
368
369 Token startTok = parser.getToken();
370 T symbol = parserFn(parser);
371 if (!symbol)
372 return T();
373
374 if constexpr (std::is_same_v<T, Attribute>) {
375 if (attributesCache)
376 *attributesCache = state.symbols.attributesCache;
377 }
378
379 // Provide the number of bytes that were read.
380 Token endTok = parser.getToken();
381 size_t numRead =
382 endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
383 if (numReadOut) {
384 *numReadOut = numRead;
385 } else if (numRead != inputStr.size()) {
386 parser.emitError(loc: endTok.getLoc()) << "found trailing characters: '"
387 << inputStr.drop_front(N: numRead) << "'";
388 return T();
389 }
390 return symbol;
391}
392
393Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
394 Type type, size_t *numRead,
395 bool isKnownNullTerminated,
396 llvm::StringMap<Attribute> *attributesCache) {
397 return parseSymbol<Attribute>(
398 inputStr: attrStr, context, numReadOut: numRead, isKnownNullTerminated, attributesCache,
399 parserFn: [type](Parser &parser) { return parser.parseAttribute(type); });
400}
401Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
402 bool isKnownNullTerminated) {
403 return parseSymbol<Type>(inputStr: typeStr, context, numReadOut: numRead, isKnownNullTerminated,
404 /*attributesCache=*/nullptr,
405 parserFn: [](Parser &parser) { return parser.parseType(); });
406}
407

source code of mlir/lib/AsmParser/DialectSymbolParser.cpp