| 1 | //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the parser for the dialect symbols, such as extended |
| 10 | // attributes and types. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "AsmParserImpl.h" |
| 15 | #include "Parser.h" |
| 16 | #include "mlir/AsmParser/AsmParserState.h" |
| 17 | #include "mlir/IR/AsmState.h" |
| 18 | #include "mlir/IR/Attributes.h" |
| 19 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
| 20 | #include "mlir/IR/BuiltinAttributes.h" |
| 21 | #include "mlir/IR/BuiltinTypes.h" |
| 22 | #include "mlir/IR/Dialect.h" |
| 23 | #include "mlir/IR/DialectImplementation.h" |
| 24 | #include "mlir/IR/MLIRContext.h" |
| 25 | #include "mlir/Support/LLVM.h" |
| 26 | #include "llvm/Support/MemoryBuffer.h" |
| 27 | #include "llvm/Support/SourceMgr.h" |
| 28 | #include <cassert> |
| 29 | #include <cstddef> |
| 30 | #include <utility> |
| 31 | |
| 32 | using namespace mlir; |
| 33 | using namespace mlir::detail; |
| 34 | using llvm::MemoryBuffer; |
| 35 | using llvm::SourceMgr; |
| 36 | |
| 37 | namespace { |
| 38 | /// This class provides the main implementation of the DialectAsmParser that |
| 39 | /// allows for dialects to parse attributes and types. This allows for dialect |
| 40 | /// hooking into the main MLIR parsing logic. |
| 41 | class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> { |
| 42 | public: |
| 43 | CustomDialectAsmParser(StringRef fullSpec, Parser &parser) |
| 44 | : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser), |
| 45 | fullSpec(fullSpec) {} |
| 46 | ~CustomDialectAsmParser() override = default; |
| 47 | |
| 48 | /// Returns the full specification of the symbol being parsed. This allows |
| 49 | /// for using a separate parser if necessary. |
| 50 | StringRef getFullSymbolSpec() const override { return fullSpec; } |
| 51 | |
| 52 | private: |
| 53 | /// The full symbol specification. |
| 54 | StringRef fullSpec; |
| 55 | }; |
| 56 | } // namespace |
| 57 | |
| 58 | /// |
| 59 | /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' |
| 60 | /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body |
| 61 | /// | '(' pretty-dialect-sym-contents+ ')' |
| 62 | /// | '[' pretty-dialect-sym-contents+ ']' |
| 63 | /// | '{' pretty-dialect-sym-contents+ '}' |
| 64 | /// | '[^[<({>\])}\0]+' |
| 65 | /// |
| 66 | ParseResult Parser::parseDialectSymbolBody(StringRef &body, |
| 67 | bool &isCodeCompletion) { |
| 68 | // Symbol bodies are a relatively unstructured format that contains a series |
| 69 | // of properly nested punctuation, with anything else in the middle. Scan |
| 70 | // ahead to find it and consume it if successful, otherwise emit an error. |
| 71 | const char *curPtr = getTokenSpelling().data(); |
| 72 | |
| 73 | // Scan over the nested punctuation, bailing out on error and consuming until |
| 74 | // we find the end. We know that we're currently looking at the '<', so we can |
| 75 | // go until we find the matching '>' character. |
| 76 | assert(*curPtr == '<'); |
| 77 | SmallVector<char, 8> nestedPunctuation; |
| 78 | const char *codeCompleteLoc = state.lex.getCodeCompleteLoc(); |
| 79 | |
| 80 | // Functor used to emit an unbalanced punctuation error. |
| 81 | auto emitPunctError = [&] { |
| 82 | return emitError() << "unbalanced '" << nestedPunctuation.back() |
| 83 | << "' character in pretty dialect name" ; |
| 84 | }; |
| 85 | // Functor used to check for unbalanced punctuation. |
| 86 | auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult { |
| 87 | if (nestedPunctuation.back() != expectedToken) |
| 88 | return emitPunctError(); |
| 89 | nestedPunctuation.pop_back(); |
| 90 | return success(); |
| 91 | }; |
| 92 | do { |
| 93 | // Handle code completions, which may appear in the middle of the symbol |
| 94 | // body. |
| 95 | if (curPtr == codeCompleteLoc) { |
| 96 | isCodeCompletion = true; |
| 97 | nestedPunctuation.clear(); |
| 98 | break; |
| 99 | } |
| 100 | |
| 101 | char c = *curPtr++; |
| 102 | switch (c) { |
| 103 | case '\0': |
| 104 | // This also handles the EOF case. |
| 105 | if (!nestedPunctuation.empty()) |
| 106 | return emitPunctError(); |
| 107 | return emitError(message: "unexpected nul or EOF in pretty dialect name" ); |
| 108 | case '<': |
| 109 | case '[': |
| 110 | case '(': |
| 111 | case '{': |
| 112 | nestedPunctuation.push_back(Elt: c); |
| 113 | continue; |
| 114 | |
| 115 | case '-': |
| 116 | // The sequence `->` is treated as special token. |
| 117 | if (*curPtr == '>') |
| 118 | ++curPtr; |
| 119 | continue; |
| 120 | |
| 121 | case '>': |
| 122 | if (failed(Result: checkNestedPunctuation('<'))) |
| 123 | return failure(); |
| 124 | break; |
| 125 | case ']': |
| 126 | if (failed(Result: checkNestedPunctuation('['))) |
| 127 | return failure(); |
| 128 | break; |
| 129 | case ')': |
| 130 | if (failed(Result: checkNestedPunctuation('('))) |
| 131 | return failure(); |
| 132 | break; |
| 133 | case '}': |
| 134 | if (failed(Result: checkNestedPunctuation('{'))) |
| 135 | return failure(); |
| 136 | break; |
| 137 | case '"': { |
| 138 | // Dispatch to the lexer to lex past strings. |
| 139 | resetToken(tokPos: curPtr - 1); |
| 140 | curPtr = state.curToken.getEndLoc().getPointer(); |
| 141 | |
| 142 | // Handle code completions, which may appear in the middle of the symbol |
| 143 | // body. |
| 144 | if (state.curToken.isCodeCompletion()) { |
| 145 | isCodeCompletion = true; |
| 146 | nestedPunctuation.clear(); |
| 147 | break; |
| 148 | } |
| 149 | |
| 150 | // Otherwise, ensure this token was actually a string. |
| 151 | if (state.curToken.isNot(k: Token::string)) |
| 152 | return failure(); |
| 153 | break; |
| 154 | } |
| 155 | |
| 156 | default: |
| 157 | continue; |
| 158 | } |
| 159 | } while (!nestedPunctuation.empty()); |
| 160 | |
| 161 | // Ok, we succeeded, remember where we stopped, reset the lexer to know it is |
| 162 | // consuming all this stuff, and return. |
| 163 | resetToken(tokPos: curPtr); |
| 164 | |
| 165 | unsigned length = curPtr - body.begin(); |
| 166 | body = StringRef(body.data(), length); |
| 167 | return success(); |
| 168 | } |
| 169 | |
| 170 | /// Parse an extended dialect symbol. |
| 171 | template <typename Symbol, typename SymbolAliasMap, typename CreateFn> |
| 172 | static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, |
| 173 | SymbolAliasMap &aliases, |
| 174 | CreateFn &&createSymbol) { |
| 175 | Token tok = p.getToken(); |
| 176 | |
| 177 | // Handle code completion of the extended symbol. |
| 178 | StringRef identifier = tok.getSpelling().drop_front(); |
| 179 | if (tok.isCodeCompletion() && identifier.empty()) |
| 180 | return p.codeCompleteDialectSymbol(aliases); |
| 181 | |
| 182 | // Parse the dialect namespace. |
| 183 | SMRange range = p.getToken().getLocRange(); |
| 184 | SMLoc loc = p.getToken().getLoc(); |
| 185 | p.consumeToken(); |
| 186 | |
| 187 | // Check to see if this is a pretty name. |
| 188 | auto [dialectName, symbolData] = identifier.split(Separator: '.'); |
| 189 | bool isPrettyName = !symbolData.empty() || identifier.back() == '.'; |
| 190 | |
| 191 | // Check to see if the symbol has trailing data, i.e. has an immediately |
| 192 | // following '<'. |
| 193 | bool hasTrailingData = |
| 194 | p.getToken().is(k: Token::less) && |
| 195 | identifier.bytes_end() == p.getTokenSpelling().bytes_begin(); |
| 196 | |
| 197 | // If there is no '<' token following this, and if the typename contains no |
| 198 | // dot, then we are parsing a symbol alias. |
| 199 | if (!hasTrailingData && !isPrettyName) { |
| 200 | // Check for an alias for this type. |
| 201 | auto aliasIt = aliases.find(identifier); |
| 202 | if (aliasIt == aliases.end()) |
| 203 | return (p.emitWrongTokenError(message: "undefined symbol alias id '" + identifier + |
| 204 | "'" ), |
| 205 | nullptr); |
| 206 | if (asmState) { |
| 207 | if constexpr (std::is_same_v<Symbol, Type>) |
| 208 | asmState->addTypeAliasUses(name: identifier, locations: range); |
| 209 | else |
| 210 | asmState->addAttrAliasUses(name: identifier, locations: range); |
| 211 | } |
| 212 | return aliasIt->second; |
| 213 | } |
| 214 | |
| 215 | // If this isn't an alias, we are parsing a dialect-specific symbol. If the |
| 216 | // name contains a dot, then this is the "pretty" form. If not, it is the |
| 217 | // verbose form that looks like <...>. |
| 218 | if (!isPrettyName) { |
| 219 | // Point the symbol data to the end of the dialect name to start. |
| 220 | symbolData = StringRef(dialectName.end(), 0); |
| 221 | |
| 222 | // Parse the body of the symbol. |
| 223 | bool isCodeCompletion = false; |
| 224 | if (p.parseDialectSymbolBody(body&: symbolData, isCodeCompletion)) |
| 225 | return nullptr; |
| 226 | symbolData = symbolData.drop_front(); |
| 227 | |
| 228 | // If the body contained a code completion it won't have the trailing `>` |
| 229 | // token, so don't drop it. |
| 230 | if (!isCodeCompletion) |
| 231 | symbolData = symbolData.drop_back(); |
| 232 | } else { |
| 233 | loc = SMLoc::getFromPointer(Ptr: symbolData.data()); |
| 234 | |
| 235 | // If the dialect's symbol is followed immediately by a <, then lex the body |
| 236 | // of it into prettyName. |
| 237 | if (hasTrailingData && p.parseDialectSymbolBody(body&: symbolData)) |
| 238 | return nullptr; |
| 239 | } |
| 240 | |
| 241 | return createSymbol(dialectName, symbolData, loc); |
| 242 | } |
| 243 | |
| 244 | /// Parse an extended attribute. |
| 245 | /// |
| 246 | /// extended-attribute ::= (dialect-attribute | attribute-alias) |
| 247 | /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>` |
| 248 | /// (`:` type)? |
| 249 | /// | `#` alias-name pretty-dialect-sym-body? (`:` type)? |
| 250 | /// attribute-alias ::= `#` alias-name |
| 251 | /// |
| 252 | Attribute Parser::parseExtendedAttr(Type type) { |
| 253 | MLIRContext *ctx = getContext(); |
| 254 | Attribute attr = parseExtendedSymbol<Attribute>( |
| 255 | p&: *this, asmState: state.asmState, aliases&: state.symbols.attributeAliasDefinitions, |
| 256 | createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute { |
| 257 | // Parse an optional trailing colon type. |
| 258 | Type attrType = type; |
| 259 | if (consumeIf(kind: Token::colon) && !(attrType = parseType())) |
| 260 | return Attribute(); |
| 261 | |
| 262 | // If we found a registered dialect, then ask it to parse the attribute. |
| 263 | if (Dialect *dialect = |
| 264 | builder.getContext()->getOrLoadDialect(name: dialectName)) { |
| 265 | // Temporarily reset the lexer to let the dialect parse the attribute. |
| 266 | const char *curLexerPos = getToken().getLoc().getPointer(); |
| 267 | resetToken(tokPos: symbolData.data()); |
| 268 | |
| 269 | // Parse the attribute. |
| 270 | CustomDialectAsmParser customParser(symbolData, *this); |
| 271 | Attribute attr = dialect->parseAttribute(customParser, attrType); |
| 272 | resetToken(tokPos: curLexerPos); |
| 273 | return attr; |
| 274 | } |
| 275 | |
| 276 | // Otherwise, form a new opaque attribute. |
| 277 | return OpaqueAttr::getChecked( |
| 278 | [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName), |
| 279 | symbolData, attrType ? attrType : NoneType::get(ctx)); |
| 280 | }); |
| 281 | |
| 282 | // Ensure that the attribute has the same type as requested. |
| 283 | auto typedAttr = dyn_cast_or_null<TypedAttr>(attr); |
| 284 | if (type && typedAttr && typedAttr.getType() != type) { |
| 285 | emitError(message: "attribute type different than expected: expected " ) |
| 286 | << type << ", but got " << typedAttr.getType(); |
| 287 | return nullptr; |
| 288 | } |
| 289 | return attr; |
| 290 | } |
| 291 | |
| 292 | /// Parse an extended type. |
| 293 | /// |
| 294 | /// extended-type ::= (dialect-type | type-alias) |
| 295 | /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` |
| 296 | /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? |
| 297 | /// type-alias ::= `!` alias-name |
| 298 | /// |
| 299 | Type Parser::parseExtendedType() { |
| 300 | MLIRContext *ctx = getContext(); |
| 301 | return parseExtendedSymbol<Type>( |
| 302 | p&: *this, asmState: state.asmState, aliases&: state.symbols.typeAliasDefinitions, |
| 303 | createSymbol: [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { |
| 304 | // If we found a registered dialect, then ask it to parse the type. |
| 305 | if (auto *dialect = ctx->getOrLoadDialect(name: dialectName)) { |
| 306 | // Temporarily reset the lexer to let the dialect parse the type. |
| 307 | const char *curLexerPos = getToken().getLoc().getPointer(); |
| 308 | resetToken(tokPos: symbolData.data()); |
| 309 | |
| 310 | // Parse the type. |
| 311 | CustomDialectAsmParser customParser(symbolData, *this); |
| 312 | Type type = dialect->parseType(customParser); |
| 313 | resetToken(tokPos: curLexerPos); |
| 314 | return type; |
| 315 | } |
| 316 | |
| 317 | // Otherwise, form a new opaque type. |
| 318 | return OpaqueType::getChecked([&] { return emitError(loc); }, |
| 319 | StringAttr::get(ctx, dialectName), |
| 320 | symbolData); |
| 321 | }); |
| 322 | } |
| 323 | |
| 324 | //===----------------------------------------------------------------------===// |
| 325 | // mlir::parseAttribute/parseType |
| 326 | //===----------------------------------------------------------------------===// |
| 327 | |
| 328 | /// Parses a symbol, of type 'T', and returns it if parsing was successful. If |
| 329 | /// parsing failed, nullptr is returned. |
| 330 | template <typename T, typename ParserFn> |
| 331 | static T parseSymbol(StringRef inputStr, MLIRContext *context, |
| 332 | size_t *numReadOut, bool isKnownNullTerminated, |
| 333 | ParserFn &&parserFn) { |
| 334 | // Set the buffer name to the string being parsed, so that it appears in error |
| 335 | // diagnostics. |
| 336 | auto memBuffer = |
| 337 | isKnownNullTerminated |
| 338 | ? MemoryBuffer::getMemBuffer(InputData: inputStr, |
| 339 | /*BufferName=*/inputStr) |
| 340 | : MemoryBuffer::getMemBufferCopy(InputData: inputStr, /*BufferName=*/inputStr); |
| 341 | SourceMgr sourceMgr; |
| 342 | sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc()); |
| 343 | SymbolState aliasState; |
| 344 | ParserConfig config(context); |
| 345 | ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr, |
| 346 | /*codeCompleteContext=*/nullptr); |
| 347 | Parser parser(state); |
| 348 | |
| 349 | Token startTok = parser.getToken(); |
| 350 | T symbol = parserFn(parser); |
| 351 | if (!symbol) |
| 352 | return T(); |
| 353 | |
| 354 | // Provide the number of bytes that were read. |
| 355 | Token endTok = parser.getToken(); |
| 356 | size_t numRead = |
| 357 | endTok.getLoc().getPointer() - startTok.getLoc().getPointer(); |
| 358 | if (numReadOut) { |
| 359 | *numReadOut = numRead; |
| 360 | } else if (numRead != inputStr.size()) { |
| 361 | parser.emitError(loc: endTok.getLoc()) << "found trailing characters: '" |
| 362 | << inputStr.drop_front(N: numRead) << "'" ; |
| 363 | return T(); |
| 364 | } |
| 365 | return symbol; |
| 366 | } |
| 367 | |
| 368 | Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, |
| 369 | Type type, size_t *numRead, |
| 370 | bool isKnownNullTerminated) { |
| 371 | return parseSymbol<Attribute>( |
| 372 | inputStr: attrStr, context, numReadOut: numRead, isKnownNullTerminated, |
| 373 | parserFn: [type](Parser &parser) { return parser.parseAttribute(type); }); |
| 374 | } |
| 375 | Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead, |
| 376 | bool isKnownNullTerminated) { |
| 377 | return parseSymbol<Type>(inputStr: typeStr, context, numReadOut: numRead, isKnownNullTerminated, |
| 378 | parserFn: [](Parser &parser) { return parser.parseType(); }); |
| 379 | } |
| 380 | |