1//===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the parser for the dialect symbols, such as extended
10// attributes and types.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AsmParserImpl.h"
15#include "Parser.h"
16#include "mlir/AsmParser/AsmParserState.h"
17#include "mlir/IR/AsmState.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/BuiltinAttributeInterfaces.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/Support/LLVM.h"
26#include "mlir/Support/LogicalResult.h"
27#include "llvm/Support/MemoryBuffer.h"
28#include "llvm/Support/SourceMgr.h"
29#include <cassert>
30#include <cstddef>
31#include <utility>
32
33using namespace mlir;
34using namespace mlir::detail;
35using llvm::MemoryBuffer;
36using llvm::SourceMgr;
37
38namespace {
39/// This class provides the main implementation of the DialectAsmParser that
40/// allows for dialects to parse attributes and types. This allows for dialect
41/// hooking into the main MLIR parsing logic.
42class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
43public:
44 CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
45 : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
46 fullSpec(fullSpec) {}
47 ~CustomDialectAsmParser() override = default;
48
49 /// Returns the full specification of the symbol being parsed. This allows
50 /// for using a separate parser if necessary.
51 StringRef getFullSymbolSpec() const override { return fullSpec; }
52
53private:
54 /// The full symbol specification.
55 StringRef fullSpec;
56};
57} // namespace
58
59///
60/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
61/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
62/// | '(' pretty-dialect-sym-contents+ ')'
63/// | '[' pretty-dialect-sym-contents+ ']'
64/// | '{' pretty-dialect-sym-contents+ '}'
65/// | '[^[<({>\])}\0]+'
66///
67ParseResult Parser::parseDialectSymbolBody(StringRef &body,
68 bool &isCodeCompletion) {
69 // Symbol bodies are a relatively unstructured format that contains a series
70 // of properly nested punctuation, with anything else in the middle. Scan
71 // ahead to find it and consume it if successful, otherwise emit an error.
72 const char *curPtr = getTokenSpelling().data();
73
74 // Scan over the nested punctuation, bailing out on error and consuming until
75 // we find the end. We know that we're currently looking at the '<', so we can
76 // go until we find the matching '>' character.
77 assert(*curPtr == '<');
78 SmallVector<char, 8> nestedPunctuation;
79 const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
80
81 // Functor used to emit an unbalanced punctuation error.
82 auto emitPunctError = [&] {
83 return emitError() << "unbalanced '" << nestedPunctuation.back()
84 << "' character in pretty dialect name";
85 };
86 // Functor used to check for unbalanced punctuation.
87 auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
88 if (nestedPunctuation.back() != expectedToken)
89 return emitPunctError();
90 nestedPunctuation.pop_back();
91 return success();
92 };
93 do {
94 // Handle code completions, which may appear in the middle of the symbol
95 // body.
96 if (curPtr == codeCompleteLoc) {
97 isCodeCompletion = true;
98 nestedPunctuation.clear();
99 break;
100 }
101
102 char c = *curPtr++;
103 switch (c) {
104 case '\0':
105 // This also handles the EOF case.
106 if (!nestedPunctuation.empty())
107 return emitPunctError();
108 return emitError(message: "unexpected nul or EOF in pretty dialect name");
109 case '<':
110 case '[':
111 case '(':
112 case '{':
113 nestedPunctuation.push_back(Elt: c);
114 continue;
115
116 case '-':
117 // The sequence `->` is treated as special token.
118 if (*curPtr == '>')
119 ++curPtr;
120 continue;
121
122 case '>':
123 if (failed(result: checkNestedPunctuation('<')))
124 return failure();
125 break;
126 case ']':
127 if (failed(result: checkNestedPunctuation('[')))
128 return failure();
129 break;
130 case ')':
131 if (failed(result: checkNestedPunctuation('(')))
132 return failure();
133 break;
134 case '}':
135 if (failed(result: checkNestedPunctuation('{')))
136 return failure();
137 break;
138 case '"': {
139 // Dispatch to the lexer to lex past strings.
140 resetToken(tokPos: curPtr - 1);
141 curPtr = state.curToken.getEndLoc().getPointer();
142
143 // Handle code completions, which may appear in the middle of the symbol
144 // body.
145 if (state.curToken.isCodeCompletion()) {
146 isCodeCompletion = true;
147 nestedPunctuation.clear();
148 break;
149 }
150
151 // Otherwise, ensure this token was actually a string.
152 if (state.curToken.isNot(k: Token::string))
153 return failure();
154 break;
155 }
156
157 default:
158 continue;
159 }
160 } while (!nestedPunctuation.empty());
161
162 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
163 // consuming all this stuff, and return.
164 resetToken(tokPos: curPtr);
165
166 unsigned length = curPtr - body.begin();
167 body = StringRef(body.data(), length);
168 return success();
169}
170
171/// Parse an extended dialect symbol.
172template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
173static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
174 SymbolAliasMap &aliases,
175 CreateFn &&createSymbol) {
176 Token tok = p.getToken();
177
178 // Handle code completion of the extended symbol.
179 StringRef identifier = tok.getSpelling().drop_front();
180 if (tok.isCodeCompletion() && identifier.empty())
181 return p.codeCompleteDialectSymbol(aliases);
182
183 // Parse the dialect namespace.
184 SMRange range = p.getToken().getLocRange();
185 SMLoc loc = p.getToken().getLoc();
186 p.consumeToken();
187
188 // Check to see if this is a pretty name.
189 auto [dialectName, symbolData] = identifier.split(Separator: '.');
190 bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
191
192 // Check to see if the symbol has trailing data, i.e. has an immediately
193 // following '<'.
194 bool hasTrailingData =
195 p.getToken().is(k: Token::less) &&
196 identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
197
198 // If there is no '<' token following this, and if the typename contains no
199 // dot, then we are parsing a symbol alias.
200 if (!hasTrailingData && !isPrettyName) {
201 // Check for an alias for this type.
202 auto aliasIt = aliases.find(identifier);
203 if (aliasIt == aliases.end())
204 return (p.emitWrongTokenError(message: "undefined symbol alias id '" + identifier +
205 "'"),
206 nullptr);
207 if (asmState) {
208 if constexpr (std::is_same_v<Symbol, Type>)
209 asmState->addTypeAliasUses(name: identifier, locations: range);
210 else
211 asmState->addAttrAliasUses(name: identifier, locations: range);
212 }
213 return aliasIt->second;
214 }
215
216 // If this isn't an alias, we are parsing a dialect-specific symbol. If the
217 // name contains a dot, then this is the "pretty" form. If not, it is the
218 // verbose form that looks like <...>.
219 if (!isPrettyName) {
220 // Point the symbol data to the end of the dialect name to start.
221 symbolData = StringRef(dialectName.end(), 0);
222
223 // Parse the body of the symbol.
224 bool isCodeCompletion = false;
225 if (p.parseDialectSymbolBody(body&: symbolData, isCodeCompletion))
226 return nullptr;
227 symbolData = symbolData.drop_front();
228
229 // If the body contained a code completion it won't have the trailing `>`
230 // token, so don't drop it.
231 if (!isCodeCompletion)
232 symbolData = symbolData.drop_back();
233 } else {
234 loc = SMLoc::getFromPointer(Ptr: symbolData.data());
235
236 // If the dialect's symbol is followed immediately by a <, then lex the body
237 // of it into prettyName.
238 if (hasTrailingData && p.parseDialectSymbolBody(body&: symbolData))
239 return nullptr;
240 }
241
242 return createSymbol(dialectName, symbolData, loc);
243}
244
245/// Parse an extended attribute.
246///
247/// extended-attribute ::= (dialect-attribute | attribute-alias)
248/// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
249/// (`:` type)?
250/// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
251/// attribute-alias ::= `#` alias-name
252///
253Attribute Parser::parseExtendedAttr(Type type) {
254 MLIRContext *ctx = getContext();
255 Attribute attr = parseExtendedSymbol<Attribute>(
256 p&: *this, asmState: state.asmState, aliases&: state.symbols.attributeAliasDefinitions,
257 createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
258 // Parse an optional trailing colon type.
259 Type attrType = type;
260 if (consumeIf(kind: Token::colon) && !(attrType = parseType()))
261 return Attribute();
262
263 // If we found a registered dialect, then ask it to parse the attribute.
264 if (Dialect *dialect =
265 builder.getContext()->getOrLoadDialect(name: dialectName)) {
266 // Temporarily reset the lexer to let the dialect parse the attribute.
267 const char *curLexerPos = getToken().getLoc().getPointer();
268 resetToken(tokPos: symbolData.data());
269
270 // Parse the attribute.
271 CustomDialectAsmParser customParser(symbolData, *this);
272 Attribute attr = dialect->parseAttribute(customParser, attrType);
273 resetToken(tokPos: curLexerPos);
274 return attr;
275 }
276
277 // Otherwise, form a new opaque attribute.
278 return OpaqueAttr::getChecked(
279 [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
280 symbolData, attrType ? attrType : NoneType::get(ctx));
281 });
282
283 // Ensure that the attribute has the same type as requested.
284 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
285 if (type && typedAttr && typedAttr.getType() != type) {
286 emitError(message: "attribute type different than expected: expected ")
287 << type << ", but got " << typedAttr.getType();
288 return nullptr;
289 }
290 return attr;
291}
292
293/// Parse an extended type.
294///
295/// extended-type ::= (dialect-type | type-alias)
296/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
297/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
298/// type-alias ::= `!` alias-name
299///
300Type Parser::parseExtendedType() {
301 MLIRContext *ctx = getContext();
302 return parseExtendedSymbol<Type>(
303 p&: *this, asmState: state.asmState, aliases&: state.symbols.typeAliasDefinitions,
304 createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
305 // If we found a registered dialect, then ask it to parse the type.
306 if (auto *dialect = ctx->getOrLoadDialect(name: dialectName)) {
307 // Temporarily reset the lexer to let the dialect parse the type.
308 const char *curLexerPos = getToken().getLoc().getPointer();
309 resetToken(tokPos: symbolData.data());
310
311 // Parse the type.
312 CustomDialectAsmParser customParser(symbolData, *this);
313 Type type = dialect->parseType(customParser);
314 resetToken(tokPos: curLexerPos);
315 return type;
316 }
317
318 // Otherwise, form a new opaque type.
319 return OpaqueType::getChecked([&] { return emitError(loc); },
320 StringAttr::get(ctx, dialectName),
321 symbolData);
322 });
323}
324
325//===----------------------------------------------------------------------===//
326// mlir::parseAttribute/parseType
327//===----------------------------------------------------------------------===//
328
329/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
330/// parsing failed, nullptr is returned.
331template <typename T, typename ParserFn>
332static T parseSymbol(StringRef inputStr, MLIRContext *context,
333 size_t *numReadOut, bool isKnownNullTerminated,
334 ParserFn &&parserFn) {
335 // Set the buffer name to the string being parsed, so that it appears in error
336 // diagnostics.
337 auto memBuffer =
338 isKnownNullTerminated
339 ? MemoryBuffer::getMemBuffer(InputData: inputStr,
340 /*BufferName=*/inputStr)
341 : MemoryBuffer::getMemBufferCopy(InputData: inputStr, /*BufferName=*/inputStr);
342 SourceMgr sourceMgr;
343 sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc());
344 SymbolState aliasState;
345 ParserConfig config(context);
346 ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
347 /*codeCompleteContext=*/nullptr);
348 Parser parser(state);
349
350 Token startTok = parser.getToken();
351 T symbol = parserFn(parser);
352 if (!symbol)
353 return T();
354
355 // Provide the number of bytes that were read.
356 Token endTok = parser.getToken();
357 size_t numRead =
358 endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
359 if (numReadOut) {
360 *numReadOut = numRead;
361 } else if (numRead != inputStr.size()) {
362 parser.emitError(loc: endTok.getLoc()) << "found trailing characters: '"
363 << inputStr.drop_front(N: numRead) << "'";
364 return T();
365 }
366 return symbol;
367}
368
369Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
370 Type type, size_t *numRead,
371 bool isKnownNullTerminated) {
372 return parseSymbol<Attribute>(
373 inputStr: attrStr, context, numReadOut: numRead, isKnownNullTerminated,
374 parserFn: [type](Parser &parser) { return parser.parseAttribute(type); });
375}
376Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
377 bool isKnownNullTerminated) {
378 return parseSymbol<Type>(inputStr: typeStr, context, numReadOut: numRead, isKnownNullTerminated,
379 parserFn: [](Parser &parser) { return parser.parseType(); });
380}
381

source code of mlir/lib/AsmParser/DialectSymbolParser.cpp