1 | //===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "FormatGen.h" |
10 | #include "llvm/ADT/StringSwitch.h" |
11 | #include "llvm/Support/SourceMgr.h" |
12 | #include "llvm/TableGen/Error.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::tblgen; |
16 | using llvm::SourceMgr; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // FormatToken |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | SMLoc FormatToken::getLoc() const { |
23 | return SMLoc::getFromPointer(Ptr: spelling.data()); |
24 | } |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // FormatLexer |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc) |
31 | : mgr(mgr), loc(loc), |
32 | curBuffer(mgr.getMemoryBuffer(i: mgr.getMainFileID())->getBuffer()), |
33 | curPtr(curBuffer.begin()) {} |
34 | |
35 | FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) { |
36 | mgr.PrintMessage(Loc: loc, Kind: SourceMgr::DK_Error, Msg: msg); |
37 | llvm::SrcMgr.PrintMessage(Loc: this->loc, Kind: SourceMgr::DK_Note, |
38 | Msg: "in custom assembly format for this operation" ); |
39 | return formToken(kind: FormatToken::error, tokStart: loc.getPointer()); |
40 | } |
41 | |
42 | FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) { |
43 | return emitError(loc: SMLoc::getFromPointer(Ptr: loc), msg); |
44 | } |
45 | |
46 | FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg, |
47 | const Twine ¬e) { |
48 | mgr.PrintMessage(Loc: loc, Kind: SourceMgr::DK_Error, Msg: msg); |
49 | llvm::SrcMgr.PrintMessage(Loc: this->loc, Kind: SourceMgr::DK_Note, |
50 | Msg: "in custom assembly format for this operation" ); |
51 | mgr.PrintMessage(Loc: loc, Kind: SourceMgr::DK_Note, Msg: note); |
52 | return formToken(kind: FormatToken::error, tokStart: loc.getPointer()); |
53 | } |
54 | |
55 | int FormatLexer::getNextChar() { |
56 | char curChar = *curPtr++; |
57 | switch (curChar) { |
58 | default: |
59 | return (unsigned char)curChar; |
60 | case 0: { |
61 | // A nul character in the stream is either the end of the current buffer or |
62 | // a random nul in the file. Disambiguate that here. |
63 | if (curPtr - 1 != curBuffer.end()) |
64 | return 0; |
65 | |
66 | // Otherwise, return end of file. |
67 | --curPtr; |
68 | return EOF; |
69 | } |
70 | case '\n': |
71 | case '\r': |
72 | // Handle the newline character by ignoring it and incrementing the line |
73 | // count. However, be careful about 'dos style' files with \n\r in them. |
74 | // Only treat a \n\r or \r\n as a single line. |
75 | if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) |
76 | ++curPtr; |
77 | return '\n'; |
78 | } |
79 | } |
80 | |
81 | FormatToken FormatLexer::lexToken() { |
82 | const char *tokStart = curPtr; |
83 | |
84 | // This always consumes at least one character. |
85 | int curChar = getNextChar(); |
86 | switch (curChar) { |
87 | default: |
88 | // Handle identifiers: [a-zA-Z_] |
89 | if (isalpha(curChar) || curChar == '_') |
90 | return lexIdentifier(tokStart); |
91 | |
92 | // Unknown character, emit an error. |
93 | return emitError(loc: tokStart, msg: "unexpected character" ); |
94 | case EOF: |
95 | // Return EOF denoting the end of lexing. |
96 | return formToken(kind: FormatToken::eof, tokStart); |
97 | |
98 | // Lex punctuation. |
99 | case '^': |
100 | return formToken(kind: FormatToken::caret, tokStart); |
101 | case ':': |
102 | return formToken(kind: FormatToken::colon, tokStart); |
103 | case ',': |
104 | return formToken(kind: FormatToken::comma, tokStart); |
105 | case '=': |
106 | return formToken(kind: FormatToken::equal, tokStart); |
107 | case '<': |
108 | return formToken(kind: FormatToken::less, tokStart); |
109 | case '>': |
110 | return formToken(kind: FormatToken::greater, tokStart); |
111 | case '?': |
112 | return formToken(kind: FormatToken::question, tokStart); |
113 | case '(': |
114 | return formToken(kind: FormatToken::l_paren, tokStart); |
115 | case ')': |
116 | return formToken(kind: FormatToken::r_paren, tokStart); |
117 | case '*': |
118 | return formToken(kind: FormatToken::star, tokStart); |
119 | case '|': |
120 | return formToken(kind: FormatToken::pipe, tokStart); |
121 | |
122 | // Ignore whitespace characters. |
123 | case 0: |
124 | case ' ': |
125 | case '\t': |
126 | case '\n': |
127 | return lexToken(); |
128 | |
129 | case '`': |
130 | return lexLiteral(tokStart); |
131 | case '$': |
132 | return lexVariable(tokStart); |
133 | case '"': |
134 | return lexString(tokStart); |
135 | } |
136 | } |
137 | |
138 | FormatToken FormatLexer::lexLiteral(const char *tokStart) { |
139 | assert(curPtr[-1] == '`'); |
140 | |
141 | // Lex a literal surrounded by ``. |
142 | while (const char curChar = *curPtr++) { |
143 | if (curChar == '`') |
144 | return formToken(kind: FormatToken::literal, tokStart); |
145 | } |
146 | return emitError(loc: curPtr - 1, msg: "unexpected end of file in literal" ); |
147 | } |
148 | |
149 | FormatToken FormatLexer::lexVariable(const char *tokStart) { |
150 | if (!isalpha(curPtr[0]) && curPtr[0] != '_') |
151 | return emitError(loc: curPtr - 1, msg: "expected variable name" ); |
152 | |
153 | // Otherwise, consume the rest of the characters. |
154 | while (isalnum(*curPtr) || *curPtr == '_') |
155 | ++curPtr; |
156 | return formToken(kind: FormatToken::variable, tokStart); |
157 | } |
158 | |
159 | FormatToken FormatLexer::lexString(const char *tokStart) { |
160 | // Lex until another quote, respecting escapes. |
161 | bool escape = false; |
162 | while (const char curChar = *curPtr++) { |
163 | if (!escape && curChar == '"') |
164 | return formToken(kind: FormatToken::string, tokStart); |
165 | escape = curChar == '\\'; |
166 | } |
167 | return emitError(loc: curPtr - 1, msg: "unexpected end of file in string" ); |
168 | } |
169 | |
170 | FormatToken FormatLexer::lexIdentifier(const char *tokStart) { |
171 | // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* |
172 | while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') |
173 | ++curPtr; |
174 | |
175 | // Check to see if this identifier is a keyword. |
176 | StringRef str(tokStart, curPtr - tokStart); |
177 | auto kind = |
178 | StringSwitch<FormatToken::Kind>(str) |
179 | .Case(S: "attr-dict" , Value: FormatToken::kw_attr_dict) |
180 | .Case(S: "attr-dict-with-keyword" , Value: FormatToken::kw_attr_dict_w_keyword) |
181 | .Case(S: "prop-dict" , Value: FormatToken::kw_prop_dict) |
182 | .Case(S: "custom" , Value: FormatToken::kw_custom) |
183 | .Case(S: "functional-type" , Value: FormatToken::kw_functional_type) |
184 | .Case(S: "oilist" , Value: FormatToken::kw_oilist) |
185 | .Case(S: "operands" , Value: FormatToken::kw_operands) |
186 | .Case(S: "params" , Value: FormatToken::kw_params) |
187 | .Case(S: "ref" , Value: FormatToken::kw_ref) |
188 | .Case(S: "regions" , Value: FormatToken::kw_regions) |
189 | .Case(S: "results" , Value: FormatToken::kw_results) |
190 | .Case(S: "struct" , Value: FormatToken::kw_struct) |
191 | .Case(S: "successors" , Value: FormatToken::kw_successors) |
192 | .Case(S: "type" , Value: FormatToken::kw_type) |
193 | .Case(S: "qualified" , Value: FormatToken::kw_qualified) |
194 | .Default(Value: FormatToken::identifier); |
195 | return FormatToken(kind, str); |
196 | } |
197 | |
198 | //===----------------------------------------------------------------------===// |
199 | // FormatParser |
200 | //===----------------------------------------------------------------------===// |
201 | |
202 | FormatElement::~FormatElement() = default; |
203 | |
204 | FormatParser::~FormatParser() = default; |
205 | |
206 | FailureOr<std::vector<FormatElement *>> FormatParser::parse() { |
207 | SMLoc loc = curToken.getLoc(); |
208 | |
209 | // Parse each of the format elements into the main format. |
210 | std::vector<FormatElement *> elements; |
211 | while (curToken.getKind() != FormatToken::eof) { |
212 | FailureOr<FormatElement *> element = parseElement(ctx: TopLevelContext); |
213 | if (failed(Result: element)) |
214 | return failure(); |
215 | elements.push_back(x: *element); |
216 | } |
217 | |
218 | // Verify the format. |
219 | if (failed(Result: verify(loc, elements))) |
220 | return failure(); |
221 | return elements; |
222 | } |
223 | |
224 | //===----------------------------------------------------------------------===// |
225 | // Element Parsing |
226 | //===----------------------------------------------------------------------===// |
227 | |
228 | FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) { |
229 | if (curToken.is(kind: FormatToken::literal)) |
230 | return parseLiteral(ctx); |
231 | if (curToken.is(kind: FormatToken::string)) |
232 | return parseString(ctx); |
233 | if (curToken.is(kind: FormatToken::variable)) |
234 | return parseVariable(ctx); |
235 | if (curToken.isKeyword()) |
236 | return parseDirective(ctx); |
237 | if (curToken.is(kind: FormatToken::l_paren)) |
238 | return parseOptionalGroup(ctx); |
239 | return emitError(loc: curToken.getLoc(), |
240 | msg: "expected literal, variable, directive, or optional group" ); |
241 | } |
242 | |
243 | FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) { |
244 | FormatToken tok = curToken; |
245 | SMLoc loc = tok.getLoc(); |
246 | consumeToken(); |
247 | |
248 | if (ctx != TopLevelContext) { |
249 | return emitError( |
250 | loc, |
251 | msg: "literals may only be used in the top-level section of the format" ); |
252 | } |
253 | // Get the spelling without the surrounding backticks. |
254 | StringRef value = tok.getSpelling(); |
255 | // Prevents things like `$arg0` or empty literals (when a literal is expected |
256 | // but not found) from getting segmentation faults. |
257 | if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`') |
258 | return emitError(loc: tok.getLoc(), msg: "expected literal, but got '" + value + "'" ); |
259 | value = value.drop_front().drop_back(); |
260 | |
261 | // The parsed literal is a space element (`` or ` `) or a newline. |
262 | if (value.empty() || value == " " || value == "\\n" ) |
263 | return create<WhitespaceElement>(args&: value); |
264 | |
265 | // Check that the parsed literal is valid. |
266 | if (!isValidLiteral(value, emitError: [&](Twine msg) { |
267 | (void)emitError(loc, msg: "expected valid literal but got '" + value + |
268 | "': " + msg); |
269 | })) |
270 | return failure(); |
271 | return create<LiteralElement>(args&: value); |
272 | } |
273 | |
274 | FailureOr<FormatElement *> FormatParser::parseString(Context ctx) { |
275 | FormatToken tok = curToken; |
276 | SMLoc loc = tok.getLoc(); |
277 | consumeToken(); |
278 | |
279 | if (ctx != CustomDirectiveContext) { |
280 | return emitError( |
281 | loc, msg: "strings may only be used as 'custom' directive arguments" ); |
282 | } |
283 | // Escape the string. |
284 | std::string value; |
285 | StringRef contents = tok.getSpelling().drop_front().drop_back(); |
286 | value.reserve(res: contents.size()); |
287 | bool escape = false; |
288 | for (char c : contents) { |
289 | escape = c == '\\'; |
290 | if (!escape) |
291 | value.push_back(c: c); |
292 | } |
293 | return create<StringElement>(args: std::move(value)); |
294 | } |
295 | |
296 | FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) { |
297 | FormatToken tok = curToken; |
298 | SMLoc loc = tok.getLoc(); |
299 | consumeToken(); |
300 | |
301 | // Get the name of the variable without the leading `$`. |
302 | StringRef name = tok.getSpelling().drop_front(); |
303 | return parseVariableImpl(loc, name, ctx); |
304 | } |
305 | |
306 | FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) { |
307 | FormatToken tok = curToken; |
308 | SMLoc loc = tok.getLoc(); |
309 | consumeToken(); |
310 | |
311 | if (tok.is(kind: FormatToken::kw_custom)) |
312 | return parseCustomDirective(loc, ctx); |
313 | if (tok.is(kind: FormatToken::kw_ref)) |
314 | return parseRefDirective(loc, context: ctx); |
315 | if (tok.is(kind: FormatToken::kw_qualified)) |
316 | return parseQualifiedDirective(loc, ctx); |
317 | return parseDirectiveImpl(loc, kind: tok.getKind(), ctx); |
318 | } |
319 | |
320 | FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) { |
321 | SMLoc loc = curToken.getLoc(); |
322 | consumeToken(); |
323 | if (ctx != TopLevelContext) { |
324 | return emitError(loc, |
325 | msg: "optional groups can only be used as top-level elements" ); |
326 | } |
327 | |
328 | // Parse the child elements for this optional group. |
329 | std::vector<FormatElement *> thenElements, elseElements; |
330 | FormatElement *anchor = nullptr; |
331 | auto parseChildElements = |
332 | [this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult { |
333 | do { |
334 | FailureOr<FormatElement *> element = parseElement(ctx: TopLevelContext); |
335 | if (failed(Result: element)) |
336 | return failure(); |
337 | // Check for an anchor. |
338 | if (curToken.is(kind: FormatToken::caret)) { |
339 | if (anchor) { |
340 | return emitError(loc: curToken.getLoc(), |
341 | msg: "only one element can be marked as the anchor of an " |
342 | "optional group" ); |
343 | } |
344 | anchor = *element; |
345 | consumeToken(); |
346 | } |
347 | elements.push_back(x: *element); |
348 | } while (!curToken.is(kind: FormatToken::r_paren)); |
349 | return success(); |
350 | }; |
351 | |
352 | // Parse the 'then' elements. If the anchor was found in this group, then the |
353 | // optional is not inverted. |
354 | if (failed(Result: parseChildElements(thenElements))) |
355 | return failure(); |
356 | consumeToken(); |
357 | bool inverted = !anchor; |
358 | |
359 | // Parse the `else` elements of this optional group. |
360 | if (curToken.is(kind: FormatToken::colon)) { |
361 | consumeToken(); |
362 | if (failed(Result: parseToken( |
363 | kind: FormatToken::l_paren, |
364 | msg: "expected '(' to start else branch of optional group" )) || |
365 | failed(Result: parseChildElements(elseElements))) |
366 | return failure(); |
367 | consumeToken(); |
368 | } |
369 | if (failed(Result: parseToken(kind: FormatToken::question, |
370 | msg: "expected '?' after optional group" ))) |
371 | return failure(); |
372 | |
373 | // The optional group is required to have an anchor. |
374 | if (!anchor) |
375 | return emitError(loc, msg: "optional group has no anchor element" ); |
376 | |
377 | // Verify the child elements. |
378 | if (failed(Result: verifyOptionalGroupElements(loc, elements: thenElements, anchor)) || |
379 | failed(Result: verifyOptionalGroupElements(loc, elements: elseElements, anchor: nullptr))) |
380 | return failure(); |
381 | |
382 | // Get the first parsable element. It must be an element that can be |
383 | // optionally-parsed. |
384 | auto isWhitespace = [](FormatElement *element) { |
385 | return isa<WhitespaceElement>(Val: element); |
386 | }; |
387 | auto thenParseBegin = llvm::find_if_not(Range&: thenElements, P: isWhitespace); |
388 | auto elseParseBegin = llvm::find_if_not(Range&: elseElements, P: isWhitespace); |
389 | unsigned thenParseStart = std::distance(first: thenElements.begin(), last: thenParseBegin); |
390 | unsigned elseParseStart = std::distance(first: elseElements.begin(), last: elseParseBegin); |
391 | |
392 | if (!isa<LiteralElement, VariableElement, CustomDirective>(Val: *thenParseBegin)) { |
393 | return emitError(loc, msg: "first parsable element of an optional group must be " |
394 | "a literal, variable, or custom directive" ); |
395 | } |
396 | return create<OptionalElement>(args: std::move(thenElements), |
397 | args: std::move(elseElements), args&: thenParseStart, |
398 | args&: elseParseStart, args&: anchor, args&: inverted); |
399 | } |
400 | |
401 | FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc, |
402 | Context ctx) { |
403 | if (ctx != TopLevelContext && ctx != StructDirectiveContext) { |
404 | return emitError(loc, msg: "`custom` can only be used at the top-level context " |
405 | "or within a `struct` directive" ); |
406 | } |
407 | |
408 | FailureOr<FormatToken> nameTok; |
409 | if (failed(Result: parseToken(kind: FormatToken::less, |
410 | msg: "expected '<' before custom directive name" )) || |
411 | failed(Result: nameTok = |
412 | parseToken(kind: FormatToken::identifier, |
413 | msg: "expected custom directive name identifier" )) || |
414 | failed(Result: parseToken(kind: FormatToken::greater, |
415 | msg: "expected '>' after custom directive name" )) || |
416 | failed(Result: parseToken(kind: FormatToken::l_paren, |
417 | msg: "expected '(' before custom directive parameters" ))) |
418 | return failure(); |
419 | |
420 | // Parse the arguments. |
421 | std::vector<FormatElement *> arguments; |
422 | while (true) { |
423 | FailureOr<FormatElement *> argument = parseElement(ctx: CustomDirectiveContext); |
424 | if (failed(Result: argument)) |
425 | return failure(); |
426 | arguments.push_back(x: *argument); |
427 | if (!curToken.is(kind: FormatToken::comma)) |
428 | break; |
429 | consumeToken(); |
430 | } |
431 | |
432 | if (failed(Result: parseToken(kind: FormatToken::r_paren, |
433 | msg: "expected ')' after custom directive parameters" ))) |
434 | return failure(); |
435 | |
436 | if (failed(Result: verifyCustomDirectiveArguments(loc, arguments))) |
437 | return failure(); |
438 | return create<CustomDirective>(args: nameTok->getSpelling(), args: std::move(arguments)); |
439 | } |
440 | |
441 | FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc, |
442 | Context context) { |
443 | if (context != CustomDirectiveContext) |
444 | return emitError(loc, msg: "'ref' is only valid within a `custom` directive" ); |
445 | |
446 | FailureOr<FormatElement *> arg; |
447 | if (failed(Result: parseToken(kind: FormatToken::l_paren, |
448 | msg: "expected '(' before argument list" )) || |
449 | failed(Result: arg = parseElement(ctx: RefDirectiveContext)) || |
450 | failed( |
451 | Result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
452 | return failure(); |
453 | |
454 | return create<RefDirective>(args&: *arg); |
455 | } |
456 | |
457 | FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc, |
458 | Context ctx) { |
459 | if (failed(Result: parseToken(kind: FormatToken::l_paren, |
460 | msg: "expected '(' before argument list" ))) |
461 | return failure(); |
462 | FailureOr<FormatElement *> var = parseElement(ctx); |
463 | if (failed(Result: var)) |
464 | return var; |
465 | if (failed(Result: markQualified(loc, element: *var))) |
466 | return failure(); |
467 | if (failed( |
468 | Result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
469 | return failure(); |
470 | return var; |
471 | } |
472 | |
473 | //===----------------------------------------------------------------------===// |
474 | // Utility Functions |
475 | //===----------------------------------------------------------------------===// |
476 | |
477 | bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value, |
478 | bool lastWasPunctuation) { |
479 | if (value.size() != 1 && value != "->" ) |
480 | return true; |
481 | if (lastWasPunctuation) |
482 | return !StringRef(">)}]," ).contains(C: value.front()); |
483 | return !StringRef("<>(){}[]," ).contains(C: value.front()); |
484 | } |
485 | |
486 | bool mlir::tblgen::canFormatStringAsKeyword( |
487 | StringRef value, function_ref<void(Twine)> emitError) { |
488 | if (value.empty()) { |
489 | if (emitError) |
490 | emitError("keywords cannot be empty" ); |
491 | return false; |
492 | } |
493 | if (!isalpha(value.front()) && value.front() != '_') { |
494 | if (emitError) |
495 | emitError("valid keyword starts with a letter or '_'" ); |
496 | return false; |
497 | } |
498 | if (!llvm::all_of(Range: value.drop_front(), P: [](char c) { |
499 | return isalnum(c) || c == '_' || c == '$' || c == '.'; |
500 | })) { |
501 | if (emitError) |
502 | emitError( |
503 | "keywords should contain only alphanum, '_', '$', or '.' characters" ); |
504 | return false; |
505 | } |
506 | return true; |
507 | } |
508 | |
509 | bool mlir::tblgen::isValidLiteral(StringRef value, |
510 | function_ref<void(Twine)> emitError) { |
511 | if (value.empty()) { |
512 | if (emitError) |
513 | emitError("literal can't be empty" ); |
514 | return false; |
515 | } |
516 | char front = value.front(); |
517 | |
518 | // If there is only one character, this must either be punctuation or a |
519 | // single character bare identifier. |
520 | if (value.size() == 1) { |
521 | StringRef bare = "_:,=<>()[]{}?+*" ; |
522 | if (isalpha(front) || bare.contains(C: front)) |
523 | return true; |
524 | if (emitError) |
525 | emitError("single character literal must be a letter or one of '" + bare + |
526 | "'" ); |
527 | return false; |
528 | } |
529 | // Check the punctuation that are larger than a single character. |
530 | if (value == "->" ) |
531 | return true; |
532 | if (value == "..." ) |
533 | return true; |
534 | |
535 | // Otherwise, this must be an identifier. |
536 | return canFormatStringAsKeyword(value, emitError); |
537 | } |
538 | |
539 | //===----------------------------------------------------------------------===// |
540 | // Commandline Options |
541 | //===----------------------------------------------------------------------===// |
542 | |
543 | llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal( |
544 | "asmformat-error-is-fatal" , |
545 | llvm::cl::desc("Emit a fatal error if format parsing fails" ), |
546 | llvm::cl::init(Val: true)); |
547 | |