1 | //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the parser for the dialect symbols, such as extended |
10 | // attributes and types. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "AsmParserImpl.h" |
15 | #include "Parser.h" |
16 | #include "mlir/AsmParser/AsmParserState.h" |
17 | #include "mlir/IR/AsmState.h" |
18 | #include "mlir/IR/Attributes.h" |
19 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
20 | #include "mlir/IR/BuiltinAttributes.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/Dialect.h" |
23 | #include "mlir/IR/DialectImplementation.h" |
24 | #include "mlir/IR/MLIRContext.h" |
25 | #include "mlir/Support/LLVM.h" |
26 | #include "llvm/Support/MemoryBuffer.h" |
27 | #include "llvm/Support/SourceMgr.h" |
28 | #include <cassert> |
29 | #include <cstddef> |
30 | #include <utility> |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::detail; |
34 | using llvm::MemoryBuffer; |
35 | using llvm::SourceMgr; |
36 | |
37 | namespace { |
38 | /// This class provides the main implementation of the DialectAsmParser that |
39 | /// allows for dialects to parse attributes and types. This allows for dialect |
40 | /// hooking into the main MLIR parsing logic. |
41 | class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> { |
42 | public: |
43 | CustomDialectAsmParser(StringRef fullSpec, Parser &parser) |
44 | : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser), |
45 | fullSpec(fullSpec) {} |
46 | ~CustomDialectAsmParser() override = default; |
47 | |
48 | /// Returns the full specification of the symbol being parsed. This allows |
49 | /// for using a separate parser if necessary. |
50 | StringRef getFullSymbolSpec() const override { return fullSpec; } |
51 | |
52 | private: |
53 | /// The full symbol specification. |
54 | StringRef fullSpec; |
55 | }; |
56 | } // namespace |
57 | |
58 | /// |
59 | /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' |
60 | /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body |
61 | /// | '(' pretty-dialect-sym-contents+ ')' |
62 | /// | '[' pretty-dialect-sym-contents+ ']' |
63 | /// | '{' pretty-dialect-sym-contents+ '}' |
64 | /// | '[^[<({>\])}\0]+' |
65 | /// |
66 | ParseResult Parser::parseDialectSymbolBody(StringRef &body, |
67 | bool &isCodeCompletion) { |
68 | // Symbol bodies are a relatively unstructured format that contains a series |
69 | // of properly nested punctuation, with anything else in the middle. Scan |
70 | // ahead to find it and consume it if successful, otherwise emit an error. |
71 | const char *curPtr = getTokenSpelling().data(); |
72 | |
73 | // Scan over the nested punctuation, bailing out on error and consuming until |
74 | // we find the end. We know that we're currently looking at the '<', so we can |
75 | // go until we find the matching '>' character. |
76 | assert(*curPtr == '<'); |
77 | SmallVector<char, 8> nestedPunctuation; |
78 | const char *codeCompleteLoc = state.lex.getCodeCompleteLoc(); |
79 | |
80 | // Functor used to emit an unbalanced punctuation error. |
81 | auto emitPunctError = [&] { |
82 | return emitError() << "unbalanced '" << nestedPunctuation.back() |
83 | << "' character in pretty dialect name" ; |
84 | }; |
85 | // Functor used to check for unbalanced punctuation. |
86 | auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult { |
87 | if (nestedPunctuation.back() != expectedToken) |
88 | return emitPunctError(); |
89 | nestedPunctuation.pop_back(); |
90 | return success(); |
91 | }; |
92 | do { |
93 | // Handle code completions, which may appear in the middle of the symbol |
94 | // body. |
95 | if (curPtr == codeCompleteLoc) { |
96 | isCodeCompletion = true; |
97 | nestedPunctuation.clear(); |
98 | break; |
99 | } |
100 | |
101 | char c = *curPtr++; |
102 | switch (c) { |
103 | case '\0': |
104 | // This also handles the EOF case. |
105 | if (!nestedPunctuation.empty()) |
106 | return emitPunctError(); |
107 | return emitError(message: "unexpected nul or EOF in pretty dialect name" ); |
108 | case '<': |
109 | case '[': |
110 | case '(': |
111 | case '{': |
112 | nestedPunctuation.push_back(Elt: c); |
113 | continue; |
114 | |
115 | case '-': |
116 | // The sequence `->` is treated as special token. |
117 | if (*curPtr == '>') |
118 | ++curPtr; |
119 | continue; |
120 | |
121 | case '>': |
122 | if (failed(Result: checkNestedPunctuation('<'))) |
123 | return failure(); |
124 | break; |
125 | case ']': |
126 | if (failed(Result: checkNestedPunctuation('['))) |
127 | return failure(); |
128 | break; |
129 | case ')': |
130 | if (failed(Result: checkNestedPunctuation('('))) |
131 | return failure(); |
132 | break; |
133 | case '}': |
134 | if (failed(Result: checkNestedPunctuation('{'))) |
135 | return failure(); |
136 | break; |
137 | case '"': { |
138 | // Dispatch to the lexer to lex past strings. |
139 | resetToken(tokPos: curPtr - 1); |
140 | curPtr = state.curToken.getEndLoc().getPointer(); |
141 | |
142 | // Handle code completions, which may appear in the middle of the symbol |
143 | // body. |
144 | if (state.curToken.isCodeCompletion()) { |
145 | isCodeCompletion = true; |
146 | nestedPunctuation.clear(); |
147 | break; |
148 | } |
149 | |
150 | // Otherwise, ensure this token was actually a string. |
151 | if (state.curToken.isNot(k: Token::string)) |
152 | return failure(); |
153 | break; |
154 | } |
155 | |
156 | default: |
157 | continue; |
158 | } |
159 | } while (!nestedPunctuation.empty()); |
160 | |
161 | // Ok, we succeeded, remember where we stopped, reset the lexer to know it is |
162 | // consuming all this stuff, and return. |
163 | resetToken(tokPos: curPtr); |
164 | |
165 | unsigned length = curPtr - body.begin(); |
166 | body = StringRef(body.data(), length); |
167 | return success(); |
168 | } |
169 | |
170 | /// Parse an extended dialect symbol. |
171 | template <typename Symbol, typename SymbolAliasMap, typename CreateFn> |
172 | static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, |
173 | SymbolAliasMap &aliases, |
174 | CreateFn &&createSymbol) { |
175 | Token tok = p.getToken(); |
176 | |
177 | // Handle code completion of the extended symbol. |
178 | StringRef identifier = tok.getSpelling().drop_front(); |
179 | if (tok.isCodeCompletion() && identifier.empty()) |
180 | return p.codeCompleteDialectSymbol(aliases); |
181 | |
182 | // Parse the dialect namespace. |
183 | SMRange range = p.getToken().getLocRange(); |
184 | SMLoc loc = p.getToken().getLoc(); |
185 | p.consumeToken(); |
186 | |
187 | // Check to see if this is a pretty name. |
188 | auto [dialectName, symbolData] = identifier.split(Separator: '.'); |
189 | bool isPrettyName = !symbolData.empty() || identifier.back() == '.'; |
190 | |
191 | // Check to see if the symbol has trailing data, i.e. has an immediately |
192 | // following '<'. |
193 | bool hasTrailingData = |
194 | p.getToken().is(k: Token::less) && |
195 | identifier.bytes_end() == p.getTokenSpelling().bytes_begin(); |
196 | |
197 | // If there is no '<' token following this, and if the typename contains no |
198 | // dot, then we are parsing a symbol alias. |
199 | if (!hasTrailingData && !isPrettyName) { |
200 | // Check for an alias for this type. |
201 | auto aliasIt = aliases.find(identifier); |
202 | if (aliasIt == aliases.end()) |
203 | return (p.emitWrongTokenError(message: "undefined symbol alias id '" + identifier + |
204 | "'" ), |
205 | nullptr); |
206 | if (asmState) { |
207 | if constexpr (std::is_same_v<Symbol, Type>) |
208 | asmState->addTypeAliasUses(name: identifier, locations: range); |
209 | else |
210 | asmState->addAttrAliasUses(name: identifier, locations: range); |
211 | } |
212 | return aliasIt->second; |
213 | } |
214 | |
215 | // If this isn't an alias, we are parsing a dialect-specific symbol. If the |
216 | // name contains a dot, then this is the "pretty" form. If not, it is the |
217 | // verbose form that looks like <...>. |
218 | if (!isPrettyName) { |
219 | // Point the symbol data to the end of the dialect name to start. |
220 | symbolData = StringRef(dialectName.end(), 0); |
221 | |
222 | // Parse the body of the symbol. |
223 | bool isCodeCompletion = false; |
224 | if (p.parseDialectSymbolBody(body&: symbolData, isCodeCompletion)) |
225 | return nullptr; |
226 | symbolData = symbolData.drop_front(); |
227 | |
228 | // If the body contained a code completion it won't have the trailing `>` |
229 | // token, so don't drop it. |
230 | if (!isCodeCompletion) |
231 | symbolData = symbolData.drop_back(); |
232 | } else { |
233 | loc = SMLoc::getFromPointer(Ptr: symbolData.data()); |
234 | |
235 | // If the dialect's symbol is followed immediately by a <, then lex the body |
236 | // of it into prettyName. |
237 | if (hasTrailingData && p.parseDialectSymbolBody(body&: symbolData)) |
238 | return nullptr; |
239 | } |
240 | |
241 | return createSymbol(dialectName, symbolData, loc); |
242 | } |
243 | |
244 | /// Parse an extended attribute. |
245 | /// |
246 | /// extended-attribute ::= (dialect-attribute | attribute-alias) |
247 | /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>` |
248 | /// (`:` type)? |
249 | /// | `#` alias-name pretty-dialect-sym-body? (`:` type)? |
250 | /// attribute-alias ::= `#` alias-name |
251 | /// |
252 | Attribute Parser::parseExtendedAttr(Type type) { |
253 | MLIRContext *ctx = getContext(); |
254 | Attribute attr = parseExtendedSymbol<Attribute>( |
255 | p&: *this, asmState: state.asmState, aliases&: state.symbols.attributeAliasDefinitions, |
256 | createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute { |
257 | // Parse an optional trailing colon type. |
258 | Type attrType = type; |
259 | if (consumeIf(kind: Token::colon) && !(attrType = parseType())) |
260 | return Attribute(); |
261 | |
262 | // If we found a registered dialect, then ask it to parse the attribute. |
263 | if (Dialect *dialect = |
264 | builder.getContext()->getOrLoadDialect(name: dialectName)) { |
265 | // Temporarily reset the lexer to let the dialect parse the attribute. |
266 | const char *curLexerPos = getToken().getLoc().getPointer(); |
267 | resetToken(tokPos: symbolData.data()); |
268 | |
269 | // Parse the attribute. |
270 | CustomDialectAsmParser customParser(symbolData, *this); |
271 | Attribute attr = dialect->parseAttribute(customParser, attrType); |
272 | resetToken(tokPos: curLexerPos); |
273 | return attr; |
274 | } |
275 | |
276 | // Otherwise, form a new opaque attribute. |
277 | return OpaqueAttr::getChecked( |
278 | [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName), |
279 | symbolData, attrType ? attrType : NoneType::get(ctx)); |
280 | }); |
281 | |
282 | // Ensure that the attribute has the same type as requested. |
283 | auto typedAttr = dyn_cast_or_null<TypedAttr>(attr); |
284 | if (type && typedAttr && typedAttr.getType() != type) { |
285 | emitError(message: "attribute type different than expected: expected " ) |
286 | << type << ", but got " << typedAttr.getType(); |
287 | return nullptr; |
288 | } |
289 | return attr; |
290 | } |
291 | |
292 | /// Parse an extended type. |
293 | /// |
294 | /// extended-type ::= (dialect-type | type-alias) |
295 | /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` |
296 | /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? |
297 | /// type-alias ::= `!` alias-name |
298 | /// |
299 | Type Parser::parseExtendedType() { |
300 | MLIRContext *ctx = getContext(); |
301 | return parseExtendedSymbol<Type>( |
302 | p&: *this, asmState: state.asmState, aliases&: state.symbols.typeAliasDefinitions, |
303 | createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { |
304 | // If we found a registered dialect, then ask it to parse the type. |
305 | if (auto *dialect = ctx->getOrLoadDialect(name: dialectName)) { |
306 | // Temporarily reset the lexer to let the dialect parse the type. |
307 | const char *curLexerPos = getToken().getLoc().getPointer(); |
308 | resetToken(tokPos: symbolData.data()); |
309 | |
310 | // Parse the type. |
311 | CustomDialectAsmParser customParser(symbolData, *this); |
312 | Type type = dialect->parseType(customParser); |
313 | resetToken(tokPos: curLexerPos); |
314 | return type; |
315 | } |
316 | |
317 | // Otherwise, form a new opaque type. |
318 | return OpaqueType::getChecked([&] { return emitError(loc); }, |
319 | StringAttr::get(ctx, dialectName), |
320 | symbolData); |
321 | }); |
322 | } |
323 | |
324 | //===----------------------------------------------------------------------===// |
325 | // mlir::parseAttribute/parseType |
326 | //===----------------------------------------------------------------------===// |
327 | |
328 | /// Parses a symbol, of type 'T', and returns it if parsing was successful. If |
329 | /// parsing failed, nullptr is returned. |
330 | template <typename T, typename ParserFn> |
331 | static T parseSymbol(StringRef inputStr, MLIRContext *context, |
332 | size_t *numReadOut, bool isKnownNullTerminated, |
333 | ParserFn &&parserFn) { |
334 | // Set the buffer name to the string being parsed, so that it appears in error |
335 | // diagnostics. |
336 | auto memBuffer = |
337 | isKnownNullTerminated |
338 | ? MemoryBuffer::getMemBuffer(InputData: inputStr, |
339 | /*BufferName=*/inputStr) |
340 | : MemoryBuffer::getMemBufferCopy(InputData: inputStr, /*BufferName=*/inputStr); |
341 | SourceMgr sourceMgr; |
342 | sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc()); |
343 | SymbolState aliasState; |
344 | ParserConfig config(context); |
345 | ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr, |
346 | /*codeCompleteContext=*/nullptr); |
347 | Parser parser(state); |
348 | |
349 | Token startTok = parser.getToken(); |
350 | T symbol = parserFn(parser); |
351 | if (!symbol) |
352 | return T(); |
353 | |
354 | // Provide the number of bytes that were read. |
355 | Token endTok = parser.getToken(); |
356 | size_t numRead = |
357 | endTok.getLoc().getPointer() - startTok.getLoc().getPointer(); |
358 | if (numReadOut) { |
359 | *numReadOut = numRead; |
360 | } else if (numRead != inputStr.size()) { |
361 | parser.emitError(loc: endTok.getLoc()) << "found trailing characters: '" |
362 | << inputStr.drop_front(N: numRead) << "'" ; |
363 | return T(); |
364 | } |
365 | return symbol; |
366 | } |
367 | |
368 | Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, |
369 | Type type, size_t *numRead, |
370 | bool isKnownNullTerminated) { |
371 | return parseSymbol<Attribute>( |
372 | inputStr: attrStr, context, numReadOut: numRead, isKnownNullTerminated, |
373 | parserFn: [type](Parser &parser) { return parser.parseAttribute(type); }); |
374 | } |
375 | Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead, |
376 | bool isKnownNullTerminated) { |
377 | return parseSymbol<Type>(inputStr: typeStr, context, numReadOut: numRead, isKnownNullTerminated, |
378 | parserFn: [](Parser &parser) { return parser.parseType(); }); |
379 | } |
380 | |