1//===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "FormatGen.h"
10#include "llvm/ADT/StringSwitch.h"
11#include "llvm/Support/SourceMgr.h"
12#include "llvm/TableGen/Error.h"
13
14using namespace mlir;
15using namespace mlir::tblgen;
16
17//===----------------------------------------------------------------------===//
18// FormatToken
19//===----------------------------------------------------------------------===//
20
21SMLoc FormatToken::getLoc() const {
22 return SMLoc::getFromPointer(Ptr: spelling.data());
23}
24
25//===----------------------------------------------------------------------===//
26// FormatLexer
27//===----------------------------------------------------------------------===//
28
29FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc)
30 : mgr(mgr), loc(loc),
31 curBuffer(mgr.getMemoryBuffer(i: mgr.getMainFileID())->getBuffer()),
32 curPtr(curBuffer.begin()) {}
33
34FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
35 mgr.PrintMessage(Loc: loc, Kind: llvm::SourceMgr::DK_Error, Msg: msg);
36 llvm::SrcMgr.PrintMessage(Loc: this->loc, Kind: llvm::SourceMgr::DK_Note,
37 Msg: "in custom assembly format for this operation");
38 return formToken(kind: FormatToken::error, tokStart: loc.getPointer());
39}
40
41FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
42 return emitError(loc: SMLoc::getFromPointer(Ptr: loc), msg);
43}
44
45FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
46 const Twine &note) {
47 mgr.PrintMessage(Loc: loc, Kind: llvm::SourceMgr::DK_Error, Msg: msg);
48 llvm::SrcMgr.PrintMessage(Loc: this->loc, Kind: llvm::SourceMgr::DK_Note,
49 Msg: "in custom assembly format for this operation");
50 mgr.PrintMessage(Loc: loc, Kind: llvm::SourceMgr::DK_Note, Msg: note);
51 return formToken(kind: FormatToken::error, tokStart: loc.getPointer());
52}
53
54int FormatLexer::getNextChar() {
55 char curChar = *curPtr++;
56 switch (curChar) {
57 default:
58 return (unsigned char)curChar;
59 case 0: {
60 // A nul character in the stream is either the end of the current buffer or
61 // a random nul in the file. Disambiguate that here.
62 if (curPtr - 1 != curBuffer.end())
63 return 0;
64
65 // Otherwise, return end of file.
66 --curPtr;
67 return EOF;
68 }
69 case '\n':
70 case '\r':
71 // Handle the newline character by ignoring it and incrementing the line
72 // count. However, be careful about 'dos style' files with \n\r in them.
73 // Only treat a \n\r or \r\n as a single line.
74 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
75 ++curPtr;
76 return '\n';
77 }
78}
79
80FormatToken FormatLexer::lexToken() {
81 const char *tokStart = curPtr;
82
83 // This always consumes at least one character.
84 int curChar = getNextChar();
85 switch (curChar) {
86 default:
87 // Handle identifiers: [a-zA-Z_]
88 if (isalpha(curChar) || curChar == '_')
89 return lexIdentifier(tokStart);
90
91 // Unknown character, emit an error.
92 return emitError(loc: tokStart, msg: "unexpected character");
93 case EOF:
94 // Return EOF denoting the end of lexing.
95 return formToken(kind: FormatToken::eof, tokStart);
96
97 // Lex punctuation.
98 case '^':
99 return formToken(kind: FormatToken::caret, tokStart);
100 case ':':
101 return formToken(kind: FormatToken::colon, tokStart);
102 case ',':
103 return formToken(kind: FormatToken::comma, tokStart);
104 case '=':
105 return formToken(kind: FormatToken::equal, tokStart);
106 case '<':
107 return formToken(kind: FormatToken::less, tokStart);
108 case '>':
109 return formToken(kind: FormatToken::greater, tokStart);
110 case '?':
111 return formToken(kind: FormatToken::question, tokStart);
112 case '(':
113 return formToken(kind: FormatToken::l_paren, tokStart);
114 case ')':
115 return formToken(kind: FormatToken::r_paren, tokStart);
116 case '*':
117 return formToken(kind: FormatToken::star, tokStart);
118 case '|':
119 return formToken(kind: FormatToken::pipe, tokStart);
120
121 // Ignore whitespace characters.
122 case 0:
123 case ' ':
124 case '\t':
125 case '\n':
126 return lexToken();
127
128 case '`':
129 return lexLiteral(tokStart);
130 case '$':
131 return lexVariable(tokStart);
132 case '"':
133 return lexString(tokStart);
134 }
135}
136
137FormatToken FormatLexer::lexLiteral(const char *tokStart) {
138 assert(curPtr[-1] == '`');
139
140 // Lex a literal surrounded by ``.
141 while (const char curChar = *curPtr++) {
142 if (curChar == '`')
143 return formToken(kind: FormatToken::literal, tokStart);
144 }
145 return emitError(loc: curPtr - 1, msg: "unexpected end of file in literal");
146}
147
148FormatToken FormatLexer::lexVariable(const char *tokStart) {
149 if (!isalpha(curPtr[0]) && curPtr[0] != '_')
150 return emitError(loc: curPtr - 1, msg: "expected variable name");
151
152 // Otherwise, consume the rest of the characters.
153 while (isalnum(*curPtr) || *curPtr == '_')
154 ++curPtr;
155 return formToken(kind: FormatToken::variable, tokStart);
156}
157
158FormatToken FormatLexer::lexString(const char *tokStart) {
159 // Lex until another quote, respecting escapes.
160 bool escape = false;
161 while (const char curChar = *curPtr++) {
162 if (!escape && curChar == '"')
163 return formToken(kind: FormatToken::string, tokStart);
164 escape = curChar == '\\';
165 }
166 return emitError(loc: curPtr - 1, msg: "unexpected end of file in string");
167}
168
169FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
170 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
171 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
172 ++curPtr;
173
174 // Check to see if this identifier is a keyword.
175 StringRef str(tokStart, curPtr - tokStart);
176 auto kind =
177 StringSwitch<FormatToken::Kind>(str)
178 .Case(S: "attr-dict", Value: FormatToken::kw_attr_dict)
179 .Case(S: "attr-dict-with-keyword", Value: FormatToken::kw_attr_dict_w_keyword)
180 .Case(S: "prop-dict", Value: FormatToken::kw_prop_dict)
181 .Case(S: "custom", Value: FormatToken::kw_custom)
182 .Case(S: "functional-type", Value: FormatToken::kw_functional_type)
183 .Case(S: "oilist", Value: FormatToken::kw_oilist)
184 .Case(S: "operands", Value: FormatToken::kw_operands)
185 .Case(S: "params", Value: FormatToken::kw_params)
186 .Case(S: "ref", Value: FormatToken::kw_ref)
187 .Case(S: "regions", Value: FormatToken::kw_regions)
188 .Case(S: "results", Value: FormatToken::kw_results)
189 .Case(S: "struct", Value: FormatToken::kw_struct)
190 .Case(S: "successors", Value: FormatToken::kw_successors)
191 .Case(S: "type", Value: FormatToken::kw_type)
192 .Case(S: "qualified", Value: FormatToken::kw_qualified)
193 .Default(Value: FormatToken::identifier);
194 return FormatToken(kind, str);
195}
196
197//===----------------------------------------------------------------------===//
198// FormatParser
199//===----------------------------------------------------------------------===//
200
201FormatElement::~FormatElement() = default;
202
203FormatParser::~FormatParser() = default;
204
205FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
206 SMLoc loc = curToken.getLoc();
207
208 // Parse each of the format elements into the main format.
209 std::vector<FormatElement *> elements;
210 while (curToken.getKind() != FormatToken::eof) {
211 FailureOr<FormatElement *> element = parseElement(ctx: TopLevelContext);
212 if (failed(result: element))
213 return failure();
214 elements.push_back(x: *element);
215 }
216
217 // Verify the format.
218 if (failed(result: verify(loc, elements)))
219 return failure();
220 return elements;
221}
222
223//===----------------------------------------------------------------------===//
224// Element Parsing
225
226FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
227 if (curToken.is(kind: FormatToken::literal))
228 return parseLiteral(ctx);
229 if (curToken.is(kind: FormatToken::string))
230 return parseString(ctx);
231 if (curToken.is(kind: FormatToken::variable))
232 return parseVariable(ctx);
233 if (curToken.isKeyword())
234 return parseDirective(ctx);
235 if (curToken.is(kind: FormatToken::l_paren))
236 return parseOptionalGroup(ctx);
237 return emitError(loc: curToken.getLoc(),
238 msg: "expected literal, variable, directive, or optional group");
239}
240
241FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
242 FormatToken tok = curToken;
243 SMLoc loc = tok.getLoc();
244 consumeToken();
245
246 if (ctx != TopLevelContext) {
247 return emitError(
248 loc,
249 msg: "literals may only be used in the top-level section of the format");
250 }
251 // Get the spelling without the surrounding backticks.
252 StringRef value = tok.getSpelling();
253 // Prevents things like `$arg0` or empty literals (when a literal is expected
254 // but not found) from getting segmentation faults.
255 if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`')
256 return emitError(loc: tok.getLoc(), msg: "expected literal, but got '" + value + "'");
257 value = value.drop_front().drop_back();
258
259 // The parsed literal is a space element (`` or ` `) or a newline.
260 if (value.empty() || value == " " || value == "\\n")
261 return create<WhitespaceElement>(args&: value);
262
263 // Check that the parsed literal is valid.
264 if (!isValidLiteral(value, emitError: [&](Twine msg) {
265 (void)emitError(loc, msg: "expected valid literal but got '" + value +
266 "': " + msg);
267 }))
268 return failure();
269 return create<LiteralElement>(args&: value);
270}
271
272FailureOr<FormatElement *> FormatParser::parseString(Context ctx) {
273 FormatToken tok = curToken;
274 SMLoc loc = tok.getLoc();
275 consumeToken();
276
277 if (ctx != CustomDirectiveContext) {
278 return emitError(
279 loc, msg: "strings may only be used as 'custom' directive arguments");
280 }
281 // Escape the string.
282 std::string value;
283 StringRef contents = tok.getSpelling().drop_front().drop_back();
284 value.reserve(res: contents.size());
285 bool escape = false;
286 for (char c : contents) {
287 escape = c == '\\';
288 if (!escape)
289 value.push_back(c: c);
290 }
291 return create<StringElement>(args: std::move(value));
292}
293
294FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
295 FormatToken tok = curToken;
296 SMLoc loc = tok.getLoc();
297 consumeToken();
298
299 // Get the name of the variable without the leading `$`.
300 StringRef name = tok.getSpelling().drop_front();
301 return parseVariableImpl(loc, name, ctx);
302}
303
304FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
305 FormatToken tok = curToken;
306 SMLoc loc = tok.getLoc();
307 consumeToken();
308
309 if (tok.is(kind: FormatToken::kw_custom))
310 return parseCustomDirective(loc, ctx);
311 return parseDirectiveImpl(loc, kind: tok.getKind(), ctx);
312}
313
314FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
315 SMLoc loc = curToken.getLoc();
316 consumeToken();
317 if (ctx != TopLevelContext) {
318 return emitError(loc,
319 msg: "optional groups can only be used as top-level elements");
320 }
321
322 // Parse the child elements for this optional group.
323 std::vector<FormatElement *> thenElements, elseElements;
324 FormatElement *anchor = nullptr;
325 auto parseChildElements =
326 [this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult {
327 do {
328 FailureOr<FormatElement *> element = parseElement(ctx: TopLevelContext);
329 if (failed(result: element))
330 return failure();
331 // Check for an anchor.
332 if (curToken.is(kind: FormatToken::caret)) {
333 if (anchor) {
334 return emitError(loc: curToken.getLoc(),
335 msg: "only one element can be marked as the anchor of an "
336 "optional group");
337 }
338 anchor = *element;
339 consumeToken();
340 }
341 elements.push_back(x: *element);
342 } while (!curToken.is(kind: FormatToken::r_paren));
343 return success();
344 };
345
346 // Parse the 'then' elements. If the anchor was found in this group, then the
347 // optional is not inverted.
348 if (failed(result: parseChildElements(thenElements)))
349 return failure();
350 consumeToken();
351 bool inverted = !anchor;
352
353 // Parse the `else` elements of this optional group.
354 if (curToken.is(kind: FormatToken::colon)) {
355 consumeToken();
356 if (failed(result: parseToken(
357 kind: FormatToken::l_paren,
358 msg: "expected '(' to start else branch of optional group")) ||
359 failed(result: parseChildElements(elseElements)))
360 return failure();
361 consumeToken();
362 }
363 if (failed(result: parseToken(kind: FormatToken::question,
364 msg: "expected '?' after optional group")))
365 return failure();
366
367 // The optional group is required to have an anchor.
368 if (!anchor)
369 return emitError(loc, msg: "optional group has no anchor element");
370
371 // Verify the child elements.
372 if (failed(result: verifyOptionalGroupElements(loc, elements: thenElements, anchor)) ||
373 failed(result: verifyOptionalGroupElements(loc, elements: elseElements, anchor: nullptr)))
374 return failure();
375
376 // Get the first parsable element. It must be an element that can be
377 // optionally-parsed.
378 auto isWhitespace = [](FormatElement *element) {
379 return isa<WhitespaceElement>(Val: element);
380 };
381 auto thenParseBegin = llvm::find_if_not(Range&: thenElements, P: isWhitespace);
382 auto elseParseBegin = llvm::find_if_not(Range&: elseElements, P: isWhitespace);
383 unsigned thenParseStart = std::distance(first: thenElements.begin(), last: thenParseBegin);
384 unsigned elseParseStart = std::distance(first: elseElements.begin(), last: elseParseBegin);
385
386 if (!isa<LiteralElement, VariableElement, CustomDirective>(Val: *thenParseBegin)) {
387 return emitError(loc, msg: "first parsable element of an optional group must be "
388 "a literal, variable, or custom directive");
389 }
390 return create<OptionalElement>(args: std::move(thenElements),
391 args: std::move(elseElements), args&: thenParseStart,
392 args&: elseParseStart, args&: anchor, args&: inverted);
393}
394
395FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
396 Context ctx) {
397 if (ctx != TopLevelContext)
398 return emitError(loc, msg: "'custom' is only valid as a top-level directive");
399
400 FailureOr<FormatToken> nameTok;
401 if (failed(result: parseToken(kind: FormatToken::less,
402 msg: "expected '<' before custom directive name")) ||
403 failed(result: nameTok =
404 parseToken(kind: FormatToken::identifier,
405 msg: "expected custom directive name identifier")) ||
406 failed(result: parseToken(kind: FormatToken::greater,
407 msg: "expected '>' after custom directive name")) ||
408 failed(result: parseToken(kind: FormatToken::l_paren,
409 msg: "expected '(' before custom directive parameters")))
410 return failure();
411
412 // Parse the arguments.
413 std::vector<FormatElement *> arguments;
414 while (true) {
415 FailureOr<FormatElement *> argument = parseElement(ctx: CustomDirectiveContext);
416 if (failed(result: argument))
417 return failure();
418 arguments.push_back(x: *argument);
419 if (!curToken.is(kind: FormatToken::comma))
420 break;
421 consumeToken();
422 }
423
424 if (failed(result: parseToken(kind: FormatToken::r_paren,
425 msg: "expected ')' after custom directive parameters")))
426 return failure();
427
428 if (failed(result: verifyCustomDirectiveArguments(loc, arguments)))
429 return failure();
430 return create<CustomDirective>(args: nameTok->getSpelling(), args: std::move(arguments));
431}
432
433//===----------------------------------------------------------------------===//
434// Utility Functions
435//===----------------------------------------------------------------------===//
436
437bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
438 bool lastWasPunctuation) {
439 if (value.size() != 1 && value != "->")
440 return true;
441 if (lastWasPunctuation)
442 return !StringRef(">)}],").contains(C: value.front());
443 return !StringRef("<>(){}[],").contains(C: value.front());
444}
445
446bool mlir::tblgen::canFormatStringAsKeyword(
447 StringRef value, function_ref<void(Twine)> emitError) {
448 if (value.empty()) {
449 if (emitError)
450 emitError("keywords cannot be empty");
451 return false;
452 }
453 if (!isalpha(value.front()) && value.front() != '_') {
454 if (emitError)
455 emitError("valid keyword starts with a letter or '_'");
456 return false;
457 }
458 if (!llvm::all_of(Range: value.drop_front(), P: [](char c) {
459 return isalnum(c) || c == '_' || c == '$' || c == '.';
460 })) {
461 if (emitError)
462 emitError(
463 "keywords should contain only alphanum, '_', '$', or '.' characters");
464 return false;
465 }
466 return true;
467}
468
469bool mlir::tblgen::isValidLiteral(StringRef value,
470 function_ref<void(Twine)> emitError) {
471 if (value.empty()) {
472 if (emitError)
473 emitError("literal can't be empty");
474 return false;
475 }
476 char front = value.front();
477
478 // If there is only one character, this must either be punctuation or a
479 // single character bare identifier.
480 if (value.size() == 1) {
481 StringRef bare = "_:,=<>()[]{}?+*";
482 if (isalpha(front) || bare.contains(C: front))
483 return true;
484 if (emitError)
485 emitError("single character literal must be a letter or one of '" + bare +
486 "'");
487 return false;
488 }
489 // Check the punctuation that are larger than a single character.
490 if (value == "->")
491 return true;
492 if (value == "...")
493 return true;
494
495 // Otherwise, this must be an identifier.
496 return canFormatStringAsKeyword(value, emitError);
497}
498
499//===----------------------------------------------------------------------===//
500// Commandline Options
501//===----------------------------------------------------------------------===//
502
503llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
504 "asmformat-error-is-fatal",
505 llvm::cl::desc("Emit a fatal error if format parsing fails"),
506 llvm::cl::init(Val: true));
507

source code of mlir/tools/mlir-tblgen/FormatGen.cpp