1 | //===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "FormatGen.h" |
10 | #include "llvm/ADT/StringSwitch.h" |
11 | #include "llvm/Support/SourceMgr.h" |
12 | #include "llvm/TableGen/Error.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::tblgen; |
16 | |
17 | //===----------------------------------------------------------------------===// |
18 | // FormatToken |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | SMLoc FormatToken::getLoc() const { |
22 | return SMLoc::getFromPointer(Ptr: spelling.data()); |
23 | } |
24 | |
25 | //===----------------------------------------------------------------------===// |
26 | // FormatLexer |
27 | //===----------------------------------------------------------------------===// |
28 | |
29 | FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc) |
30 | : mgr(mgr), loc(loc), |
31 | curBuffer(mgr.getMemoryBuffer(i: mgr.getMainFileID())->getBuffer()), |
32 | curPtr(curBuffer.begin()) {} |
33 | |
34 | FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) { |
35 | mgr.PrintMessage(Loc: loc, Kind: llvm::SourceMgr::DK_Error, Msg: msg); |
36 | llvm::SrcMgr.PrintMessage(Loc: this->loc, Kind: llvm::SourceMgr::DK_Note, |
37 | Msg: "in custom assembly format for this operation" ); |
38 | return formToken(kind: FormatToken::error, tokStart: loc.getPointer()); |
39 | } |
40 | |
41 | FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) { |
42 | return emitError(loc: SMLoc::getFromPointer(Ptr: loc), msg); |
43 | } |
44 | |
45 | FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg, |
46 | const Twine ¬e) { |
47 | mgr.PrintMessage(Loc: loc, Kind: llvm::SourceMgr::DK_Error, Msg: msg); |
48 | llvm::SrcMgr.PrintMessage(Loc: this->loc, Kind: llvm::SourceMgr::DK_Note, |
49 | Msg: "in custom assembly format for this operation" ); |
50 | mgr.PrintMessage(Loc: loc, Kind: llvm::SourceMgr::DK_Note, Msg: note); |
51 | return formToken(kind: FormatToken::error, tokStart: loc.getPointer()); |
52 | } |
53 | |
54 | int FormatLexer::getNextChar() { |
55 | char curChar = *curPtr++; |
56 | switch (curChar) { |
57 | default: |
58 | return (unsigned char)curChar; |
59 | case 0: { |
60 | // A nul character in the stream is either the end of the current buffer or |
61 | // a random nul in the file. Disambiguate that here. |
62 | if (curPtr - 1 != curBuffer.end()) |
63 | return 0; |
64 | |
65 | // Otherwise, return end of file. |
66 | --curPtr; |
67 | return EOF; |
68 | } |
69 | case '\n': |
70 | case '\r': |
71 | // Handle the newline character by ignoring it and incrementing the line |
72 | // count. However, be careful about 'dos style' files with \n\r in them. |
73 | // Only treat a \n\r or \r\n as a single line. |
74 | if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) |
75 | ++curPtr; |
76 | return '\n'; |
77 | } |
78 | } |
79 | |
80 | FormatToken FormatLexer::lexToken() { |
81 | const char *tokStart = curPtr; |
82 | |
83 | // This always consumes at least one character. |
84 | int curChar = getNextChar(); |
85 | switch (curChar) { |
86 | default: |
87 | // Handle identifiers: [a-zA-Z_] |
88 | if (isalpha(curChar) || curChar == '_') |
89 | return lexIdentifier(tokStart); |
90 | |
91 | // Unknown character, emit an error. |
92 | return emitError(loc: tokStart, msg: "unexpected character" ); |
93 | case EOF: |
94 | // Return EOF denoting the end of lexing. |
95 | return formToken(kind: FormatToken::eof, tokStart); |
96 | |
97 | // Lex punctuation. |
98 | case '^': |
99 | return formToken(kind: FormatToken::caret, tokStart); |
100 | case ':': |
101 | return formToken(kind: FormatToken::colon, tokStart); |
102 | case ',': |
103 | return formToken(kind: FormatToken::comma, tokStart); |
104 | case '=': |
105 | return formToken(kind: FormatToken::equal, tokStart); |
106 | case '<': |
107 | return formToken(kind: FormatToken::less, tokStart); |
108 | case '>': |
109 | return formToken(kind: FormatToken::greater, tokStart); |
110 | case '?': |
111 | return formToken(kind: FormatToken::question, tokStart); |
112 | case '(': |
113 | return formToken(kind: FormatToken::l_paren, tokStart); |
114 | case ')': |
115 | return formToken(kind: FormatToken::r_paren, tokStart); |
116 | case '*': |
117 | return formToken(kind: FormatToken::star, tokStart); |
118 | case '|': |
119 | return formToken(kind: FormatToken::pipe, tokStart); |
120 | |
121 | // Ignore whitespace characters. |
122 | case 0: |
123 | case ' ': |
124 | case '\t': |
125 | case '\n': |
126 | return lexToken(); |
127 | |
128 | case '`': |
129 | return lexLiteral(tokStart); |
130 | case '$': |
131 | return lexVariable(tokStart); |
132 | case '"': |
133 | return lexString(tokStart); |
134 | } |
135 | } |
136 | |
137 | FormatToken FormatLexer::lexLiteral(const char *tokStart) { |
138 | assert(curPtr[-1] == '`'); |
139 | |
140 | // Lex a literal surrounded by ``. |
141 | while (const char curChar = *curPtr++) { |
142 | if (curChar == '`') |
143 | return formToken(kind: FormatToken::literal, tokStart); |
144 | } |
145 | return emitError(loc: curPtr - 1, msg: "unexpected end of file in literal" ); |
146 | } |
147 | |
148 | FormatToken FormatLexer::lexVariable(const char *tokStart) { |
149 | if (!isalpha(curPtr[0]) && curPtr[0] != '_') |
150 | return emitError(loc: curPtr - 1, msg: "expected variable name" ); |
151 | |
152 | // Otherwise, consume the rest of the characters. |
153 | while (isalnum(*curPtr) || *curPtr == '_') |
154 | ++curPtr; |
155 | return formToken(kind: FormatToken::variable, tokStart); |
156 | } |
157 | |
158 | FormatToken FormatLexer::lexString(const char *tokStart) { |
159 | // Lex until another quote, respecting escapes. |
160 | bool escape = false; |
161 | while (const char curChar = *curPtr++) { |
162 | if (!escape && curChar == '"') |
163 | return formToken(kind: FormatToken::string, tokStart); |
164 | escape = curChar == '\\'; |
165 | } |
166 | return emitError(loc: curPtr - 1, msg: "unexpected end of file in string" ); |
167 | } |
168 | |
169 | FormatToken FormatLexer::lexIdentifier(const char *tokStart) { |
170 | // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* |
171 | while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') |
172 | ++curPtr; |
173 | |
174 | // Check to see if this identifier is a keyword. |
175 | StringRef str(tokStart, curPtr - tokStart); |
176 | auto kind = |
177 | StringSwitch<FormatToken::Kind>(str) |
178 | .Case(S: "attr-dict" , Value: FormatToken::kw_attr_dict) |
179 | .Case(S: "attr-dict-with-keyword" , Value: FormatToken::kw_attr_dict_w_keyword) |
180 | .Case(S: "prop-dict" , Value: FormatToken::kw_prop_dict) |
181 | .Case(S: "custom" , Value: FormatToken::kw_custom) |
182 | .Case(S: "functional-type" , Value: FormatToken::kw_functional_type) |
183 | .Case(S: "oilist" , Value: FormatToken::kw_oilist) |
184 | .Case(S: "operands" , Value: FormatToken::kw_operands) |
185 | .Case(S: "params" , Value: FormatToken::kw_params) |
186 | .Case(S: "ref" , Value: FormatToken::kw_ref) |
187 | .Case(S: "regions" , Value: FormatToken::kw_regions) |
188 | .Case(S: "results" , Value: FormatToken::kw_results) |
189 | .Case(S: "struct" , Value: FormatToken::kw_struct) |
190 | .Case(S: "successors" , Value: FormatToken::kw_successors) |
191 | .Case(S: "type" , Value: FormatToken::kw_type) |
192 | .Case(S: "qualified" , Value: FormatToken::kw_qualified) |
193 | .Default(Value: FormatToken::identifier); |
194 | return FormatToken(kind, str); |
195 | } |
196 | |
197 | //===----------------------------------------------------------------------===// |
198 | // FormatParser |
199 | //===----------------------------------------------------------------------===// |
200 | |
201 | FormatElement::~FormatElement() = default; |
202 | |
203 | FormatParser::~FormatParser() = default; |
204 | |
205 | FailureOr<std::vector<FormatElement *>> FormatParser::parse() { |
206 | SMLoc loc = curToken.getLoc(); |
207 | |
208 | // Parse each of the format elements into the main format. |
209 | std::vector<FormatElement *> elements; |
210 | while (curToken.getKind() != FormatToken::eof) { |
211 | FailureOr<FormatElement *> element = parseElement(ctx: TopLevelContext); |
212 | if (failed(result: element)) |
213 | return failure(); |
214 | elements.push_back(x: *element); |
215 | } |
216 | |
217 | // Verify the format. |
218 | if (failed(result: verify(loc, elements))) |
219 | return failure(); |
220 | return elements; |
221 | } |
222 | |
223 | //===----------------------------------------------------------------------===// |
224 | // Element Parsing |
225 | |
226 | FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) { |
227 | if (curToken.is(kind: FormatToken::literal)) |
228 | return parseLiteral(ctx); |
229 | if (curToken.is(kind: FormatToken::string)) |
230 | return parseString(ctx); |
231 | if (curToken.is(kind: FormatToken::variable)) |
232 | return parseVariable(ctx); |
233 | if (curToken.isKeyword()) |
234 | return parseDirective(ctx); |
235 | if (curToken.is(kind: FormatToken::l_paren)) |
236 | return parseOptionalGroup(ctx); |
237 | return emitError(loc: curToken.getLoc(), |
238 | msg: "expected literal, variable, directive, or optional group" ); |
239 | } |
240 | |
241 | FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) { |
242 | FormatToken tok = curToken; |
243 | SMLoc loc = tok.getLoc(); |
244 | consumeToken(); |
245 | |
246 | if (ctx != TopLevelContext) { |
247 | return emitError( |
248 | loc, |
249 | msg: "literals may only be used in the top-level section of the format" ); |
250 | } |
251 | // Get the spelling without the surrounding backticks. |
252 | StringRef value = tok.getSpelling(); |
253 | // Prevents things like `$arg0` or empty literals (when a literal is expected |
254 | // but not found) from getting segmentation faults. |
255 | if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`') |
256 | return emitError(loc: tok.getLoc(), msg: "expected literal, but got '" + value + "'" ); |
257 | value = value.drop_front().drop_back(); |
258 | |
259 | // The parsed literal is a space element (`` or ` `) or a newline. |
260 | if (value.empty() || value == " " || value == "\\n" ) |
261 | return create<WhitespaceElement>(args&: value); |
262 | |
263 | // Check that the parsed literal is valid. |
264 | if (!isValidLiteral(value, emitError: [&](Twine msg) { |
265 | (void)emitError(loc, msg: "expected valid literal but got '" + value + |
266 | "': " + msg); |
267 | })) |
268 | return failure(); |
269 | return create<LiteralElement>(args&: value); |
270 | } |
271 | |
272 | FailureOr<FormatElement *> FormatParser::parseString(Context ctx) { |
273 | FormatToken tok = curToken; |
274 | SMLoc loc = tok.getLoc(); |
275 | consumeToken(); |
276 | |
277 | if (ctx != CustomDirectiveContext) { |
278 | return emitError( |
279 | loc, msg: "strings may only be used as 'custom' directive arguments" ); |
280 | } |
281 | // Escape the string. |
282 | std::string value; |
283 | StringRef contents = tok.getSpelling().drop_front().drop_back(); |
284 | value.reserve(res: contents.size()); |
285 | bool escape = false; |
286 | for (char c : contents) { |
287 | escape = c == '\\'; |
288 | if (!escape) |
289 | value.push_back(c: c); |
290 | } |
291 | return create<StringElement>(args: std::move(value)); |
292 | } |
293 | |
294 | FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) { |
295 | FormatToken tok = curToken; |
296 | SMLoc loc = tok.getLoc(); |
297 | consumeToken(); |
298 | |
299 | // Get the name of the variable without the leading `$`. |
300 | StringRef name = tok.getSpelling().drop_front(); |
301 | return parseVariableImpl(loc, name, ctx); |
302 | } |
303 | |
304 | FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) { |
305 | FormatToken tok = curToken; |
306 | SMLoc loc = tok.getLoc(); |
307 | consumeToken(); |
308 | |
309 | if (tok.is(kind: FormatToken::kw_custom)) |
310 | return parseCustomDirective(loc, ctx); |
311 | return parseDirectiveImpl(loc, kind: tok.getKind(), ctx); |
312 | } |
313 | |
314 | FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) { |
315 | SMLoc loc = curToken.getLoc(); |
316 | consumeToken(); |
317 | if (ctx != TopLevelContext) { |
318 | return emitError(loc, |
319 | msg: "optional groups can only be used as top-level elements" ); |
320 | } |
321 | |
322 | // Parse the child elements for this optional group. |
323 | std::vector<FormatElement *> thenElements, elseElements; |
324 | FormatElement *anchor = nullptr; |
325 | auto parseChildElements = |
326 | [this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult { |
327 | do { |
328 | FailureOr<FormatElement *> element = parseElement(ctx: TopLevelContext); |
329 | if (failed(result: element)) |
330 | return failure(); |
331 | // Check for an anchor. |
332 | if (curToken.is(kind: FormatToken::caret)) { |
333 | if (anchor) { |
334 | return emitError(loc: curToken.getLoc(), |
335 | msg: "only one element can be marked as the anchor of an " |
336 | "optional group" ); |
337 | } |
338 | anchor = *element; |
339 | consumeToken(); |
340 | } |
341 | elements.push_back(x: *element); |
342 | } while (!curToken.is(kind: FormatToken::r_paren)); |
343 | return success(); |
344 | }; |
345 | |
346 | // Parse the 'then' elements. If the anchor was found in this group, then the |
347 | // optional is not inverted. |
348 | if (failed(result: parseChildElements(thenElements))) |
349 | return failure(); |
350 | consumeToken(); |
351 | bool inverted = !anchor; |
352 | |
353 | // Parse the `else` elements of this optional group. |
354 | if (curToken.is(kind: FormatToken::colon)) { |
355 | consumeToken(); |
356 | if (failed(result: parseToken( |
357 | kind: FormatToken::l_paren, |
358 | msg: "expected '(' to start else branch of optional group" )) || |
359 | failed(result: parseChildElements(elseElements))) |
360 | return failure(); |
361 | consumeToken(); |
362 | } |
363 | if (failed(result: parseToken(kind: FormatToken::question, |
364 | msg: "expected '?' after optional group" ))) |
365 | return failure(); |
366 | |
367 | // The optional group is required to have an anchor. |
368 | if (!anchor) |
369 | return emitError(loc, msg: "optional group has no anchor element" ); |
370 | |
371 | // Verify the child elements. |
372 | if (failed(result: verifyOptionalGroupElements(loc, elements: thenElements, anchor)) || |
373 | failed(result: verifyOptionalGroupElements(loc, elements: elseElements, anchor: nullptr))) |
374 | return failure(); |
375 | |
376 | // Get the first parsable element. It must be an element that can be |
377 | // optionally-parsed. |
378 | auto isWhitespace = [](FormatElement *element) { |
379 | return isa<WhitespaceElement>(Val: element); |
380 | }; |
381 | auto thenParseBegin = llvm::find_if_not(Range&: thenElements, P: isWhitespace); |
382 | auto elseParseBegin = llvm::find_if_not(Range&: elseElements, P: isWhitespace); |
383 | unsigned thenParseStart = std::distance(first: thenElements.begin(), last: thenParseBegin); |
384 | unsigned elseParseStart = std::distance(first: elseElements.begin(), last: elseParseBegin); |
385 | |
386 | if (!isa<LiteralElement, VariableElement, CustomDirective>(Val: *thenParseBegin)) { |
387 | return emitError(loc, msg: "first parsable element of an optional group must be " |
388 | "a literal, variable, or custom directive" ); |
389 | } |
390 | return create<OptionalElement>(args: std::move(thenElements), |
391 | args: std::move(elseElements), args&: thenParseStart, |
392 | args&: elseParseStart, args&: anchor, args&: inverted); |
393 | } |
394 | |
395 | FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc, |
396 | Context ctx) { |
397 | if (ctx != TopLevelContext) |
398 | return emitError(loc, msg: "'custom' is only valid as a top-level directive" ); |
399 | |
400 | FailureOr<FormatToken> nameTok; |
401 | if (failed(result: parseToken(kind: FormatToken::less, |
402 | msg: "expected '<' before custom directive name" )) || |
403 | failed(result: nameTok = |
404 | parseToken(kind: FormatToken::identifier, |
405 | msg: "expected custom directive name identifier" )) || |
406 | failed(result: parseToken(kind: FormatToken::greater, |
407 | msg: "expected '>' after custom directive name" )) || |
408 | failed(result: parseToken(kind: FormatToken::l_paren, |
409 | msg: "expected '(' before custom directive parameters" ))) |
410 | return failure(); |
411 | |
412 | // Parse the arguments. |
413 | std::vector<FormatElement *> arguments; |
414 | while (true) { |
415 | FailureOr<FormatElement *> argument = parseElement(ctx: CustomDirectiveContext); |
416 | if (failed(result: argument)) |
417 | return failure(); |
418 | arguments.push_back(x: *argument); |
419 | if (!curToken.is(kind: FormatToken::comma)) |
420 | break; |
421 | consumeToken(); |
422 | } |
423 | |
424 | if (failed(result: parseToken(kind: FormatToken::r_paren, |
425 | msg: "expected ')' after custom directive parameters" ))) |
426 | return failure(); |
427 | |
428 | if (failed(result: verifyCustomDirectiveArguments(loc, arguments))) |
429 | return failure(); |
430 | return create<CustomDirective>(args: nameTok->getSpelling(), args: std::move(arguments)); |
431 | } |
432 | |
433 | //===----------------------------------------------------------------------===// |
434 | // Utility Functions |
435 | //===----------------------------------------------------------------------===// |
436 | |
437 | bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value, |
438 | bool lastWasPunctuation) { |
439 | if (value.size() != 1 && value != "->" ) |
440 | return true; |
441 | if (lastWasPunctuation) |
442 | return !StringRef(">)}]," ).contains(C: value.front()); |
443 | return !StringRef("<>(){}[]," ).contains(C: value.front()); |
444 | } |
445 | |
446 | bool mlir::tblgen::canFormatStringAsKeyword( |
447 | StringRef value, function_ref<void(Twine)> emitError) { |
448 | if (value.empty()) { |
449 | if (emitError) |
450 | emitError("keywords cannot be empty" ); |
451 | return false; |
452 | } |
453 | if (!isalpha(value.front()) && value.front() != '_') { |
454 | if (emitError) |
455 | emitError("valid keyword starts with a letter or '_'" ); |
456 | return false; |
457 | } |
458 | if (!llvm::all_of(Range: value.drop_front(), P: [](char c) { |
459 | return isalnum(c) || c == '_' || c == '$' || c == '.'; |
460 | })) { |
461 | if (emitError) |
462 | emitError( |
463 | "keywords should contain only alphanum, '_', '$', or '.' characters" ); |
464 | return false; |
465 | } |
466 | return true; |
467 | } |
468 | |
469 | bool mlir::tblgen::isValidLiteral(StringRef value, |
470 | function_ref<void(Twine)> emitError) { |
471 | if (value.empty()) { |
472 | if (emitError) |
473 | emitError("literal can't be empty" ); |
474 | return false; |
475 | } |
476 | char front = value.front(); |
477 | |
478 | // If there is only one character, this must either be punctuation or a |
479 | // single character bare identifier. |
480 | if (value.size() == 1) { |
481 | StringRef bare = "_:,=<>()[]{}?+*" ; |
482 | if (isalpha(front) || bare.contains(C: front)) |
483 | return true; |
484 | if (emitError) |
485 | emitError("single character literal must be a letter or one of '" + bare + |
486 | "'" ); |
487 | return false; |
488 | } |
489 | // Check the punctuation that are larger than a single character. |
490 | if (value == "->" ) |
491 | return true; |
492 | if (value == "..." ) |
493 | return true; |
494 | |
495 | // Otherwise, this must be an identifier. |
496 | return canFormatStringAsKeyword(value, emitError); |
497 | } |
498 | |
499 | //===----------------------------------------------------------------------===// |
500 | // Commandline Options |
501 | //===----------------------------------------------------------------------===// |
502 | |
503 | llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal( |
504 | "asmformat-error-is-fatal" , |
505 | llvm::cl::desc("Emit a fatal error if format parsing fails" ), |
506 | llvm::cl::init(Val: true)); |
507 | |