1 | //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the parser for the dialect symbols, such as extended |
10 | // attributes and types. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "AsmParserImpl.h" |
15 | #include "Parser.h" |
16 | #include "mlir/AsmParser/AsmParserState.h" |
17 | #include "mlir/IR/AsmState.h" |
18 | #include "mlir/IR/Attributes.h" |
19 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
20 | #include "mlir/IR/BuiltinAttributes.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/Dialect.h" |
23 | #include "mlir/IR/DialectImplementation.h" |
24 | #include "mlir/IR/MLIRContext.h" |
25 | #include "mlir/Support/LLVM.h" |
26 | #include "mlir/Support/LogicalResult.h" |
27 | #include "llvm/Support/MemoryBuffer.h" |
28 | #include "llvm/Support/SourceMgr.h" |
29 | #include <cassert> |
30 | #include <cstddef> |
31 | #include <utility> |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::detail; |
35 | using llvm::MemoryBuffer; |
36 | using llvm::SourceMgr; |
37 | |
38 | namespace { |
39 | /// This class provides the main implementation of the DialectAsmParser that |
40 | /// allows for dialects to parse attributes and types. This allows for dialect |
41 | /// hooking into the main MLIR parsing logic. |
42 | class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> { |
43 | public: |
44 | CustomDialectAsmParser(StringRef fullSpec, Parser &parser) |
45 | : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser), |
46 | fullSpec(fullSpec) {} |
47 | ~CustomDialectAsmParser() override = default; |
48 | |
49 | /// Returns the full specification of the symbol being parsed. This allows |
50 | /// for using a separate parser if necessary. |
51 | StringRef getFullSymbolSpec() const override { return fullSpec; } |
52 | |
53 | private: |
54 | /// The full symbol specification. |
55 | StringRef fullSpec; |
56 | }; |
57 | } // namespace |
58 | |
59 | /// |
60 | /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' |
61 | /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body |
62 | /// | '(' pretty-dialect-sym-contents+ ')' |
63 | /// | '[' pretty-dialect-sym-contents+ ']' |
64 | /// | '{' pretty-dialect-sym-contents+ '}' |
65 | /// | '[^[<({>\])}\0]+' |
66 | /// |
67 | ParseResult Parser::parseDialectSymbolBody(StringRef &body, |
68 | bool &isCodeCompletion) { |
69 | // Symbol bodies are a relatively unstructured format that contains a series |
70 | // of properly nested punctuation, with anything else in the middle. Scan |
71 | // ahead to find it and consume it if successful, otherwise emit an error. |
72 | const char *curPtr = getTokenSpelling().data(); |
73 | |
74 | // Scan over the nested punctuation, bailing out on error and consuming until |
75 | // we find the end. We know that we're currently looking at the '<', so we can |
76 | // go until we find the matching '>' character. |
77 | assert(*curPtr == '<'); |
78 | SmallVector<char, 8> nestedPunctuation; |
79 | const char *codeCompleteLoc = state.lex.getCodeCompleteLoc(); |
80 | |
81 | // Functor used to emit an unbalanced punctuation error. |
82 | auto emitPunctError = [&] { |
83 | return emitError() << "unbalanced '" << nestedPunctuation.back() |
84 | << "' character in pretty dialect name" ; |
85 | }; |
86 | // Functor used to check for unbalanced punctuation. |
87 | auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult { |
88 | if (nestedPunctuation.back() != expectedToken) |
89 | return emitPunctError(); |
90 | nestedPunctuation.pop_back(); |
91 | return success(); |
92 | }; |
93 | do { |
94 | // Handle code completions, which may appear in the middle of the symbol |
95 | // body. |
96 | if (curPtr == codeCompleteLoc) { |
97 | isCodeCompletion = true; |
98 | nestedPunctuation.clear(); |
99 | break; |
100 | } |
101 | |
102 | char c = *curPtr++; |
103 | switch (c) { |
104 | case '\0': |
105 | // This also handles the EOF case. |
106 | if (!nestedPunctuation.empty()) |
107 | return emitPunctError(); |
108 | return emitError(message: "unexpected nul or EOF in pretty dialect name" ); |
109 | case '<': |
110 | case '[': |
111 | case '(': |
112 | case '{': |
113 | nestedPunctuation.push_back(Elt: c); |
114 | continue; |
115 | |
116 | case '-': |
117 | // The sequence `->` is treated as special token. |
118 | if (*curPtr == '>') |
119 | ++curPtr; |
120 | continue; |
121 | |
122 | case '>': |
123 | if (failed(result: checkNestedPunctuation('<'))) |
124 | return failure(); |
125 | break; |
126 | case ']': |
127 | if (failed(result: checkNestedPunctuation('['))) |
128 | return failure(); |
129 | break; |
130 | case ')': |
131 | if (failed(result: checkNestedPunctuation('('))) |
132 | return failure(); |
133 | break; |
134 | case '}': |
135 | if (failed(result: checkNestedPunctuation('{'))) |
136 | return failure(); |
137 | break; |
138 | case '"': { |
139 | // Dispatch to the lexer to lex past strings. |
140 | resetToken(tokPos: curPtr - 1); |
141 | curPtr = state.curToken.getEndLoc().getPointer(); |
142 | |
143 | // Handle code completions, which may appear in the middle of the symbol |
144 | // body. |
145 | if (state.curToken.isCodeCompletion()) { |
146 | isCodeCompletion = true; |
147 | nestedPunctuation.clear(); |
148 | break; |
149 | } |
150 | |
151 | // Otherwise, ensure this token was actually a string. |
152 | if (state.curToken.isNot(k: Token::string)) |
153 | return failure(); |
154 | break; |
155 | } |
156 | |
157 | default: |
158 | continue; |
159 | } |
160 | } while (!nestedPunctuation.empty()); |
161 | |
162 | // Ok, we succeeded, remember where we stopped, reset the lexer to know it is |
163 | // consuming all this stuff, and return. |
164 | resetToken(tokPos: curPtr); |
165 | |
166 | unsigned length = curPtr - body.begin(); |
167 | body = StringRef(body.data(), length); |
168 | return success(); |
169 | } |
170 | |
171 | /// Parse an extended dialect symbol. |
172 | template <typename Symbol, typename SymbolAliasMap, typename CreateFn> |
173 | static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, |
174 | SymbolAliasMap &aliases, |
175 | CreateFn &&createSymbol) { |
176 | Token tok = p.getToken(); |
177 | |
178 | // Handle code completion of the extended symbol. |
179 | StringRef identifier = tok.getSpelling().drop_front(); |
180 | if (tok.isCodeCompletion() && identifier.empty()) |
181 | return p.codeCompleteDialectSymbol(aliases); |
182 | |
183 | // Parse the dialect namespace. |
184 | SMRange range = p.getToken().getLocRange(); |
185 | SMLoc loc = p.getToken().getLoc(); |
186 | p.consumeToken(); |
187 | |
188 | // Check to see if this is a pretty name. |
189 | auto [dialectName, symbolData] = identifier.split(Separator: '.'); |
190 | bool isPrettyName = !symbolData.empty() || identifier.back() == '.'; |
191 | |
192 | // Check to see if the symbol has trailing data, i.e. has an immediately |
193 | // following '<'. |
194 | bool hasTrailingData = |
195 | p.getToken().is(k: Token::less) && |
196 | identifier.bytes_end() == p.getTokenSpelling().bytes_begin(); |
197 | |
198 | // If there is no '<' token following this, and if the typename contains no |
199 | // dot, then we are parsing a symbol alias. |
200 | if (!hasTrailingData && !isPrettyName) { |
201 | // Check for an alias for this type. |
202 | auto aliasIt = aliases.find(identifier); |
203 | if (aliasIt == aliases.end()) |
204 | return (p.emitWrongTokenError(message: "undefined symbol alias id '" + identifier + |
205 | "'" ), |
206 | nullptr); |
207 | if (asmState) { |
208 | if constexpr (std::is_same_v<Symbol, Type>) |
209 | asmState->addTypeAliasUses(name: identifier, locations: range); |
210 | else |
211 | asmState->addAttrAliasUses(name: identifier, locations: range); |
212 | } |
213 | return aliasIt->second; |
214 | } |
215 | |
216 | // If this isn't an alias, we are parsing a dialect-specific symbol. If the |
217 | // name contains a dot, then this is the "pretty" form. If not, it is the |
218 | // verbose form that looks like <...>. |
219 | if (!isPrettyName) { |
220 | // Point the symbol data to the end of the dialect name to start. |
221 | symbolData = StringRef(dialectName.end(), 0); |
222 | |
223 | // Parse the body of the symbol. |
224 | bool isCodeCompletion = false; |
225 | if (p.parseDialectSymbolBody(body&: symbolData, isCodeCompletion)) |
226 | return nullptr; |
227 | symbolData = symbolData.drop_front(); |
228 | |
229 | // If the body contained a code completion it won't have the trailing `>` |
230 | // token, so don't drop it. |
231 | if (!isCodeCompletion) |
232 | symbolData = symbolData.drop_back(); |
233 | } else { |
234 | loc = SMLoc::getFromPointer(Ptr: symbolData.data()); |
235 | |
236 | // If the dialect's symbol is followed immediately by a <, then lex the body |
237 | // of it into prettyName. |
238 | if (hasTrailingData && p.parseDialectSymbolBody(body&: symbolData)) |
239 | return nullptr; |
240 | } |
241 | |
242 | return createSymbol(dialectName, symbolData, loc); |
243 | } |
244 | |
245 | /// Parse an extended attribute. |
246 | /// |
247 | /// extended-attribute ::= (dialect-attribute | attribute-alias) |
248 | /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>` |
249 | /// (`:` type)? |
250 | /// | `#` alias-name pretty-dialect-sym-body? (`:` type)? |
251 | /// attribute-alias ::= `#` alias-name |
252 | /// |
253 | Attribute Parser::parseExtendedAttr(Type type) { |
254 | MLIRContext *ctx = getContext(); |
255 | Attribute attr = parseExtendedSymbol<Attribute>( |
256 | p&: *this, asmState: state.asmState, aliases&: state.symbols.attributeAliasDefinitions, |
257 | createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute { |
258 | // Parse an optional trailing colon type. |
259 | Type attrType = type; |
260 | if (consumeIf(kind: Token::colon) && !(attrType = parseType())) |
261 | return Attribute(); |
262 | |
263 | // If we found a registered dialect, then ask it to parse the attribute. |
264 | if (Dialect *dialect = |
265 | builder.getContext()->getOrLoadDialect(name: dialectName)) { |
266 | // Temporarily reset the lexer to let the dialect parse the attribute. |
267 | const char *curLexerPos = getToken().getLoc().getPointer(); |
268 | resetToken(tokPos: symbolData.data()); |
269 | |
270 | // Parse the attribute. |
271 | CustomDialectAsmParser customParser(symbolData, *this); |
272 | Attribute attr = dialect->parseAttribute(customParser, attrType); |
273 | resetToken(tokPos: curLexerPos); |
274 | return attr; |
275 | } |
276 | |
277 | // Otherwise, form a new opaque attribute. |
278 | return OpaqueAttr::getChecked( |
279 | [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName), |
280 | symbolData, attrType ? attrType : NoneType::get(ctx)); |
281 | }); |
282 | |
283 | // Ensure that the attribute has the same type as requested. |
284 | auto typedAttr = dyn_cast_or_null<TypedAttr>(attr); |
285 | if (type && typedAttr && typedAttr.getType() != type) { |
286 | emitError(message: "attribute type different than expected: expected " ) |
287 | << type << ", but got " << typedAttr.getType(); |
288 | return nullptr; |
289 | } |
290 | return attr; |
291 | } |
292 | |
293 | /// Parse an extended type. |
294 | /// |
295 | /// extended-type ::= (dialect-type | type-alias) |
296 | /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` |
297 | /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? |
298 | /// type-alias ::= `!` alias-name |
299 | /// |
300 | Type Parser::parseExtendedType() { |
301 | MLIRContext *ctx = getContext(); |
302 | return parseExtendedSymbol<Type>( |
303 | p&: *this, asmState: state.asmState, aliases&: state.symbols.typeAliasDefinitions, |
304 | createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { |
305 | // If we found a registered dialect, then ask it to parse the type. |
306 | if (auto *dialect = ctx->getOrLoadDialect(name: dialectName)) { |
307 | // Temporarily reset the lexer to let the dialect parse the type. |
308 | const char *curLexerPos = getToken().getLoc().getPointer(); |
309 | resetToken(tokPos: symbolData.data()); |
310 | |
311 | // Parse the type. |
312 | CustomDialectAsmParser customParser(symbolData, *this); |
313 | Type type = dialect->parseType(customParser); |
314 | resetToken(tokPos: curLexerPos); |
315 | return type; |
316 | } |
317 | |
318 | // Otherwise, form a new opaque type. |
319 | return OpaqueType::getChecked([&] { return emitError(loc); }, |
320 | StringAttr::get(ctx, dialectName), |
321 | symbolData); |
322 | }); |
323 | } |
324 | |
325 | //===----------------------------------------------------------------------===// |
326 | // mlir::parseAttribute/parseType |
327 | //===----------------------------------------------------------------------===// |
328 | |
329 | /// Parses a symbol, of type 'T', and returns it if parsing was successful. If |
330 | /// parsing failed, nullptr is returned. |
331 | template <typename T, typename ParserFn> |
332 | static T parseSymbol(StringRef inputStr, MLIRContext *context, |
333 | size_t *numReadOut, bool isKnownNullTerminated, |
334 | ParserFn &&parserFn) { |
335 | // Set the buffer name to the string being parsed, so that it appears in error |
336 | // diagnostics. |
337 | auto memBuffer = |
338 | isKnownNullTerminated |
339 | ? MemoryBuffer::getMemBuffer(InputData: inputStr, |
340 | /*BufferName=*/inputStr) |
341 | : MemoryBuffer::getMemBufferCopy(InputData: inputStr, /*BufferName=*/inputStr); |
342 | SourceMgr sourceMgr; |
343 | sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc()); |
344 | SymbolState aliasState; |
345 | ParserConfig config(context); |
346 | ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr, |
347 | /*codeCompleteContext=*/nullptr); |
348 | Parser parser(state); |
349 | |
350 | Token startTok = parser.getToken(); |
351 | T symbol = parserFn(parser); |
352 | if (!symbol) |
353 | return T(); |
354 | |
355 | // Provide the number of bytes that were read. |
356 | Token endTok = parser.getToken(); |
357 | size_t numRead = |
358 | endTok.getLoc().getPointer() - startTok.getLoc().getPointer(); |
359 | if (numReadOut) { |
360 | *numReadOut = numRead; |
361 | } else if (numRead != inputStr.size()) { |
362 | parser.emitError(loc: endTok.getLoc()) << "found trailing characters: '" |
363 | << inputStr.drop_front(N: numRead) << "'" ; |
364 | return T(); |
365 | } |
366 | return symbol; |
367 | } |
368 | |
369 | Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, |
370 | Type type, size_t *numRead, |
371 | bool isKnownNullTerminated) { |
372 | return parseSymbol<Attribute>( |
373 | inputStr: attrStr, context, numReadOut: numRead, isKnownNullTerminated, |
374 | parserFn: [type](Parser &parser) { return parser.parseAttribute(type); }); |
375 | } |
376 | Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead, |
377 | bool isKnownNullTerminated) { |
378 | return parseSymbol<Type>(inputStr: typeStr, context, numReadOut: numRead, isKnownNullTerminated, |
379 | parserFn: [](Parser &parser) { return parser.parseType(); }); |
380 | } |
381 | |