1//===- Lexer.cpp ----------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Lexer.h"
10#include "mlir/Support/LogicalResult.h"
11#include "mlir/Tools/PDLL/AST/Diagnostic.h"
12#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
13#include "llvm/ADT/StringExtras.h"
14#include "llvm/ADT/StringSwitch.h"
15#include "llvm/Support/SourceMgr.h"
16
17using namespace mlir;
18using namespace mlir::pdll;
19
20//===----------------------------------------------------------------------===//
21// Token
22//===----------------------------------------------------------------------===//
23
24std::string Token::getStringValue() const {
25 assert(getKind() == string || getKind() == string_block ||
26 getKind() == code_complete_string);
27
28 // Start by dropping the quotes.
29 StringRef bytes = getSpelling();
30 if (is(k: string))
31 bytes = bytes.drop_front().drop_back();
32 else if (is(k: string_block))
33 bytes = bytes.drop_front(N: 2).drop_back(N: 2);
34
35 std::string result;
36 result.reserve(res: bytes.size());
37 for (unsigned i = 0, e = bytes.size(); i != e;) {
38 auto c = bytes[i++];
39 if (c != '\\') {
40 result.push_back(c: c);
41 continue;
42 }
43
44 assert(i + 1 <= e && "invalid string should be caught by lexer");
45 auto c1 = bytes[i++];
46 switch (c1) {
47 case '"':
48 case '\\':
49 result.push_back(c: c1);
50 continue;
51 case 'n':
52 result.push_back(c: '\n');
53 continue;
54 case 't':
55 result.push_back(c: '\t');
56 continue;
57 default:
58 break;
59 }
60
61 assert(i + 1 <= e && "invalid string should be caught by lexer");
62 auto c2 = bytes[i++];
63
64 assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
65 result.push_back(c: (llvm::hexDigitValue(C: c1) << 4) | llvm::hexDigitValue(C: c2));
66 }
67
68 return result;
69}
70
71//===----------------------------------------------------------------------===//
72// Lexer
73//===----------------------------------------------------------------------===//
74
75Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
76 CodeCompleteContext *codeCompleteContext)
77 : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
78 codeCompletionLocation(nullptr) {
79 curBufferID = mgr.getMainFileID();
80 curBuffer = srcMgr.getMemoryBuffer(i: curBufferID)->getBuffer();
81 curPtr = curBuffer.begin();
82
83 // Set the code completion location if necessary.
84 if (codeCompleteContext) {
85 codeCompletionLocation =
86 codeCompleteContext->getCodeCompleteLoc().getPointer();
87 }
88
89 // If the diag engine has no handler, add a default that emits to the
90 // SourceMgr.
91 if (!diagEngine.getHandlerFn()) {
92 diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
93 srcMgr.PrintMessage(Loc: diag.getLocation().Start, Kind: diag.getSeverity(),
94 Msg: diag.getMessage());
95 for (const ast::Diagnostic &note : diag.getNotes())
96 srcMgr.PrintMessage(Loc: note.getLocation().Start, Kind: note.getSeverity(),
97 Msg: note.getMessage());
98 });
99 addedHandlerToDiagEngine = true;
100 }
101}
102
103Lexer::~Lexer() {
104 if (addedHandlerToDiagEngine)
105 diagEngine.setHandlerFn(nullptr);
106}
107
108LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
109 std::string includedFile;
110 int bufferID =
111 srcMgr.AddIncludeFile(Filename: filename.str(), IncludeLoc: includeLoc.End, IncludedFile&: includedFile);
112 if (!bufferID)
113 return failure();
114
115 curBufferID = bufferID;
116 curBuffer = srcMgr.getMemoryBuffer(i: curBufferID)->getBuffer();
117 curPtr = curBuffer.begin();
118 return success();
119}
120
121Token Lexer::emitError(SMRange loc, const Twine &msg) {
122 diagEngine.emitError(loc, msg);
123 return formToken(kind: Token::error, tokStart: loc.Start.getPointer());
124}
125Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
126 const Twine &note) {
127 diagEngine.emitError(loc, msg)->attachNote(msg: note, noteLoc);
128 return formToken(kind: Token::error, tokStart: loc.Start.getPointer());
129}
130Token Lexer::emitError(const char *loc, const Twine &msg) {
131 return emitError(
132 loc: SMRange(SMLoc::getFromPointer(Ptr: loc), SMLoc::getFromPointer(Ptr: loc + 1)), msg);
133}
134
135int Lexer::getNextChar() {
136 char curChar = *curPtr++;
137 switch (curChar) {
138 default:
139 return static_cast<unsigned char>(curChar);
140 case 0: {
141 // A nul character in the stream is either the end of the current buffer
142 // or a random nul in the file. Disambiguate that here.
143 if (curPtr - 1 != curBuffer.end())
144 return 0;
145
146 // Otherwise, return end of file.
147 --curPtr;
148 return EOF;
149 }
150 case '\n':
151 case '\r':
152 // Handle the newline character by ignoring it and incrementing the line
153 // count. However, be careful about 'dos style' files with \n\r in them.
154 // Only treat a \n\r or \r\n as a single line.
155 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
156 ++curPtr;
157 return '\n';
158 }
159}
160
161Token Lexer::lexToken() {
162 while (true) {
163 const char *tokStart = curPtr;
164
165 // Check to see if this token is at the code completion location.
166 if (tokStart == codeCompletionLocation)
167 return formToken(kind: Token::code_complete, tokStart);
168
169 // This always consumes at least one character.
170 int curChar = getNextChar();
171 switch (curChar) {
172 default:
173 // Handle identifiers: [a-zA-Z_]
174 if (isalpha(curChar) || curChar == '_')
175 return lexIdentifier(tokStart);
176
177 // Unknown character, emit an error.
178 return emitError(loc: tokStart, msg: "unexpected character");
179 case EOF: {
180 // Return EOF denoting the end of lexing.
181 Token eof = formToken(kind: Token::eof, tokStart);
182
183 // Check to see if we are in an included file.
184 SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(i: curBufferID);
185 if (parentIncludeLoc.isValid()) {
186 curBufferID = srcMgr.FindBufferContainingLoc(Loc: parentIncludeLoc);
187 curBuffer = srcMgr.getMemoryBuffer(i: curBufferID)->getBuffer();
188 curPtr = parentIncludeLoc.getPointer();
189 }
190
191 return eof;
192 }
193
194 // Lex punctuation.
195 case '-':
196 if (*curPtr == '>') {
197 ++curPtr;
198 return formToken(kind: Token::arrow, tokStart);
199 }
200 return emitError(loc: tokStart, msg: "unexpected character");
201 case ':':
202 return formToken(kind: Token::colon, tokStart);
203 case ',':
204 return formToken(kind: Token::comma, tokStart);
205 case '.':
206 return formToken(kind: Token::dot, tokStart);
207 case '=':
208 if (*curPtr == '>') {
209 ++curPtr;
210 return formToken(kind: Token::equal_arrow, tokStart);
211 }
212 return formToken(kind: Token::equal, tokStart);
213 case ';':
214 return formToken(kind: Token::semicolon, tokStart);
215 case '[':
216 if (*curPtr == '{') {
217 ++curPtr;
218 return lexString(tokStart, /*isStringBlock=*/true);
219 }
220 return formToken(kind: Token::l_square, tokStart);
221 case ']':
222 return formToken(kind: Token::r_square, tokStart);
223
224 case '<':
225 return formToken(kind: Token::less, tokStart);
226 case '>':
227 return formToken(kind: Token::greater, tokStart);
228 case '{':
229 return formToken(kind: Token::l_brace, tokStart);
230 case '}':
231 return formToken(kind: Token::r_brace, tokStart);
232 case '(':
233 return formToken(kind: Token::l_paren, tokStart);
234 case ')':
235 return formToken(kind: Token::r_paren, tokStart);
236 case '/':
237 if (*curPtr == '/') {
238 lexComment();
239 continue;
240 }
241 return emitError(loc: tokStart, msg: "unexpected character");
242
243 // Ignore whitespace characters.
244 case 0:
245 case ' ':
246 case '\t':
247 case '\n':
248 return lexToken();
249
250 case '#':
251 return lexDirective(tokStart);
252 case '"':
253 return lexString(tokStart, /*isStringBlock=*/false);
254
255 case '0':
256 case '1':
257 case '2':
258 case '3':
259 case '4':
260 case '5':
261 case '6':
262 case '7':
263 case '8':
264 case '9':
265 return lexNumber(tokStart);
266 }
267 }
268}
269
270/// Skip a comment line, starting with a '//'.
271void Lexer::lexComment() {
272 // Advance over the second '/' in a '//' comment.
273 assert(*curPtr == '/');
274 ++curPtr;
275
276 while (true) {
277 switch (*curPtr++) {
278 case '\n':
279 case '\r':
280 // Newline is end of comment.
281 return;
282 case 0:
283 // If this is the end of the buffer, end the comment.
284 if (curPtr - 1 == curBuffer.end()) {
285 --curPtr;
286 return;
287 }
288 [[fallthrough]];
289 default:
290 // Skip over other characters.
291 break;
292 }
293 }
294}
295
296Token Lexer::lexDirective(const char *tokStart) {
297 // Match the rest with an identifier regex: [0-9a-zA-Z_]*
298 while (isalnum(*curPtr) || *curPtr == '_')
299 ++curPtr;
300
301 StringRef str(tokStart, curPtr - tokStart);
302 return Token(Token::directive, str);
303}
304
305Token Lexer::lexIdentifier(const char *tokStart) {
306 // Match the rest of the identifier regex: [0-9a-zA-Z_]*
307 while (isalnum(*curPtr) || *curPtr == '_')
308 ++curPtr;
309
310 // Check to see if this identifier is a keyword.
311 StringRef str(tokStart, curPtr - tokStart);
312 Token::Kind kind = StringSwitch<Token::Kind>(str)
313 .Case(S: "attr", Value: Token::kw_attr)
314 .Case(S: "Attr", Value: Token::kw_Attr)
315 .Case(S: "erase", Value: Token::kw_erase)
316 .Case(S: "let", Value: Token::kw_let)
317 .Case(S: "Constraint", Value: Token::kw_Constraint)
318 .Case(S: "not", Value: Token::kw_not)
319 .Case(S: "op", Value: Token::kw_op)
320 .Case(S: "Op", Value: Token::kw_Op)
321 .Case(S: "OpName", Value: Token::kw_OpName)
322 .Case(S: "Pattern", Value: Token::kw_Pattern)
323 .Case(S: "replace", Value: Token::kw_replace)
324 .Case(S: "return", Value: Token::kw_return)
325 .Case(S: "rewrite", Value: Token::kw_rewrite)
326 .Case(S: "Rewrite", Value: Token::kw_Rewrite)
327 .Case(S: "type", Value: Token::kw_type)
328 .Case(S: "Type", Value: Token::kw_Type)
329 .Case(S: "TypeRange", Value: Token::kw_TypeRange)
330 .Case(S: "Value", Value: Token::kw_Value)
331 .Case(S: "ValueRange", Value: Token::kw_ValueRange)
332 .Case(S: "with", Value: Token::kw_with)
333 .Case(S: "_", Value: Token::underscore)
334 .Default(Value: Token::identifier);
335 return Token(kind, str);
336}
337
338Token Lexer::lexNumber(const char *tokStart) {
339 assert(isdigit(curPtr[-1]));
340
341 // Handle the normal decimal case.
342 while (isdigit(*curPtr))
343 ++curPtr;
344
345 return formToken(kind: Token::integer, tokStart);
346}
347
348Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
349 while (true) {
350 // Check to see if there is a code completion location within the string. In
351 // these cases we generate a completion location and place the currently
352 // lexed string within the token (without the quotes). This allows for the
353 // parser to use the partially lexed string when computing the completion
354 // results.
355 if (curPtr == codeCompletionLocation) {
356 return formToken(kind: Token::code_complete_string,
357 tokStart: tokStart + (isStringBlock ? 2 : 1));
358 }
359
360 switch (*curPtr++) {
361 case '"':
362 // If this is a string block, we only end the string when we encounter a
363 // `}]`.
364 if (!isStringBlock)
365 return formToken(kind: Token::string, tokStart);
366 continue;
367 case '}':
368 // If this is a string block, we only end the string when we encounter a
369 // `}]`.
370 if (!isStringBlock || *curPtr != ']')
371 continue;
372 ++curPtr;
373 return formToken(kind: Token::string_block, tokStart);
374 case 0: {
375 // If this is a random nul character in the middle of a string, just
376 // include it. If it is the end of file, then it is an error.
377 if (curPtr - 1 != curBuffer.end())
378 continue;
379 --curPtr;
380
381 StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
382 return emitError(loc: curPtr - 1,
383 msg: "expected '" + expectedEndStr + "' in string literal");
384 }
385
386 case '\n':
387 case '\v':
388 case '\f':
389 // String blocks allow multiple lines.
390 if (!isStringBlock)
391 return emitError(loc: curPtr - 1, msg: "expected '\"' in string literal");
392 continue;
393
394 case '\\':
395 // Handle explicitly a few escapes.
396 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
397 *curPtr == 't') {
398 ++curPtr;
399 } else if (llvm::isHexDigit(C: *curPtr) && llvm::isHexDigit(C: curPtr[1])) {
400 // Support \xx for two hex digits.
401 curPtr += 2;
402 } else {
403 return emitError(loc: curPtr - 1, msg: "unknown escape in string literal");
404 }
405 continue;
406
407 default:
408 continue;
409 }
410 }
411}
412

source code of mlir/lib/Tools/PDLL/Parser/Lexer.cpp