1//===- Parser.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Tools/PDLL/Parser/Parser.h"
10#include "Lexer.h"
11#include "mlir/Support/IndentedOstream.h"
12#include "mlir/TableGen/Argument.h"
13#include "mlir/TableGen/Attribute.h"
14#include "mlir/TableGen/Constraint.h"
15#include "mlir/TableGen/Format.h"
16#include "mlir/TableGen/Operator.h"
17#include "mlir/Tools/PDLL/AST/Context.h"
18#include "mlir/Tools/PDLL/AST/Diagnostic.h"
19#include "mlir/Tools/PDLL/AST/Nodes.h"
20#include "mlir/Tools/PDLL/AST/Types.h"
21#include "mlir/Tools/PDLL/ODS/Constraint.h"
22#include "mlir/Tools/PDLL/ODS/Context.h"
23#include "mlir/Tools/PDLL/ODS/Operation.h"
24#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/ManagedStatic.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/ScopedPrinter.h"
31#include "llvm/TableGen/Error.h"
32#include "llvm/TableGen/Parser.h"
33#include <optional>
34#include <string>
35
36using namespace mlir;
37using namespace mlir::pdll;
38
39//===----------------------------------------------------------------------===//
40// Parser
41//===----------------------------------------------------------------------===//
42
43namespace {
44class Parser {
45public:
46 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47 bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50 typeTy(ast::TypeType::get(context&: ctx)), valueTy(ast::ValueType::get(context&: ctx)),
51 typeRangeTy(ast::TypeRangeType::get(context&: ctx)),
52 valueRangeTy(ast::ValueRangeType::get(context&: ctx)),
53 attrTy(ast::AttributeType::get(context&: ctx)),
54 codeCompleteContext(codeCompleteContext) {}
55
56 /// Try to parse a new module. Returns nullptr in the case of failure.
57 FailureOr<ast::Module *> parseModule();
58
59private:
60 /// The current context of the parser. It allows for the parser to know a bit
61 /// about the construct it is nested within during parsing. This is used
62 /// specifically to provide additional verification during parsing, e.g. to
63 /// prevent using rewrites within a match context, matcher constraints within
64 /// a rewrite section, etc.
65 enum class ParserContext {
66 /// The parser is in the global context.
67 Global,
68 /// The parser is currently within a Constraint, which disallows all types
69 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
70 Constraint,
71 /// The parser is currently within the matcher portion of a Pattern, which
72 /// is allows a terminal operation rewrite statement but no other rewrite
73 /// transformations.
74 PatternMatch,
75 /// The parser is currently within a Rewrite, which disallows calls to
76 /// constraints, requires operation expressions to have names, etc.
77 Rewrite,
78 };
79
80 /// The current specification context of an operations result type. This
81 /// indicates how the result types of an operation may be inferred.
82 enum class OpResultTypeContext {
83 /// The result types of the operation are not known to be inferred.
84 Explicit,
85 /// The result types of the operation are inferred from the root input of a
86 /// `replace` statement.
87 Replacement,
88 /// The result types of the operation are inferred by using the
89 /// `InferTypeOpInterface` interface provided by the operation.
90 Interface,
91 };
92
93 //===--------------------------------------------------------------------===//
94 // Parsing
95 //===--------------------------------------------------------------------===//
96
97 /// Push a new decl scope onto the lexer.
98 ast::DeclScope *pushDeclScope() {
99 ast::DeclScope *newScope =
100 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
101 return (curDeclScope = newScope);
102 }
103 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
104
105 /// Pop the last decl scope from the lexer.
106 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
107
108 /// Parse the body of an AST module.
109 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
110
111 /// Try to convert the given expression to `type`. Returns failure and emits
112 /// an error if a conversion is not viable. On failure, `noteAttachFn` is
113 /// invoked to attach notes to the emitted error diagnostic. On success,
114 /// `expr` is updated to the expression used to convert to `type`.
115 LogicalResult convertExpressionTo(
116 ast::Expr *&expr, ast::Type type,
117 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
118 LogicalResult
119 convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
120 ast::Type type,
121 function_ref<ast::InFlightDiagnostic()> emitErrorFn);
122 LogicalResult convertTupleExpressionTo(
123 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
124 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
125 function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
126
127 /// Given an operation expression, convert it to a Value or ValueRange
128 /// typed expression.
129 ast::Expr *convertOpToValue(const ast::Expr *opExpr);
130
131 /// Lookup ODS information for the given operation, returns nullptr if no
132 /// information is found.
133 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
134 return opName ? ctx.getODSContext().lookupOperation(name: *opName) : nullptr;
135 }
136
137 /// Process the given documentation string, or return an empty string if
138 /// documentation isn't enabled.
139 StringRef processDoc(StringRef doc) {
140 return enableDocumentation ? doc : StringRef();
141 }
142
143 /// Process the given documentation string and format it, or return an empty
144 /// string if documentation isn't enabled.
145 std::string processAndFormatDoc(const Twine &doc) {
146 if (!enableDocumentation)
147 return "";
148 std::string docStr;
149 {
150 llvm::raw_string_ostream docOS(docStr);
151 raw_indented_ostream(docOS).printReindented(
152 str: StringRef(docStr).rtrim(Chars: " \t"));
153 }
154 return docStr;
155 }
156
157 //===--------------------------------------------------------------------===//
158 // Directives
159
160 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
161 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
162 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
163 SmallVectorImpl<ast::Decl *> &decls);
164
165 /// Process the records of a parsed tablegen include file.
166 void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
167 SmallVectorImpl<ast::Decl *> &decls);
168
169 /// Create a user defined native constraint for a constraint imported from
170 /// ODS.
171 template <typename ConstraintT>
172 ast::Decl *
173 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
174 SMRange loc, ast::Type type,
175 StringRef nativeType, StringRef docString);
176 template <typename ConstraintT>
177 ast::Decl *
178 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
179 SMRange loc, ast::Type type,
180 StringRef nativeType);
181
182 //===--------------------------------------------------------------------===//
183 // Decls
184
185 /// This structure contains the set of pattern metadata that may be parsed.
186 struct ParsedPatternMetadata {
187 std::optional<uint16_t> benefit;
188 bool hasBoundedRecursion = false;
189 };
190
191 FailureOr<ast::Decl *> parseTopLevelDecl();
192 FailureOr<ast::NamedAttributeDecl *>
193 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
194
195 /// Parse an argument variable as part of the signature of a
196 /// UserConstraintDecl or UserRewriteDecl.
197 FailureOr<ast::VariableDecl *> parseArgumentDecl();
198
199 /// Parse a result variable as part of the signature of a UserConstraintDecl
200 /// or UserRewriteDecl.
201 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
202
203 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
204 /// defined in a non-global context.
205 FailureOr<ast::UserConstraintDecl *>
206 parseUserConstraintDecl(bool isInline = false);
207
208 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
209 /// non-global context, such as within a Pattern/Constraint/etc.
210 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
211
212 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
213 /// PDLL constructs.
214 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
215 const ast::Name &name, bool isInline,
216 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
217 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
218
219 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
220 /// defined in a non-global context.
221 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
222
223 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
224 /// non-global context, such as within a Pattern/Rewrite/etc.
225 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
226
227 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
228 /// PDLL constructs.
229 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
230 const ast::Name &name, bool isInline,
231 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
232 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
233
234 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
235 /// effectively the same syntax, and only differ on slight semantics (given
236 /// the different parsing contexts).
237 template <typename T, typename ParseUserPDLLDeclFnT>
238 FailureOr<T *> parseUserConstraintOrRewriteDecl(
239 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
240 StringRef anonymousNamePrefix, bool isInline);
241
242 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
243 /// These decls have effectively the same syntax.
244 template <typename T>
245 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
246 const ast::Name &name, bool isInline,
247 ArrayRef<ast::VariableDecl *> arguments,
248 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
249
250 /// Parse the functional signature (i.e. the arguments and results) of a
251 /// UserConstraintDecl or UserRewriteDecl.
252 LogicalResult parseUserConstraintOrRewriteSignature(
253 SmallVectorImpl<ast::VariableDecl *> &arguments,
254 SmallVectorImpl<ast::VariableDecl *> &results,
255 ast::DeclScope *&argumentScope, ast::Type &resultType);
256
257 /// Validate the return (which if present is specified by bodyIt) of a
258 /// UserConstraintDecl or UserRewriteDecl.
259 LogicalResult validateUserConstraintOrRewriteReturn(
260 StringRef declType, ast::CompoundStmt *body,
261 ArrayRef<ast::Stmt *>::iterator bodyIt,
262 ArrayRef<ast::Stmt *>::iterator bodyE,
263 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
264
265 FailureOr<ast::CompoundStmt *>
266 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
267 bool expectTerminalSemicolon = true);
268 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
269 FailureOr<ast::Decl *> parsePatternDecl();
270 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
271
272 /// Check to see if a decl has already been defined with the given name, if
273 /// one has emit and error and return failure. Returns success otherwise.
274 LogicalResult checkDefineNamedDecl(const ast::Name &name);
275
276 /// Try to define a variable decl with the given components, returns the
277 /// variable on success.
278 FailureOr<ast::VariableDecl *>
279 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
280 ast::Expr *initExpr,
281 ArrayRef<ast::ConstraintRef> constraints);
282 FailureOr<ast::VariableDecl *>
283 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
284 ArrayRef<ast::ConstraintRef> constraints);
285
286 /// Parse the constraint reference list for a variable decl.
287 LogicalResult parseVariableDeclConstraintList(
288 SmallVectorImpl<ast::ConstraintRef> &constraints);
289
290 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
291 FailureOr<ast::Expr *> parseTypeConstraintExpr();
292
293 /// Try to parse a single reference to a constraint. `typeConstraint` is the
294 /// location of a previously parsed type constraint for the entity that will
295 /// be constrained by the parsed constraint. `existingConstraints` are any
296 /// existing constraints that have already been parsed for the same entity
297 /// that will be constrained by this constraint. `allowInlineTypeConstraints`
298 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
299 FailureOr<ast::ConstraintRef>
300 parseConstraint(std::optional<SMRange> &typeConstraint,
301 ArrayRef<ast::ConstraintRef> existingConstraints,
302 bool allowInlineTypeConstraints);
303
304 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
305 /// argument or result variable. The constraints for these variables do not
306 /// allow inline type constraints, and only permit a single constraint.
307 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
308
309 //===--------------------------------------------------------------------===//
310 // Exprs
311
312 FailureOr<ast::Expr *> parseExpr();
313
314 /// Identifier expressions.
315 FailureOr<ast::Expr *> parseAttributeExpr();
316 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
317 bool isNegated = false);
318 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
319 FailureOr<ast::Expr *> parseIdentifierExpr();
320 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
321 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
322 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
323 FailureOr<ast::Expr *> parseNegatedExpr();
324 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
325 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
326 FailureOr<ast::Expr *>
327 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328 OpResultTypeContext::Explicit);
329 FailureOr<ast::Expr *> parseTupleExpr();
330 FailureOr<ast::Expr *> parseTypeExpr();
331 FailureOr<ast::Expr *> parseUnderscoreExpr();
332
333 //===--------------------------------------------------------------------===//
334 // Stmts
335
336 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
337 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338 FailureOr<ast::EraseStmt *> parseEraseStmt();
339 FailureOr<ast::LetStmt *> parseLetStmt();
340 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341 FailureOr<ast::ReturnStmt *> parseReturnStmt();
342 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
343
344 //===--------------------------------------------------------------------===//
345 // Creation+Analysis
346 //===--------------------------------------------------------------------===//
347
348 //===--------------------------------------------------------------------===//
349 // Decls
350
351 /// Try to extract a callable from the given AST node. Returns nullptr on
352 /// failure.
353 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
354
355 /// Try to create a pattern decl with the given components, returning the
356 /// Pattern on success.
357 FailureOr<ast::PatternDecl *>
358 createPatternDecl(SMRange loc, const ast::Name *name,
359 const ParsedPatternMetadata &metadata,
360 ast::CompoundStmt *body);
361
362 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
363 /// of results, defined as part of the signature.
364 ast::Type
365 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
366
367 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
368 template <typename T>
369 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
370 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
371 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
372 ast::CompoundStmt *body);
373
374 /// Try to create a variable decl with the given components, returning the
375 /// Variable on success.
376 FailureOr<ast::VariableDecl *>
377 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
378 ArrayRef<ast::ConstraintRef> constraints);
379
380 /// Create a variable for an argument or result defined as part of the
381 /// signature of a UserConstraintDecl/UserRewriteDecl.
382 FailureOr<ast::VariableDecl *>
383 createArgOrResultVariableDecl(StringRef name, SMRange loc,
384 const ast::ConstraintRef &constraint);
385
386 /// Validate the constraints used to constraint a variable decl.
387 /// `inferredType` is the type of the variable inferred by the constraints
388 /// within the list, and is updated to the most refined type as determined by
389 /// the constraints. Returns success if the constraint list is valid, failure
390 /// otherwise.
391 LogicalResult
392 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
393 ast::Type &inferredType);
394 /// Validate a single reference to a constraint. `inferredType` contains the
395 /// currently inferred variabled type and is refined within the type defined
396 /// by the constraint. Returns success if the constraint is valid, failure
397 /// otherwise.
398 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
399 ast::Type &inferredType);
400 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
401 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
402
403 //===--------------------------------------------------------------------===//
404 // Exprs
405
406 FailureOr<ast::CallExpr *>
407 createCallExpr(SMRange loc, ast::Expr *parentExpr,
408 MutableArrayRef<ast::Expr *> arguments,
409 bool isNegated = false);
410 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
411 FailureOr<ast::DeclRefExpr *>
412 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
413 ArrayRef<ast::ConstraintRef> constraints);
414 FailureOr<ast::MemberAccessExpr *>
415 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
416
417 /// Validate the member access `name` into the given parent expression. On
418 /// success, this also returns the type of the member accessed.
419 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
420 StringRef name, SMRange loc);
421 FailureOr<ast::OperationExpr *>
422 createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
423 OpResultTypeContext resultTypeContext,
424 SmallVectorImpl<ast::Expr *> &operands,
425 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
426 SmallVectorImpl<ast::Expr *> &results);
427 LogicalResult
428 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
429 const ods::Operation *odsOp,
430 SmallVectorImpl<ast::Expr *> &operands);
431 LogicalResult validateOperationResults(SMRange loc,
432 std::optional<StringRef> name,
433 const ods::Operation *odsOp,
434 SmallVectorImpl<ast::Expr *> &results);
435 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
436 const ods::Operation *odsOp);
437 LogicalResult validateOperationOperandsOrResults(
438 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
439 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
440 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
441 ast::RangeType rangeTy);
442 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
443 ArrayRef<ast::Expr *> elements,
444 ArrayRef<StringRef> elementNames);
445
446 //===--------------------------------------------------------------------===//
447 // Stmts
448
449 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
450 FailureOr<ast::ReplaceStmt *>
451 createReplaceStmt(SMRange loc, ast::Expr *rootOp,
452 MutableArrayRef<ast::Expr *> replValues);
453 FailureOr<ast::RewriteStmt *>
454 createRewriteStmt(SMRange loc, ast::Expr *rootOp,
455 ast::CompoundStmt *rewriteBody);
456
457 //===--------------------------------------------------------------------===//
458 // Code Completion
459 //===--------------------------------------------------------------------===//
460
461 /// The set of various code completion methods. Every completion method
462 /// returns `failure` to stop the parsing process after providing completion
463 /// results.
464
465 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
466 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467 LogicalResult codeCompleteConstraintName(ast::Type inferredType,
468 bool allowInlineTypeConstraints);
469 LogicalResult codeCompleteDialectName();
470 LogicalResult codeCompleteOperationName(StringRef dialectName);
471 LogicalResult codeCompletePatternMetadata();
472 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
473
474 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
475 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
476 unsigned currentNumOperands);
477 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
478 unsigned currentNumResults);
479
480 //===--------------------------------------------------------------------===//
481 // Lexer Utilities
482 //===--------------------------------------------------------------------===//
483
484 /// If the current token has the specified kind, consume it and return true.
485 /// If not, return false.
486 bool consumeIf(Token::Kind kind) {
487 if (curToken.isNot(k: kind))
488 return false;
489 consumeToken(kind);
490 return true;
491 }
492
493 /// Advance the current lexer onto the next token.
494 void consumeToken() {
495 assert(curToken.isNot(Token::eof, Token::error) &&
496 "shouldn't advance past EOF or errors");
497 curToken = lexer.lexToken();
498 }
499
500 /// Advance the current lexer onto the next token, asserting what the expected
501 /// current token is. This is preferred to the above method because it leads
502 /// to more self-documenting code with better checking.
503 void consumeToken(Token::Kind kind) {
504 assert(curToken.is(kind) && "consumed an unexpected token");
505 consumeToken();
506 }
507
508 /// Reset the lexer to the location at the given position.
509 void resetToken(SMRange tokLoc) {
510 lexer.resetPointer(newPointer: tokLoc.Start.getPointer());
511 curToken = lexer.lexToken();
512 }
513
514 /// Consume the specified token if present and return success. On failure,
515 /// output a diagnostic and return failure.
516 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
517 if (curToken.getKind() != kind)
518 return emitError(loc: curToken.getLoc(), msg);
519 consumeToken();
520 return success();
521 }
522 LogicalResult emitError(SMRange loc, const Twine &msg) {
523 lexer.emitError(loc, msg);
524 return failure();
525 }
526 LogicalResult emitError(const Twine &msg) {
527 return emitError(loc: curToken.getLoc(), msg);
528 }
529 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
530 const Twine &note) {
531 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
532 return failure();
533 }
534
535 //===--------------------------------------------------------------------===//
536 // Fields
537 //===--------------------------------------------------------------------===//
538
539 /// The owning AST context.
540 ast::Context &ctx;
541
542 /// The lexer of this parser.
543 Lexer lexer;
544
545 /// The current token within the lexer.
546 Token curToken;
547
548 /// A flag indicating if the parser should add documentation to AST nodes when
549 /// viable.
550 bool enableDocumentation;
551
552 /// The most recently defined decl scope.
553 ast::DeclScope *curDeclScope = nullptr;
554 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
555
556 /// The current context of the parser.
557 ParserContext parserContext = ParserContext::Global;
558
559 /// Cached types to simplify verification and expression creation.
560 ast::Type typeTy, valueTy;
561 ast::RangeType typeRangeTy, valueRangeTy;
562 ast::Type attrTy;
563
564 /// A counter used when naming anonymous constraints and rewrites.
565 unsigned anonymousDeclNameCounter = 0;
566
567 /// The optional code completion context.
568 CodeCompleteContext *codeCompleteContext;
569};
570} // namespace
571
572FailureOr<ast::Module *> Parser::parseModule() {
573 SMLoc moduleLoc = curToken.getStartLoc();
574 pushDeclScope();
575
576 // Parse the top-level decls of the module.
577 SmallVector<ast::Decl *> decls;
578 if (failed(Result: parseModuleBody(decls)))
579 return popDeclScope(), failure();
580
581 popDeclScope();
582 return ast::Module::create(ctx, loc: moduleLoc, children: decls);
583}
584
585LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
586 while (curToken.isNot(k: Token::eof)) {
587 if (curToken.is(k: Token::directive)) {
588 if (failed(Result: parseDirective(decls)))
589 return failure();
590 continue;
591 }
592
593 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
594 if (failed(Result: decl))
595 return failure();
596 decls.push_back(Elt: *decl);
597 }
598 return success();
599}
600
601ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
602 return ast::AllResultsMemberAccessExpr::create(ctx, loc: opExpr->getLoc(), parentExpr: opExpr,
603 type: valueRangeTy);
604}
605
606LogicalResult Parser::convertExpressionTo(
607 ast::Expr *&expr, ast::Type type,
608 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
609 ast::Type exprType = expr->getType();
610 if (exprType == type)
611 return success();
612
613 auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
614 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
615 loc: expr->getLoc(), msg: llvm::formatv(Fmt: "unable to convert expression of type "
616 "`{0}` to the expected type of "
617 "`{1}`",
618 Vals&: exprType, Vals&: type));
619 if (noteAttachFn)
620 noteAttachFn(*diag);
621 return diag;
622 };
623
624 if (auto exprOpType = dyn_cast<ast::OperationType>(Val&: exprType))
625 return convertOpExpressionTo(expr, exprType: exprOpType, type, emitErrorFn: emitConvertError);
626
627 // FIXME: Decide how to allow/support converting a single result to multiple,
628 // and multiple to a single result. For now, we just allow Single->Range,
629 // but this isn't something really supported in the PDL dialect. We should
630 // figure out some way to support both.
631 if ((exprType == valueTy || exprType == valueRangeTy) &&
632 (type == valueTy || type == valueRangeTy))
633 return success();
634 if ((exprType == typeTy || exprType == typeRangeTy) &&
635 (type == typeTy || type == typeRangeTy))
636 return success();
637
638 // Handle tuple types.
639 if (auto exprTupleType = dyn_cast<ast::TupleType>(Val&: exprType))
640 return convertTupleExpressionTo(expr, exprType: exprTupleType, type, emitErrorFn: emitConvertError,
641 noteAttachFn);
642
643 return emitConvertError();
644}
645
646LogicalResult Parser::convertOpExpressionTo(
647 ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
648 function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
649 // Two operation types are compatible if they have the same name, or if the
650 // expected type is more general.
651 if (auto opType = dyn_cast<ast::OperationType>(Val&: type)) {
652 if (opType.getName())
653 return emitErrorFn();
654 return success();
655 }
656
657 // An operation can always convert to a ValueRange.
658 if (type == valueRangeTy) {
659 expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr,
660 type: valueRangeTy);
661 return success();
662 }
663
664 // Allow conversion to a single value by constraining the result range.
665 if (type == valueTy) {
666 // If the operation is registered, we can verify if it can ever have a
667 // single result.
668 if (const ods::Operation *odsOp = exprType.getODSOperation()) {
669 if (odsOp->getResults().empty()) {
670 return emitErrorFn()->attachNote(
671 msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined "
672 "with zero results",
673 Vals: odsOp->getName()),
674 noteLoc: odsOp->getLoc());
675 }
676
677 unsigned numSingleResults = llvm::count_if(
678 Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) {
679 return result.getVariableLengthKind() ==
680 ods::VariableLengthKind::Single;
681 });
682 if (numSingleResults > 1) {
683 return emitErrorFn()->attachNote(
684 msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined "
685 "with at least {1} results",
686 Vals: odsOp->getName(), Vals&: numSingleResults),
687 noteLoc: odsOp->getLoc());
688 }
689 }
690
691 expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr,
692 type: valueTy);
693 return success();
694 }
695 return emitErrorFn();
696}
697
698LogicalResult Parser::convertTupleExpressionTo(
699 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
700 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
701 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
702 // Handle conversions between tuples.
703 if (auto tupleType = dyn_cast<ast::TupleType>(Val&: type)) {
704 if (tupleType.size() != exprType.size())
705 return emitErrorFn();
706
707 // Build a new tuple expression using each of the elements of the current
708 // tuple.
709 SmallVector<ast::Expr *> newExprs;
710 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
711 newExprs.push_back(Elt: ast::MemberAccessExpr::create(
712 ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i),
713 type: exprType.getElementTypes()[i]));
714
715 auto diagFn = [&](ast::Diagnostic &diag) {
716 diag.attachNote(msg: llvm::formatv(Fmt: "when converting element #{0} of `{1}`",
717 Vals&: i, Vals&: exprType));
718 if (noteAttachFn)
719 noteAttachFn(diag);
720 };
721 if (failed(Result: convertExpressionTo(expr&: newExprs.back(),
722 type: tupleType.getElementTypes()[i], noteAttachFn: diagFn)))
723 return failure();
724 }
725 expr = ast::TupleExpr::create(ctx, loc: expr->getLoc(), elements: newExprs,
726 elementNames: tupleType.getElementNames());
727 return success();
728 }
729
730 // Handle conversion to a range.
731 auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
732 ast::RangeType resultTy) -> LogicalResult {
733 // TODO: We currently only allow range conversion within a rewrite context.
734 if (parserContext != ParserContext::Rewrite) {
735 return emitErrorFn()->attachNote(msg: "Tuple to Range conversion is currently "
736 "only allowed within a rewrite context");
737 }
738
739 // All of the tuple elements must be allowed types.
740 for (ast::Type elementType : exprType.getElementTypes())
741 if (!llvm::is_contained(Range&: allowedElementTypes, Element: elementType))
742 return emitErrorFn();
743
744 // Build a new tuple expression using each of the elements of the current
745 // tuple.
746 SmallVector<ast::Expr *> newExprs;
747 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
748 newExprs.push_back(Elt: ast::MemberAccessExpr::create(
749 ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i),
750 type: exprType.getElementTypes()[i]));
751 }
752 expr = ast::RangeExpr::create(ctx, loc: expr->getLoc(), elements: newExprs, type: resultTy);
753 return success();
754 };
755 if (type == valueRangeTy)
756 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
757 if (type == typeRangeTy)
758 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
759
760 return emitErrorFn();
761}
762
763//===----------------------------------------------------------------------===//
764// Directives
765//===----------------------------------------------------------------------===//
766
767LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
768 StringRef directive = curToken.getSpelling();
769 if (directive == "#include")
770 return parseInclude(decls);
771
772 return emitError(msg: "unknown directive `" + directive + "`");
773}
774
775LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
776 SMRange loc = curToken.getLoc();
777 consumeToken(kind: Token::directive);
778
779 // Handle code completion of the include file path.
780 if (curToken.is(k: Token::code_complete_string))
781 return codeCompleteIncludeFilename(curPath: curToken.getStringValue());
782
783 // Parse the file being included.
784 if (!curToken.isString())
785 return emitError(loc,
786 msg: "expected string file name after `include` directive");
787 SMRange fileLoc = curToken.getLoc();
788 std::string filenameStr = curToken.getStringValue();
789 StringRef filename = filenameStr;
790 consumeToken();
791
792 // Check the type of include. If ending with `.pdll`, this is another pdl file
793 // to be parsed along with the current module.
794 if (filename.ends_with(Suffix: ".pdll")) {
795 if (failed(Result: lexer.pushInclude(filename, includeLoc: fileLoc)))
796 return emitError(loc: fileLoc,
797 msg: "unable to open include file `" + filename + "`");
798
799 // If we added the include successfully, parse it into the current module.
800 // Make sure to update to the next token after we finish parsing the nested
801 // file.
802 curToken = lexer.lexToken();
803 LogicalResult result = parseModuleBody(decls);
804 curToken = lexer.lexToken();
805 return result;
806 }
807
808 // Otherwise, this must be a `.td` include.
809 if (filename.ends_with(Suffix: ".td"))
810 return parseTdInclude(filename, fileLoc, decls);
811
812 return emitError(loc: fileLoc,
813 msg: "expected include filename to end with `.pdll` or `.td`");
814}
815
816LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
817 SmallVectorImpl<ast::Decl *> &decls) {
818 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
819
820 // Use the source manager to open the file, but don't yet add it.
821 std::string includedFile;
822 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
823 parserSrcMgr.OpenIncludeFile(Filename: filename.str(), IncludedFile&: includedFile);
824 if (!includeBuffer)
825 return emitError(loc: fileLoc, msg: "unable to open include file `" + filename + "`");
826
827 // Setup the source manager for parsing the tablegen file.
828 llvm::SourceMgr tdSrcMgr;
829 tdSrcMgr.AddNewSourceBuffer(F: std::move(*includeBuffer), IncludeLoc: SMLoc());
830 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
831
832 // This class provides a context argument for the llvm::SourceMgr diagnostic
833 // handler.
834 struct DiagHandlerContext {
835 Parser &parser;
836 StringRef filename;
837 llvm::SMRange loc;
838 } handlerContext{.parser: *this, .filename: filename, .loc: fileLoc};
839
840 // Set the diagnostic handler for the tablegen source manager.
841 tdSrcMgr.setDiagHandler(
842 DH: [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
843 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
844 (void)ctx->parser.emitError(
845 loc: ctx->loc,
846 msg: llvm::formatv(Fmt: "error while processing include file `{0}`: {1}",
847 Vals&: ctx->filename, Vals: diag.getMessage()));
848 },
849 Ctx: &handlerContext);
850
851 // Parse the tablegen file.
852 llvm::RecordKeeper tdRecords;
853 if (llvm::TableGenParseFile(InputSrcMgr&: tdSrcMgr, Records&: tdRecords))
854 return failure();
855
856 // Process the parsed records.
857 processTdIncludeRecords(tdRecords, decls);
858
859 // After we are done processing, move all of the tablegen source buffers to
860 // the main parser source mgr. This allows for directly using source locations
861 // from the .td files without needing to remap them.
862 parserSrcMgr.takeSourceBuffersFrom(SrcMgr&: tdSrcMgr, MainBufferIncludeLoc: fileLoc.End);
863 return success();
864}
865
866void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
867 SmallVectorImpl<ast::Decl *> &decls) {
868 // Return the length kind of the given value.
869 auto getLengthKind = [](const auto &value) {
870 if (value.isOptional())
871 return ods::VariableLengthKind::Optional;
872 return value.isVariadic() ? ods::VariableLengthKind::Variadic
873 : ods::VariableLengthKind::Single;
874 };
875
876 // Insert a type constraint into the ODS context.
877 ods::Context &odsContext = ctx.getODSContext();
878 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
879 -> const ods::TypeConstraint & {
880 return odsContext.insertTypeConstraint(
881 name: cst.constraint.getUniqueDefName(),
882 summary: processDoc(doc: cst.constraint.getSummary()), cppClass: cst.constraint.getCppType());
883 };
884 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
885 return {loc, llvm::SMLoc::getFromPointer(Ptr: loc.getPointer() + 1)};
886 };
887
888 // Process the parsed tablegen records to build ODS information.
889 /// Operations.
890 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Op")) {
891 tblgen::Operator op(def);
892
893 // Check to see if this operation is known to support type inferrence.
894 bool supportsResultTypeInferrence =
895 op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait");
896
897 auto [odsOp, inserted] = odsContext.insertOperation(
898 name: op.getOperationName(), summary: processDoc(doc: op.getSummary()),
899 desc: processAndFormatDoc(doc: op.getDescription()), nativeClassName: op.getQualCppClassName(),
900 supportsResultTypeInferrence, loc: op.getLoc().front());
901
902 // Ignore operations that have already been added.
903 if (!inserted)
904 continue;
905
906 for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
907 odsOp->appendAttribute(name: attr.name, optional: attr.attr.isOptional(),
908 constraint: odsContext.insertAttributeConstraint(
909 name: attr.attr.getUniqueDefName(),
910 summary: processDoc(doc: attr.attr.getSummary()),
911 cppClass: attr.attr.getStorageType()));
912 }
913 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
914 odsOp->appendOperand(name: operand.name, variableLengthKind: getLengthKind(operand),
915 constraint: addTypeConstraint(operand));
916 }
917 for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
918 odsOp->appendResult(name: result.name, variableLengthKind: getLengthKind(result),
919 constraint: addTypeConstraint(result));
920 }
921 }
922
923 auto shouldBeSkipped = [this](const llvm::Record *def) {
924 return def->isAnonymous() || curDeclScope->lookup(name: def->getName()) ||
925 def->isSubClassOf(Name: "DeclareInterfaceMethods");
926 };
927
928 /// Attr constraints.
929 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Attr")) {
930 if (shouldBeSkipped(def))
931 continue;
932
933 tblgen::Attribute constraint(def);
934 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
935 constraint, loc: convertLocToRange(def->getLoc().front()), type: attrTy,
936 nativeType: constraint.getStorageType()));
937 }
938 /// Type constraints.
939 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Type")) {
940 if (shouldBeSkipped(def))
941 continue;
942
943 tblgen::TypeConstraint constraint(def);
944 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
945 constraint, loc: convertLocToRange(def->getLoc().front()), type: typeTy,
946 nativeType: constraint.getCppType()));
947 }
948 /// OpInterfaces.
949 ast::Type opTy = ast::OperationType::get(context&: ctx);
950 for (const llvm::Record *def :
951 tdRecords.getAllDerivedDefinitions(ClassName: "OpInterface")) {
952 if (shouldBeSkipped(def))
953 continue;
954
955 SMRange loc = convertLocToRange(def->getLoc().front());
956
957 std::string cppClassName =
958 llvm::formatv(Fmt: "{0}::{1}", Vals: def->getValueAsString(FieldName: "cppNamespace"),
959 Vals: def->getValueAsString(FieldName: "cppInterfaceName"))
960 .str();
961 std::string codeBlock =
962 llvm::formatv(Fmt: "return ::mlir::success(llvm::isa<{0}>(self));",
963 Vals&: cppClassName)
964 .str();
965
966 std::string desc =
967 processAndFormatDoc(doc: def->getValueAsString(FieldName: "description"));
968 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
969 name: def->getName(), codeBlock, loc, type: opTy, nativeType: cppClassName, docString: desc));
970 }
971}
972
973template <typename ConstraintT>
974ast::Decl *Parser::createODSNativePDLLConstraintDecl(
975 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
976 StringRef nativeType, StringRef docString) {
977 // Build the single input parameter.
978 ast::DeclScope *argScope = pushDeclScope();
979 auto *paramVar = ast::VariableDecl::create(
980 ctx, name: ast::Name::create(ctx, name: "self", location: loc), type,
981 /*initExpr=*/nullptr, constraints: ast::ConstraintRef(ConstraintT::create(ctx, loc)));
982 argScope->add(decl: paramVar);
983 popDeclScope();
984
985 // Build the native constraint.
986 auto *constraintDecl = ast::UserConstraintDecl::createNative(
987 ctx, name: ast::Name::create(ctx, name, location: loc), inputs: paramVar,
988 /*results=*/std::nullopt, codeBlock, resultType: ast::TupleType::get(context&: ctx),
989 nativeInputTypes: nativeType);
990 constraintDecl->setDocComment(ctx, comment: docString);
991 curDeclScope->add(decl: constraintDecl);
992 return constraintDecl;
993}
994
995template <typename ConstraintT>
996ast::Decl *
997Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
998 SMRange loc, ast::Type type,
999 StringRef nativeType) {
1000 // Format the condition template.
1001 tblgen::FmtContext fmtContext;
1002 fmtContext.withSelf(subst: "self");
1003 std::string codeBlock = tblgen::tgfmt(
1004 fmt: "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1005 ctx: &fmtContext);
1006
1007 // If documentation was enabled, build the doc string for the generated
1008 // constraint. It would be nice to do this lazily, but TableGen information is
1009 // destroyed after we finish parsing the file.
1010 std::string docString;
1011 if (enableDocumentation) {
1012 StringRef desc = constraint.getDescription();
1013 docString = processAndFormatDoc(
1014 doc: constraint.getSummary() +
1015 (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1016 }
1017
1018 return createODSNativePDLLConstraintDecl<ConstraintT>(
1019 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1020 docString);
1021}
1022
1023//===----------------------------------------------------------------------===//
1024// Decls
1025//===----------------------------------------------------------------------===//
1026
1027FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1028 FailureOr<ast::Decl *> decl;
1029 switch (curToken.getKind()) {
1030 case Token::kw_Constraint:
1031 decl = parseUserConstraintDecl();
1032 break;
1033 case Token::kw_Pattern:
1034 decl = parsePatternDecl();
1035 break;
1036 case Token::kw_Rewrite:
1037 decl = parseUserRewriteDecl();
1038 break;
1039 default:
1040 return emitError(msg: "expected top-level declaration, such as a `Pattern`");
1041 }
1042 if (failed(Result: decl))
1043 return failure();
1044
1045 // If the decl has a name, add it to the current scope.
1046 if (const ast::Name *name = (*decl)->getName()) {
1047 if (failed(Result: checkDefineNamedDecl(name: *name)))
1048 return failure();
1049 curDeclScope->add(decl: *decl);
1050 }
1051 return decl;
1052}
1053
1054FailureOr<ast::NamedAttributeDecl *>
1055Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1056 // Check for name code completion.
1057 if (curToken.is(k: Token::code_complete))
1058 return codeCompleteAttributeName(opName: parentOpName);
1059
1060 std::string attrNameStr;
1061 if (curToken.isString())
1062 attrNameStr = curToken.getStringValue();
1063 else if (curToken.is(k: Token::identifier) || curToken.isKeyword())
1064 attrNameStr = curToken.getSpelling().str();
1065 else
1066 return emitError(msg: "expected identifier or string attribute name");
1067 const auto &name = ast::Name::create(ctx, name: attrNameStr, location: curToken.getLoc());
1068 consumeToken();
1069
1070 // Check for a value of the attribute.
1071 ast::Expr *attrValue = nullptr;
1072 if (consumeIf(kind: Token::equal)) {
1073 FailureOr<ast::Expr *> attrExpr = parseExpr();
1074 if (failed(Result: attrExpr))
1075 return failure();
1076 attrValue = *attrExpr;
1077 } else {
1078 // If there isn't a concrete value, create an expression representing a
1079 // UnitAttr.
1080 attrValue = ast::AttributeExpr::create(ctx, loc: name.getLoc(), value: "unit");
1081 }
1082
1083 return ast::NamedAttributeDecl::create(ctx, name, value: attrValue);
1084}
1085
1086FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1087 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1088 bool expectTerminalSemicolon) {
1089 consumeToken(kind: Token::equal_arrow);
1090
1091 // Parse the single statement of the lambda body.
1092 SMLoc bodyStartLoc = curToken.getStartLoc();
1093 pushDeclScope();
1094 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1095 bool failedToParse =
1096 failed(Result: singleStatement) || failed(Result: processStatementFn(*singleStatement));
1097 popDeclScope();
1098 if (failedToParse)
1099 return failure();
1100
1101 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1102 return ast::CompoundStmt::create(ctx, location: bodyLoc, children: *singleStatement);
1103}
1104
1105FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1106 // Ensure that the argument is named.
1107 if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword())
1108 return emitError(msg: "expected identifier argument name");
1109
1110 // Parse the argument similarly to a normal variable.
1111 StringRef name = curToken.getSpelling();
1112 SMRange nameLoc = curToken.getLoc();
1113 consumeToken();
1114
1115 if (failed(
1116 Result: parseToken(kind: Token::colon, msg: "expected `:` before argument constraint")))
1117 return failure();
1118
1119 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1120 if (failed(Result: cst))
1121 return failure();
1122
1123 return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst);
1124}
1125
1126FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1127 // Check to see if this result is named.
1128 if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) {
1129 // Check to see if this name actually refers to a Constraint.
1130 if (!curDeclScope->lookup<ast::ConstraintDecl>(name: curToken.getSpelling())) {
1131 // If it wasn't a constraint, parse the result similarly to a variable. If
1132 // there is already an existing decl, we will emit an error when defining
1133 // this variable later.
1134 StringRef name = curToken.getSpelling();
1135 SMRange nameLoc = curToken.getLoc();
1136 consumeToken();
1137
1138 if (failed(Result: parseToken(kind: Token::colon,
1139 msg: "expected `:` before result constraint")))
1140 return failure();
1141
1142 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1143 if (failed(Result: cst))
1144 return failure();
1145
1146 return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst);
1147 }
1148 }
1149
1150 // If it isn't named, we parse the constraint directly and create an unnamed
1151 // result variable.
1152 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1153 if (failed(Result: cst))
1154 return failure();
1155
1156 return createArgOrResultVariableDecl(name: "", loc: cst->referenceLoc, constraint: *cst);
1157}
1158
1159FailureOr<ast::UserConstraintDecl *>
1160Parser::parseUserConstraintDecl(bool isInline) {
1161 // Constraints and rewrites have very similar formats, dispatch to a shared
1162 // interface for parsing.
1163 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1164 parseUserPDLLFn: [&](auto &&...args) {
1165 return this->parseUserPDLLConstraintDecl(name: args...);
1166 },
1167 declContext: ParserContext::Constraint, anonymousNamePrefix: "constraint", isInline);
1168}
1169
1170FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1171 FailureOr<ast::UserConstraintDecl *> decl =
1172 parseUserConstraintDecl(/*isInline=*/true);
1173 if (failed(Result: decl) || failed(Result: checkDefineNamedDecl(name: (*decl)->getName())))
1174 return failure();
1175
1176 curDeclScope->add(decl: *decl);
1177 return decl;
1178}
1179
1180FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1181 const ast::Name &name, bool isInline,
1182 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1183 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1184 // Push the argument scope back onto the list, so that the body can
1185 // reference arguments.
1186 pushDeclScope(scope: argumentScope);
1187
1188 // Parse the body of the constraint. The body is either defined as a compound
1189 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1190 ast::CompoundStmt *body;
1191 if (curToken.is(k: Token::equal_arrow)) {
1192 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1193 processStatementFn: [&](ast::Stmt *&stmt) -> LogicalResult {
1194 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(Val: stmt);
1195 if (!stmtExpr) {
1196 return emitError(loc: stmt->getLoc(),
1197 msg: "expected `Constraint` lambda body to contain a "
1198 "single expression");
1199 }
1200 stmt = ast::ReturnStmt::create(ctx, loc: stmt->getLoc(), resultExpr: stmtExpr);
1201 return success();
1202 },
1203 /*expectTerminalSemicolon=*/!isInline);
1204 if (failed(Result: bodyResult))
1205 return failure();
1206 body = *bodyResult;
1207 } else {
1208 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1209 if (failed(Result: bodyResult))
1210 return failure();
1211 body = *bodyResult;
1212
1213 // Verify the structure of the body.
1214 auto bodyIt = body->begin(), bodyE = body->end();
1215 for (; bodyIt != bodyE; ++bodyIt)
1216 if (isa<ast::ReturnStmt>(Val: *bodyIt))
1217 break;
1218 if (failed(Result: validateUserConstraintOrRewriteReturn(
1219 declType: "Constraint", body, bodyIt, bodyE, results, resultType)))
1220 return failure();
1221 }
1222 popDeclScope();
1223
1224 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1225 name, arguments, results, resultType, body);
1226}
1227
1228FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1229 // Constraints and rewrites have very similar formats, dispatch to a shared
1230 // interface for parsing.
1231 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1232 parseUserPDLLFn: [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(name: args...); },
1233 declContext: ParserContext::Rewrite, anonymousNamePrefix: "rewrite", isInline);
1234}
1235
1236FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1237 FailureOr<ast::UserRewriteDecl *> decl =
1238 parseUserRewriteDecl(/*isInline=*/true);
1239 if (failed(Result: decl) || failed(Result: checkDefineNamedDecl(name: (*decl)->getName())))
1240 return failure();
1241
1242 curDeclScope->add(decl: *decl);
1243 return decl;
1244}
1245
1246FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1247 const ast::Name &name, bool isInline,
1248 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1249 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1250 // Push the argument scope back onto the list, so that the body can
1251 // reference arguments.
1252 curDeclScope = argumentScope;
1253 ast::CompoundStmt *body;
1254 if (curToken.is(k: Token::equal_arrow)) {
1255 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1256 processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult {
1257 if (isa<ast::OpRewriteStmt>(Val: statement))
1258 return success();
1259
1260 ast::Expr *statementExpr = dyn_cast<ast::Expr>(Val: statement);
1261 if (!statementExpr) {
1262 return emitError(
1263 loc: statement->getLoc(),
1264 msg: "expected `Rewrite` lambda body to contain a single expression "
1265 "or an operation rewrite statement; such as `erase`, "
1266 "`replace`, or `rewrite`");
1267 }
1268 statement =
1269 ast::ReturnStmt::create(ctx, loc: statement->getLoc(), resultExpr: statementExpr);
1270 return success();
1271 },
1272 /*expectTerminalSemicolon=*/!isInline);
1273 if (failed(Result: bodyResult))
1274 return failure();
1275 body = *bodyResult;
1276 } else {
1277 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1278 if (failed(Result: bodyResult))
1279 return failure();
1280 body = *bodyResult;
1281 }
1282 popDeclScope();
1283
1284 // Verify the structure of the body.
1285 auto bodyIt = body->begin(), bodyE = body->end();
1286 for (; bodyIt != bodyE; ++bodyIt)
1287 if (isa<ast::ReturnStmt>(Val: *bodyIt))
1288 break;
1289 if (failed(Result: validateUserConstraintOrRewriteReturn(declType: "Rewrite", body, bodyIt,
1290 bodyE, results, resultType)))
1291 return failure();
1292 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1293 name, arguments, results, resultType, body);
1294}
1295
1296template <typename T, typename ParseUserPDLLDeclFnT>
1297FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1298 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1299 StringRef anonymousNamePrefix, bool isInline) {
1300 SMRange loc = curToken.getLoc();
1301 consumeToken();
1302 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1303
1304 // Parse the name of the decl.
1305 const ast::Name *name = nullptr;
1306 if (curToken.isNot(k: Token::identifier)) {
1307 // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1308 // in C++, so being unnamed is fine.
1309 if (!isInline)
1310 return emitError(msg: "expected identifier name");
1311
1312 // Create a unique anonymous name to use, as the name for this decl is not
1313 // important.
1314 std::string anonName =
1315 llvm::formatv(Fmt: "<anonymous_{0}_{1}>", Vals&: anonymousNamePrefix,
1316 Vals: anonymousDeclNameCounter++)
1317 .str();
1318 name = &ast::Name::create(ctx, name: anonName, location: loc);
1319 } else {
1320 // If a name was provided, we can use it directly.
1321 name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc());
1322 consumeToken(kind: Token::identifier);
1323 }
1324
1325 // Parse the functional signature of the decl.
1326 SmallVector<ast::VariableDecl *> arguments, results;
1327 ast::DeclScope *argumentScope;
1328 ast::Type resultType;
1329 if (failed(Result: parseUserConstraintOrRewriteSignature(arguments, results,
1330 argumentScope, resultType)))
1331 return failure();
1332
1333 // Check to see which type of constraint this is. If the constraint contains a
1334 // compound body, this is a PDLL decl.
1335 if (curToken.isAny(k1: Token::l_brace, k2: Token::equal_arrow))
1336 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1337 resultType);
1338
1339 // Otherwise, this is a native decl.
1340 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1341 results, resultType);
1342}
1343
1344template <typename T>
1345FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1346 const ast::Name &name, bool isInline,
1347 ArrayRef<ast::VariableDecl *> arguments,
1348 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1349 // If followed by a string, the native code body has also been specified.
1350 std::string codeStrStorage;
1351 std::optional<StringRef> optCodeStr;
1352 if (curToken.isString()) {
1353 codeStrStorage = curToken.getStringValue();
1354 optCodeStr = codeStrStorage;
1355 consumeToken();
1356 } else if (isInline) {
1357 return emitError(loc: name.getLoc(),
1358 msg: "external declarations must be declared in global scope");
1359 } else if (curToken.is(k: Token::error)) {
1360 return failure();
1361 }
1362 if (failed(Result: parseToken(kind: Token::semicolon,
1363 msg: "expected `;` after native declaration")))
1364 return failure();
1365 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1366}
1367
1368LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1369 SmallVectorImpl<ast::VariableDecl *> &arguments,
1370 SmallVectorImpl<ast::VariableDecl *> &results,
1371 ast::DeclScope *&argumentScope, ast::Type &resultType) {
1372 // Parse the argument list of the decl.
1373 if (failed(Result: parseToken(kind: Token::l_paren, msg: "expected `(` to start argument list")))
1374 return failure();
1375
1376 argumentScope = pushDeclScope();
1377 if (curToken.isNot(k: Token::r_paren)) {
1378 do {
1379 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1380 if (failed(Result: argument))
1381 return failure();
1382 arguments.emplace_back(Args&: *argument);
1383 } while (consumeIf(kind: Token::comma));
1384 }
1385 popDeclScope();
1386 if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` to end argument list")))
1387 return failure();
1388
1389 // Parse the results of the decl.
1390 pushDeclScope();
1391 if (consumeIf(kind: Token::arrow)) {
1392 auto parseResultFn = [&]() -> LogicalResult {
1393 FailureOr<ast::VariableDecl *> result = parseResultDecl(resultNum: results.size());
1394 if (failed(Result: result))
1395 return failure();
1396 results.emplace_back(Args&: *result);
1397 return success();
1398 };
1399
1400 // Check for a list of results.
1401 if (consumeIf(kind: Token::l_paren)) {
1402 do {
1403 if (failed(Result: parseResultFn()))
1404 return failure();
1405 } while (consumeIf(kind: Token::comma));
1406 if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` to end result list")))
1407 return failure();
1408
1409 // Otherwise, there is only one result.
1410 } else if (failed(Result: parseResultFn())) {
1411 return failure();
1412 }
1413 }
1414 popDeclScope();
1415
1416 // Compute the result type of the decl.
1417 resultType = createUserConstraintRewriteResultType(results);
1418
1419 // Verify that results are only named if there are more than one.
1420 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1421 return emitError(
1422 loc: results.front()->getLoc(),
1423 msg: "cannot create a single-element tuple with an element label");
1424 }
1425 return success();
1426}
1427
1428LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1429 StringRef declType, ast::CompoundStmt *body,
1430 ArrayRef<ast::Stmt *>::iterator bodyIt,
1431 ArrayRef<ast::Stmt *>::iterator bodyE,
1432 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1433 // Handle if a `return` was provided.
1434 if (bodyIt != bodyE) {
1435 // Emit an error if we have trailing statements after the return.
1436 if (std::next(x: bodyIt) != bodyE) {
1437 return emitError(
1438 loc: (*std::next(x: bodyIt))->getLoc(),
1439 msg: llvm::formatv(Fmt: "`return` terminated the `{0}` body, but found "
1440 "trailing statements afterwards",
1441 Vals&: declType));
1442 }
1443
1444 // Otherwise if a return wasn't provided, check that no results are
1445 // expected.
1446 } else if (!results.empty()) {
1447 return emitError(
1448 loc: {body->getLoc().End, body->getLoc().End},
1449 msg: llvm::formatv(Fmt: "missing return in a `{0}` expected to return `{1}`",
1450 Vals&: declType, Vals&: resultType));
1451 }
1452 return success();
1453}
1454
1455FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1456 return parseLambdaBody(processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult {
1457 if (isa<ast::OpRewriteStmt>(Val: statement))
1458 return success();
1459 return emitError(
1460 loc: statement->getLoc(),
1461 msg: "expected Pattern lambda body to contain a single operation "
1462 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1463 });
1464}
1465
1466FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1467 SMRange loc = curToken.getLoc();
1468 consumeToken(kind: Token::kw_Pattern);
1469 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1470
1471 // Check for an optional identifier for the pattern name.
1472 const ast::Name *name = nullptr;
1473 if (curToken.is(k: Token::identifier)) {
1474 name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc());
1475 consumeToken(kind: Token::identifier);
1476 }
1477
1478 // Parse any pattern metadata.
1479 ParsedPatternMetadata metadata;
1480 if (consumeIf(kind: Token::kw_with) && failed(Result: parsePatternDeclMetadata(metadata)))
1481 return failure();
1482
1483 // Parse the pattern body.
1484 ast::CompoundStmt *body;
1485
1486 // Handle a lambda body.
1487 if (curToken.is(k: Token::equal_arrow)) {
1488 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1489 if (failed(Result: bodyResult))
1490 return failure();
1491 body = *bodyResult;
1492 } else {
1493 if (curToken.isNot(k: Token::l_brace))
1494 return emitError(msg: "expected `{` or `=>` to start pattern body");
1495 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1496 if (failed(Result: bodyResult))
1497 return failure();
1498 body = *bodyResult;
1499
1500 // Verify the body of the pattern.
1501 auto bodyIt = body->begin(), bodyE = body->end();
1502 for (; bodyIt != bodyE; ++bodyIt) {
1503 if (isa<ast::ReturnStmt>(Val: *bodyIt)) {
1504 return emitError(loc: (*bodyIt)->getLoc(),
1505 msg: "`return` statements are only permitted within a "
1506 "`Constraint` or `Rewrite` body");
1507 }
1508 // Break when we've found the rewrite statement.
1509 if (isa<ast::OpRewriteStmt>(Val: *bodyIt))
1510 break;
1511 }
1512 if (bodyIt == bodyE) {
1513 return emitError(loc,
1514 msg: "expected Pattern body to terminate with an operation "
1515 "rewrite statement, such as `erase`");
1516 }
1517 if (std::next(x: bodyIt) != bodyE) {
1518 return emitError(loc: (*std::next(x: bodyIt))->getLoc(),
1519 msg: "Pattern body was terminated by an operation "
1520 "rewrite statement, but found trailing statements");
1521 }
1522 }
1523
1524 return createPatternDecl(loc, name, metadata, body);
1525}
1526
1527LogicalResult
1528Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1529 std::optional<SMRange> benefitLoc;
1530 std::optional<SMRange> hasBoundedRecursionLoc;
1531
1532 do {
1533 // Handle metadata code completion.
1534 if (curToken.is(k: Token::code_complete))
1535 return codeCompletePatternMetadata();
1536
1537 if (curToken.isNot(k: Token::identifier))
1538 return emitError(msg: "expected pattern metadata identifier");
1539 StringRef metadataStr = curToken.getSpelling();
1540 SMRange metadataLoc = curToken.getLoc();
1541 consumeToken(kind: Token::identifier);
1542
1543 // Parse the benefit metadata: benefit(<integer-value>)
1544 if (metadataStr == "benefit") {
1545 if (benefitLoc) {
1546 return emitErrorAndNote(loc: metadataLoc,
1547 msg: "pattern benefit has already been specified",
1548 noteLoc: *benefitLoc, note: "see previous definition here");
1549 }
1550 if (failed(Result: parseToken(kind: Token::l_paren,
1551 msg: "expected `(` before pattern benefit")))
1552 return failure();
1553
1554 uint16_t benefitValue = 0;
1555 if (curToken.isNot(k: Token::integer))
1556 return emitError(msg: "expected integral pattern benefit");
1557 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, Result&: benefitValue))
1558 return emitError(
1559 msg: "expected pattern benefit to fit within a 16-bit integer");
1560 consumeToken(kind: Token::integer);
1561
1562 metadata.benefit = benefitValue;
1563 benefitLoc = metadataLoc;
1564
1565 if (failed(
1566 Result: parseToken(kind: Token::r_paren, msg: "expected `)` after pattern benefit")))
1567 return failure();
1568 continue;
1569 }
1570
1571 // Parse the bounded recursion metadata: recursion
1572 if (metadataStr == "recursion") {
1573 if (hasBoundedRecursionLoc) {
1574 return emitErrorAndNote(
1575 loc: metadataLoc,
1576 msg: "pattern recursion metadata has already been specified",
1577 noteLoc: *hasBoundedRecursionLoc, note: "see previous definition here");
1578 }
1579 metadata.hasBoundedRecursion = true;
1580 hasBoundedRecursionLoc = metadataLoc;
1581 continue;
1582 }
1583
1584 return emitError(loc: metadataLoc, msg: "unknown pattern metadata");
1585 } while (consumeIf(kind: Token::comma));
1586
1587 return success();
1588}
1589
1590FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1591 consumeToken(kind: Token::less);
1592
1593 FailureOr<ast::Expr *> typeExpr = parseExpr();
1594 if (failed(Result: typeExpr) ||
1595 failed(Result: parseToken(kind: Token::greater,
1596 msg: "expected `>` after variable type constraint")))
1597 return failure();
1598 return typeExpr;
1599}
1600
1601LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1602 assert(curDeclScope && "defining decl outside of a decl scope");
1603 if (ast::Decl *lastDecl = curDeclScope->lookup(name: name.getName())) {
1604 return emitErrorAndNote(
1605 loc: name.getLoc(), msg: "`" + name.getName() + "` has already been defined",
1606 noteLoc: lastDecl->getName()->getLoc(), note: "see previous definition here");
1607 }
1608 return success();
1609}
1610
1611FailureOr<ast::VariableDecl *>
1612Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1613 ast::Expr *initExpr,
1614 ArrayRef<ast::ConstraintRef> constraints) {
1615 assert(curDeclScope && "defining variable outside of decl scope");
1616 const ast::Name &nameDecl = ast::Name::create(ctx, name, location: nameLoc);
1617
1618 // If the name of the variable indicates a special variable, we don't add it
1619 // to the scope. This variable is local to the definition point.
1620 if (name.empty() || name == "_") {
1621 return ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr,
1622 constraints);
1623 }
1624 if (failed(Result: checkDefineNamedDecl(name: nameDecl)))
1625 return failure();
1626
1627 auto *varDecl =
1628 ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr, constraints);
1629 curDeclScope->add(decl: varDecl);
1630 return varDecl;
1631}
1632
1633FailureOr<ast::VariableDecl *>
1634Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1635 ArrayRef<ast::ConstraintRef> constraints) {
1636 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1637 constraints);
1638}
1639
1640LogicalResult Parser::parseVariableDeclConstraintList(
1641 SmallVectorImpl<ast::ConstraintRef> &constraints) {
1642 std::optional<SMRange> typeConstraint;
1643 auto parseSingleConstraint = [&] {
1644 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1645 typeConstraint, existingConstraints: constraints, /*allowInlineTypeConstraints=*/true);
1646 if (failed(Result: constraint))
1647 return failure();
1648 constraints.push_back(Elt: *constraint);
1649 return success();
1650 };
1651
1652 // Check to see if this is a single constraint, or a list.
1653 if (!consumeIf(kind: Token::l_square))
1654 return parseSingleConstraint();
1655
1656 do {
1657 if (failed(Result: parseSingleConstraint()))
1658 return failure();
1659 } while (consumeIf(kind: Token::comma));
1660 return parseToken(kind: Token::r_square, msg: "expected `]` after constraint list");
1661}
1662
1663FailureOr<ast::ConstraintRef>
1664Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1665 ArrayRef<ast::ConstraintRef> existingConstraints,
1666 bool allowInlineTypeConstraints) {
1667 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1668 if (!allowInlineTypeConstraints) {
1669 return emitError(
1670 loc: curToken.getLoc(),
1671 msg: "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1672 "permitted on arguments or results");
1673 }
1674 if (typeConstraint)
1675 return emitErrorAndNote(
1676 loc: curToken.getLoc(),
1677 msg: "the type of this variable has already been constrained",
1678 noteLoc: *typeConstraint, note: "see previous constraint location here");
1679 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1680 if (failed(Result: constraintExpr))
1681 return failure();
1682 typeExpr = *constraintExpr;
1683 typeConstraint = typeExpr->getLoc();
1684 return success();
1685 };
1686
1687 SMRange loc = curToken.getLoc();
1688 switch (curToken.getKind()) {
1689 case Token::kw_Attr: {
1690 consumeToken(kind: Token::kw_Attr);
1691
1692 // Check for a type constraint.
1693 ast::Expr *typeExpr = nullptr;
1694 if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr)))
1695 return failure();
1696 return ast::ConstraintRef(
1697 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1698 }
1699 case Token::kw_Op: {
1700 consumeToken(kind: Token::kw_Op);
1701
1702 // Parse an optional operation name. If the name isn't provided, this refers
1703 // to "any" operation.
1704 FailureOr<ast::OpNameDecl *> opName =
1705 parseWrappedOperationName(/*allowEmptyName=*/true);
1706 if (failed(Result: opName))
1707 return failure();
1708
1709 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, nameDecl: *opName),
1710 loc);
1711 }
1712 case Token::kw_Type:
1713 consumeToken(kind: Token::kw_Type);
1714 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1715 case Token::kw_TypeRange:
1716 consumeToken(kind: Token::kw_TypeRange);
1717 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1718 loc);
1719 case Token::kw_Value: {
1720 consumeToken(kind: Token::kw_Value);
1721
1722 // Check for a type constraint.
1723 ast::Expr *typeExpr = nullptr;
1724 if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr)))
1725 return failure();
1726
1727 return ast::ConstraintRef(
1728 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1729 }
1730 case Token::kw_ValueRange: {
1731 consumeToken(kind: Token::kw_ValueRange);
1732
1733 // Check for a type constraint.
1734 ast::Expr *typeExpr = nullptr;
1735 if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr)))
1736 return failure();
1737
1738 return ast::ConstraintRef(
1739 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1740 }
1741
1742 case Token::kw_Constraint: {
1743 // Handle an inline constraint.
1744 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1745 if (failed(Result: decl))
1746 return failure();
1747 return ast::ConstraintRef(*decl, loc);
1748 }
1749 case Token::identifier: {
1750 StringRef constraintName = curToken.getSpelling();
1751 consumeToken(kind: Token::identifier);
1752
1753 // Lookup the referenced constraint.
1754 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(name: constraintName);
1755 if (!cstDecl) {
1756 return emitError(loc, msg: "unknown reference to constraint `" +
1757 constraintName + "`");
1758 }
1759
1760 // Handle a reference to a proper constraint.
1761 if (auto *cst = dyn_cast<ast::ConstraintDecl>(Val: cstDecl))
1762 return ast::ConstraintRef(cst, loc);
1763
1764 return emitErrorAndNote(
1765 loc, msg: "invalid reference to non-constraint", noteLoc: cstDecl->getLoc(),
1766 note: "see the definition of `" + constraintName + "` here");
1767 }
1768 // Handle single entity constraint code completion.
1769 case Token::code_complete: {
1770 // Try to infer the current type for use by code completion.
1771 ast::Type inferredType;
1772 if (failed(Result: validateVariableConstraints(constraints: existingConstraints, inferredType)))
1773 return failure();
1774
1775 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1776 }
1777 default:
1778 break;
1779 }
1780 return emitError(loc, msg: "expected identifier constraint");
1781}
1782
1783FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1784 std::optional<SMRange> typeConstraint;
1785 return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1786 /*allowInlineTypeConstraints=*/false);
1787}
1788
1789//===----------------------------------------------------------------------===//
1790// Exprs
1791//===----------------------------------------------------------------------===//
1792
1793FailureOr<ast::Expr *> Parser::parseExpr() {
1794 if (curToken.is(k: Token::underscore))
1795 return parseUnderscoreExpr();
1796
1797 // Parse the LHS expression.
1798 FailureOr<ast::Expr *> lhsExpr;
1799 switch (curToken.getKind()) {
1800 case Token::kw_attr:
1801 lhsExpr = parseAttributeExpr();
1802 break;
1803 case Token::kw_Constraint:
1804 lhsExpr = parseInlineConstraintLambdaExpr();
1805 break;
1806 case Token::kw_not:
1807 lhsExpr = parseNegatedExpr();
1808 break;
1809 case Token::identifier:
1810 lhsExpr = parseIdentifierExpr();
1811 break;
1812 case Token::kw_op:
1813 lhsExpr = parseOperationExpr();
1814 break;
1815 case Token::kw_Rewrite:
1816 lhsExpr = parseInlineRewriteLambdaExpr();
1817 break;
1818 case Token::kw_type:
1819 lhsExpr = parseTypeExpr();
1820 break;
1821 case Token::l_paren:
1822 lhsExpr = parseTupleExpr();
1823 break;
1824 default:
1825 return emitError(msg: "expected expression");
1826 }
1827 if (failed(Result: lhsExpr))
1828 return failure();
1829
1830 // Check for an operator expression.
1831 while (true) {
1832 switch (curToken.getKind()) {
1833 case Token::dot:
1834 lhsExpr = parseMemberAccessExpr(parentExpr: *lhsExpr);
1835 break;
1836 case Token::l_paren:
1837 lhsExpr = parseCallExpr(parentExpr: *lhsExpr);
1838 break;
1839 default:
1840 return lhsExpr;
1841 }
1842 if (failed(Result: lhsExpr))
1843 return failure();
1844 }
1845}
1846
1847FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1848 SMRange loc = curToken.getLoc();
1849 consumeToken(kind: Token::kw_attr);
1850
1851 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1852 // identifier.
1853 if (!consumeIf(kind: Token::less)) {
1854 resetToken(tokLoc: loc);
1855 return parseIdentifierExpr();
1856 }
1857
1858 if (!curToken.isString())
1859 return emitError(msg: "expected string literal containing MLIR attribute");
1860 std::string attrExpr = curToken.getStringValue();
1861 consumeToken();
1862
1863 loc.End = curToken.getEndLoc();
1864 if (failed(
1865 Result: parseToken(kind: Token::greater, msg: "expected `>` after attribute literal")))
1866 return failure();
1867 return ast::AttributeExpr::create(ctx, loc, value: attrExpr);
1868}
1869
1870FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1871 bool isNegated) {
1872 consumeToken(kind: Token::l_paren);
1873
1874 // Parse the arguments of the call.
1875 SmallVector<ast::Expr *> arguments;
1876 if (curToken.isNot(k: Token::r_paren)) {
1877 do {
1878 // Handle code completion for the call arguments.
1879 if (curToken.is(k: Token::code_complete)) {
1880 codeCompleteCallSignature(parent: parentExpr, currentNumArgs: arguments.size());
1881 return failure();
1882 }
1883
1884 FailureOr<ast::Expr *> argument = parseExpr();
1885 if (failed(Result: argument))
1886 return failure();
1887 arguments.push_back(Elt: *argument);
1888 } while (consumeIf(kind: Token::comma));
1889 }
1890
1891 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1892 if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` after argument list")))
1893 return failure();
1894
1895 return createCallExpr(loc, parentExpr, arguments, isNegated);
1896}
1897
1898FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1899 ast::Decl *decl = curDeclScope->lookup(name);
1900 if (!decl)
1901 return emitError(loc, msg: "undefined reference to `" + name + "`");
1902
1903 return createDeclRefExpr(loc, decl);
1904}
1905
1906FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1907 StringRef name = curToken.getSpelling();
1908 SMRange nameLoc = curToken.getLoc();
1909 consumeToken();
1910
1911 // Check to see if this is a decl ref expression that defines a variable
1912 // inline.
1913 if (consumeIf(kind: Token::colon)) {
1914 SmallVector<ast::ConstraintRef> constraints;
1915 if (failed(Result: parseVariableDeclConstraintList(constraints)))
1916 return failure();
1917 ast::Type type;
1918 if (failed(Result: validateVariableConstraints(constraints, inferredType&: type)))
1919 return failure();
1920 return createInlineVariableExpr(type, name, loc: nameLoc, constraints);
1921 }
1922
1923 return parseDeclRefExpr(name, loc: nameLoc);
1924}
1925
1926FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1927 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1928 if (failed(Result: decl))
1929 return failure();
1930
1931 return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl,
1932 type: ast::ConstraintType::get(context&: ctx));
1933}
1934
1935FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1936 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1937 if (failed(Result: decl))
1938 return failure();
1939
1940 return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl,
1941 type: ast::RewriteType::get(context&: ctx));
1942}
1943
1944FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1945 SMRange dotLoc = curToken.getLoc();
1946 consumeToken(kind: Token::dot);
1947
1948 // Check for code completion of the member name.
1949 if (curToken.is(k: Token::code_complete))
1950 return codeCompleteMemberAccess(parentExpr);
1951
1952 // Parse the member name.
1953 Token memberNameTok = curToken;
1954 if (memberNameTok.isNot(k1: Token::identifier, k2: Token::integer) &&
1955 !memberNameTok.isKeyword())
1956 return emitError(loc: dotLoc, msg: "expected identifier or numeric member name");
1957 StringRef memberName = memberNameTok.getSpelling();
1958 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1959 consumeToken();
1960
1961 return createMemberAccessExpr(parentExpr, name: memberName, loc);
1962}
1963
1964FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1965 consumeToken(kind: Token::kw_not);
1966 // Only native constraints are supported after negation
1967 if (!curToken.is(k: Token::identifier))
1968 return emitError(msg: "expected native constraint");
1969 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1970 if (failed(Result: identifierExpr))
1971 return failure();
1972 if (!curToken.is(k: Token::l_paren))
1973 return emitError(msg: "expected `(` after function name");
1974 return parseCallExpr(parentExpr: *identifierExpr, /*isNegated = */ true);
1975}
1976
1977FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1978 SMRange loc = curToken.getLoc();
1979
1980 // Check for code completion for the dialect name.
1981 if (curToken.is(k: Token::code_complete))
1982 return codeCompleteDialectName();
1983
1984 // Handle the case of an no operation name.
1985 if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword()) {
1986 if (allowEmptyName)
1987 return ast::OpNameDecl::create(ctx, loc: SMRange());
1988 return emitError(msg: "expected dialect namespace");
1989 }
1990 StringRef name = curToken.getSpelling();
1991 consumeToken();
1992
1993 // Otherwise, this is a literal operation name.
1994 if (failed(Result: parseToken(kind: Token::dot, msg: "expected `.` after dialect namespace")))
1995 return failure();
1996
1997 // Check for code completion for the operation name.
1998 if (curToken.is(k: Token::code_complete))
1999 return codeCompleteOperationName(dialectName: name);
2000
2001 if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword())
2002 return emitError(msg: "expected operation name after dialect namespace");
2003
2004 name = StringRef(name.data(), name.size() + 1);
2005 do {
2006 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2007 loc.End = curToken.getEndLoc();
2008 consumeToken();
2009 } while (curToken.isAny(k1: Token::identifier, k2: Token::dot) ||
2010 curToken.isKeyword());
2011 return ast::OpNameDecl::create(ctx, name: ast::Name::create(ctx, name, location: loc));
2012}
2013
2014FailureOr<ast::OpNameDecl *>
2015Parser::parseWrappedOperationName(bool allowEmptyName) {
2016 if (!consumeIf(kind: Token::less))
2017 return ast::OpNameDecl::create(ctx, loc: SMRange());
2018
2019 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2020 if (failed(Result: opNameDecl))
2021 return failure();
2022
2023 if (failed(Result: parseToken(kind: Token::greater, msg: "expected `>` after operation name")))
2024 return failure();
2025 return opNameDecl;
2026}
2027
2028FailureOr<ast::Expr *>
2029Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2030 SMRange loc = curToken.getLoc();
2031 consumeToken(kind: Token::kw_op);
2032
2033 // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2034 // identifier.
2035 if (curToken.isNot(k: Token::less)) {
2036 resetToken(tokLoc: loc);
2037 return parseIdentifierExpr();
2038 }
2039
2040 // Parse the operation name. The name may be elided, in which case the
2041 // operation refers to "any" operation(i.e. a difference between `MyOp` and
2042 // `Operation*`). Operation names within a rewrite context must be named.
2043 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2044 FailureOr<ast::OpNameDecl *> opNameDecl =
2045 parseWrappedOperationName(allowEmptyName);
2046 if (failed(Result: opNameDecl))
2047 return failure();
2048 std::optional<StringRef> opName = (*opNameDecl)->getName();
2049
2050 // Functor used to create an implicit range variable, used for implicit "all"
2051 // operand or results variables.
2052 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2053 FailureOr<ast::VariableDecl *> rangeVar =
2054 defineVariableDecl(name: "_", nameLoc: loc, type, constraints: ast::ConstraintRef(cst, loc));
2055 assert(succeeded(rangeVar) && "expected range variable to be valid");
2056 return ast::DeclRefExpr::create(ctx, loc, decl: *rangeVar, type);
2057 };
2058
2059 // Check for the optional list of operands.
2060 SmallVector<ast::Expr *> operands;
2061 if (!consumeIf(kind: Token::l_paren)) {
2062 // If the operand list isn't specified and we are in a match context, define
2063 // an inplace unconstrained operand range corresponding to all of the
2064 // operands of the operation. This avoids treating zero operands the same
2065 // way as "unconstrained operands".
2066 if (parserContext != ParserContext::Rewrite) {
2067 operands.push_back(Elt: createImplicitRangeVar(
2068 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2069 }
2070 } else if (!consumeIf(kind: Token::r_paren)) {
2071 // If the operand list was specified and non-empty, parse the operands.
2072 do {
2073 // Check for operand signature code completion.
2074 if (curToken.is(k: Token::code_complete)) {
2075 codeCompleteOperationOperandsSignature(opName, currentNumOperands: operands.size());
2076 return failure();
2077 }
2078
2079 FailureOr<ast::Expr *> operand = parseExpr();
2080 if (failed(Result: operand))
2081 return failure();
2082 operands.push_back(Elt: *operand);
2083 } while (consumeIf(kind: Token::comma));
2084
2085 if (failed(Result: parseToken(kind: Token::r_paren,
2086 msg: "expected `)` after operation operand list")))
2087 return failure();
2088 }
2089
2090 // Check for the optional list of attributes.
2091 SmallVector<ast::NamedAttributeDecl *> attributes;
2092 if (consumeIf(kind: Token::l_brace)) {
2093 do {
2094 FailureOr<ast::NamedAttributeDecl *> decl =
2095 parseNamedAttributeDecl(parentOpName: opName);
2096 if (failed(Result: decl))
2097 return failure();
2098 attributes.emplace_back(Args&: *decl);
2099 } while (consumeIf(kind: Token::comma));
2100
2101 if (failed(Result: parseToken(kind: Token::r_brace,
2102 msg: "expected `}` after operation attribute list")))
2103 return failure();
2104 }
2105
2106 // Handle the result types of the operation.
2107 SmallVector<ast::Expr *> resultTypes;
2108 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2109
2110 // Check for an explicit list of result types.
2111 if (consumeIf(kind: Token::arrow)) {
2112 if (failed(Result: parseToken(kind: Token::l_paren,
2113 msg: "expected `(` before operation result type list")))
2114 return failure();
2115
2116 // If result types are provided, initially assume that the operation does
2117 // not rely on type inferrence. We don't assert that it isn't, because we
2118 // may be inferring the value of some type/type range variables, but given
2119 // that these variables may be defined in calls we can't always discern when
2120 // this is the case.
2121 resultTypeContext = OpResultTypeContext::Explicit;
2122
2123 // Handle the case of an empty result list.
2124 if (!consumeIf(kind: Token::r_paren)) {
2125 do {
2126 // Check for result signature code completion.
2127 if (curToken.is(k: Token::code_complete)) {
2128 codeCompleteOperationResultsSignature(opName, currentNumResults: resultTypes.size());
2129 return failure();
2130 }
2131
2132 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2133 if (failed(Result: resultTypeExpr))
2134 return failure();
2135 resultTypes.push_back(Elt: *resultTypeExpr);
2136 } while (consumeIf(kind: Token::comma));
2137
2138 if (failed(Result: parseToken(kind: Token::r_paren,
2139 msg: "expected `)` after operation result type list")))
2140 return failure();
2141 }
2142 } else if (parserContext != ParserContext::Rewrite) {
2143 // If the result list isn't specified and we are in a match context, define
2144 // an inplace unconstrained result range corresponding to all of the results
2145 // of the operation. This avoids treating zero results the same way as
2146 // "unconstrained results".
2147 resultTypes.push_back(Elt: createImplicitRangeVar(
2148 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2149 } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2150 // If the result list isn't specified and we are in a rewrite, try to infer
2151 // them at runtime instead.
2152 resultTypeContext = OpResultTypeContext::Interface;
2153 }
2154
2155 return createOperationExpr(loc, name: *opNameDecl, resultTypeContext, operands,
2156 attributes, results&: resultTypes);
2157}
2158
2159FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2160 SMRange loc = curToken.getLoc();
2161 consumeToken(kind: Token::l_paren);
2162
2163 DenseMap<StringRef, SMRange> usedNames;
2164 SmallVector<StringRef> elementNames;
2165 SmallVector<ast::Expr *> elements;
2166 if (curToken.isNot(k: Token::r_paren)) {
2167 do {
2168 // Check for the optional element name assignment before the value.
2169 StringRef elementName;
2170 if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) {
2171 Token elementNameTok = curToken;
2172 consumeToken();
2173
2174 // The element name is only present if followed by an `=`.
2175 if (consumeIf(kind: Token::equal)) {
2176 elementName = elementNameTok.getSpelling();
2177
2178 // Check to see if this name is already used.
2179 auto elementNameIt =
2180 usedNames.try_emplace(Key: elementName, Args: elementNameTok.getLoc());
2181 if (!elementNameIt.second) {
2182 return emitErrorAndNote(
2183 loc: elementNameTok.getLoc(),
2184 msg: llvm::formatv(Fmt: "duplicate tuple element label `{0}`",
2185 Vals&: elementName),
2186 noteLoc: elementNameIt.first->getSecond(),
2187 note: "see previous label use here");
2188 }
2189 } else {
2190 // Otherwise, we treat this as part of an expression so reset the
2191 // lexer.
2192 resetToken(tokLoc: elementNameTok.getLoc());
2193 }
2194 }
2195 elementNames.push_back(Elt: elementName);
2196
2197 // Parse the tuple element value.
2198 FailureOr<ast::Expr *> element = parseExpr();
2199 if (failed(Result: element))
2200 return failure();
2201 elements.push_back(Elt: *element);
2202 } while (consumeIf(kind: Token::comma));
2203 }
2204 loc.End = curToken.getEndLoc();
2205 if (failed(
2206 Result: parseToken(kind: Token::r_paren, msg: "expected `)` after tuple element list")))
2207 return failure();
2208 return createTupleExpr(loc, elements, elementNames);
2209}
2210
2211FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2212 SMRange loc = curToken.getLoc();
2213 consumeToken(kind: Token::kw_type);
2214
2215 // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2216 // identifier.
2217 if (!consumeIf(kind: Token::less)) {
2218 resetToken(tokLoc: loc);
2219 return parseIdentifierExpr();
2220 }
2221
2222 if (!curToken.isString())
2223 return emitError(msg: "expected string literal containing MLIR type");
2224 std::string attrExpr = curToken.getStringValue();
2225 consumeToken();
2226
2227 loc.End = curToken.getEndLoc();
2228 if (failed(Result: parseToken(kind: Token::greater, msg: "expected `>` after type literal")))
2229 return failure();
2230 return ast::TypeExpr::create(ctx, loc, value: attrExpr);
2231}
2232
2233FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2234 StringRef name = curToken.getSpelling();
2235 SMRange nameLoc = curToken.getLoc();
2236 consumeToken(kind: Token::underscore);
2237
2238 // Underscore expressions require a constraint list.
2239 if (failed(Result: parseToken(kind: Token::colon, msg: "expected `:` after `_` variable")))
2240 return failure();
2241
2242 // Parse the constraints for the expression.
2243 SmallVector<ast::ConstraintRef> constraints;
2244 if (failed(Result: parseVariableDeclConstraintList(constraints)))
2245 return failure();
2246
2247 ast::Type type;
2248 if (failed(Result: validateVariableConstraints(constraints, inferredType&: type)))
2249 return failure();
2250 return createInlineVariableExpr(type, name, loc: nameLoc, constraints);
2251}
2252
2253//===----------------------------------------------------------------------===//
2254// Stmts
2255//===----------------------------------------------------------------------===//
2256
2257FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2258 FailureOr<ast::Stmt *> stmt;
2259 switch (curToken.getKind()) {
2260 case Token::kw_erase:
2261 stmt = parseEraseStmt();
2262 break;
2263 case Token::kw_let:
2264 stmt = parseLetStmt();
2265 break;
2266 case Token::kw_replace:
2267 stmt = parseReplaceStmt();
2268 break;
2269 case Token::kw_return:
2270 stmt = parseReturnStmt();
2271 break;
2272 case Token::kw_rewrite:
2273 stmt = parseRewriteStmt();
2274 break;
2275 default:
2276 stmt = parseExpr();
2277 break;
2278 }
2279 if (failed(Result: stmt) ||
2280 (expectTerminalSemicolon &&
2281 failed(Result: parseToken(kind: Token::semicolon, msg: "expected `;` after statement"))))
2282 return failure();
2283 return stmt;
2284}
2285
2286FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2287 SMLoc startLoc = curToken.getStartLoc();
2288 consumeToken(kind: Token::l_brace);
2289
2290 // Push a new block scope and parse any nested statements.
2291 pushDeclScope();
2292 SmallVector<ast::Stmt *> statements;
2293 while (curToken.isNot(k: Token::r_brace)) {
2294 FailureOr<ast::Stmt *> statement = parseStmt();
2295 if (failed(Result: statement))
2296 return popDeclScope(), failure();
2297 statements.push_back(Elt: *statement);
2298 }
2299 popDeclScope();
2300
2301 // Consume the end brace.
2302 SMRange location(startLoc, curToken.getEndLoc());
2303 consumeToken(kind: Token::r_brace);
2304
2305 return ast::CompoundStmt::create(ctx, location, children: statements);
2306}
2307
2308FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2309 if (parserContext == ParserContext::Constraint)
2310 return emitError(msg: "`erase` cannot be used within a Constraint");
2311 SMRange loc = curToken.getLoc();
2312 consumeToken(kind: Token::kw_erase);
2313
2314 // Parse the root operation expression.
2315 FailureOr<ast::Expr *> rootOp = parseExpr();
2316 if (failed(Result: rootOp))
2317 return failure();
2318
2319 return createEraseStmt(loc, rootOp: *rootOp);
2320}
2321
2322FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2323 SMRange loc = curToken.getLoc();
2324 consumeToken(kind: Token::kw_let);
2325
2326 // Parse the name of the new variable.
2327 SMRange varLoc = curToken.getLoc();
2328 if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword()) {
2329 // `_` is a reserved variable name.
2330 if (curToken.is(k: Token::underscore)) {
2331 return emitError(loc: varLoc,
2332 msg: "`_` may only be used to define \"inline\" variables");
2333 }
2334 return emitError(loc: varLoc,
2335 msg: "expected identifier after `let` to name a new variable");
2336 }
2337 StringRef varName = curToken.getSpelling();
2338 consumeToken();
2339
2340 // Parse the optional set of constraints.
2341 SmallVector<ast::ConstraintRef> constraints;
2342 if (consumeIf(kind: Token::colon) &&
2343 failed(Result: parseVariableDeclConstraintList(constraints)))
2344 return failure();
2345
2346 // Parse the optional initializer expression.
2347 ast::Expr *initializer = nullptr;
2348 if (consumeIf(kind: Token::equal)) {
2349 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2350 if (failed(Result: initOrFailure))
2351 return failure();
2352 initializer = *initOrFailure;
2353
2354 // Check that the constraints are compatible with having an initializer,
2355 // e.g. type constraints cannot be used with initializers.
2356 for (ast::ConstraintRef constraint : constraints) {
2357 LogicalResult result =
2358 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2359 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2360 ast::ValueRangeConstraintDecl>(caseFn: [&](const auto *cst) {
2361 if (cst->getTypeExpr()) {
2362 return this->emitError(
2363 loc: constraint.referenceLoc,
2364 msg: "type constraints are not permitted on variables with "
2365 "initializers");
2366 }
2367 return success();
2368 })
2369 .Default(defaultResult: success());
2370 if (failed(Result: result))
2371 return failure();
2372 }
2373 }
2374
2375 FailureOr<ast::VariableDecl *> varDecl =
2376 createVariableDecl(name: varName, loc: varLoc, initializer, constraints);
2377 if (failed(Result: varDecl))
2378 return failure();
2379 return ast::LetStmt::create(ctx, loc, varDecl: *varDecl);
2380}
2381
2382FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2383 if (parserContext == ParserContext::Constraint)
2384 return emitError(msg: "`replace` cannot be used within a Constraint");
2385 SMRange loc = curToken.getLoc();
2386 consumeToken(kind: Token::kw_replace);
2387
2388 // Parse the root operation expression.
2389 FailureOr<ast::Expr *> rootOp = parseExpr();
2390 if (failed(Result: rootOp))
2391 return failure();
2392
2393 if (failed(
2394 Result: parseToken(kind: Token::kw_with, msg: "expected `with` after root operation")))
2395 return failure();
2396
2397 // The replacement portion of this statement is within a rewrite context.
2398 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2399
2400 // Parse the replacement values.
2401 SmallVector<ast::Expr *> replValues;
2402 if (consumeIf(kind: Token::l_paren)) {
2403 if (consumeIf(kind: Token::r_paren)) {
2404 return emitError(
2405 loc, msg: "expected at least one replacement value, consider using "
2406 "`erase` if no replacement values are desired");
2407 }
2408
2409 do {
2410 FailureOr<ast::Expr *> replExpr = parseExpr();
2411 if (failed(Result: replExpr))
2412 return failure();
2413 replValues.emplace_back(Args&: *replExpr);
2414 } while (consumeIf(kind: Token::comma));
2415
2416 if (failed(Result: parseToken(kind: Token::r_paren,
2417 msg: "expected `)` after replacement values")))
2418 return failure();
2419 } else {
2420 // Handle replacement with an operation uniquely, as the replacement
2421 // operation supports type inferrence from the root operation.
2422 FailureOr<ast::Expr *> replExpr;
2423 if (curToken.is(k: Token::kw_op))
2424 replExpr = parseOperationExpr(inputResultTypeContext: OpResultTypeContext::Replacement);
2425 else
2426 replExpr = parseExpr();
2427 if (failed(Result: replExpr))
2428 return failure();
2429 replValues.emplace_back(Args&: *replExpr);
2430 }
2431
2432 return createReplaceStmt(loc, rootOp: *rootOp, replValues);
2433}
2434
2435FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2436 SMRange loc = curToken.getLoc();
2437 consumeToken(kind: Token::kw_return);
2438
2439 // Parse the result value.
2440 FailureOr<ast::Expr *> resultExpr = parseExpr();
2441 if (failed(Result: resultExpr))
2442 return failure();
2443
2444 return ast::ReturnStmt::create(ctx, loc, resultExpr: *resultExpr);
2445}
2446
2447FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2448 if (parserContext == ParserContext::Constraint)
2449 return emitError(msg: "`rewrite` cannot be used within a Constraint");
2450 SMRange loc = curToken.getLoc();
2451 consumeToken(kind: Token::kw_rewrite);
2452
2453 // Parse the root operation.
2454 FailureOr<ast::Expr *> rootOp = parseExpr();
2455 if (failed(Result: rootOp))
2456 return failure();
2457
2458 if (failed(Result: parseToken(kind: Token::kw_with, msg: "expected `with` before rewrite body")))
2459 return failure();
2460
2461 if (curToken.isNot(k: Token::l_brace))
2462 return emitError(msg: "expected `{` to start rewrite body");
2463
2464 // The rewrite body of this statement is within a rewrite context.
2465 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2466
2467 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2468 if (failed(Result: rewriteBody))
2469 return failure();
2470
2471 // Verify the rewrite body.
2472 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2473 if (isa<ast::ReturnStmt>(Val: stmt)) {
2474 return emitError(loc: stmt->getLoc(),
2475 msg: "`return` statements are only permitted within a "
2476 "`Constraint` or `Rewrite` body");
2477 }
2478 }
2479
2480 return createRewriteStmt(loc, rootOp: *rootOp, rewriteBody: *rewriteBody);
2481}
2482
2483//===----------------------------------------------------------------------===//
2484// Creation+Analysis
2485//===----------------------------------------------------------------------===//
2486
2487//===----------------------------------------------------------------------===//
2488// Decls
2489//===----------------------------------------------------------------------===//
2490
2491ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2492 // Unwrap reference expressions.
2493 if (auto *init = dyn_cast<ast::DeclRefExpr>(Val: node))
2494 node = init->getDecl();
2495 return dyn_cast<ast::CallableDecl>(Val: node);
2496}
2497
2498FailureOr<ast::PatternDecl *>
2499Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2500 const ParsedPatternMetadata &metadata,
2501 ast::CompoundStmt *body) {
2502 return ast::PatternDecl::create(ctx, location: loc, name, benefit: metadata.benefit,
2503 hasBoundedRecursion: metadata.hasBoundedRecursion, body);
2504}
2505
2506ast::Type Parser::createUserConstraintRewriteResultType(
2507 ArrayRef<ast::VariableDecl *> results) {
2508 // Single result decls use the type of the single result.
2509 if (results.size() == 1)
2510 return results[0]->getType();
2511
2512 // Multiple results use a tuple type, with the types and names grabbed from
2513 // the result variable decls.
2514 auto resultTypes = llvm::map_range(
2515 C&: results, F: [&](const auto *result) { return result->getType(); });
2516 auto resultNames = llvm::map_range(
2517 C&: results, F: [&](const auto *result) { return result->getName().getName(); });
2518 return ast::TupleType::get(context&: ctx, elementTypes: llvm::to_vector(Range&: resultTypes),
2519 elementNames: llvm::to_vector(Range&: resultNames));
2520}
2521
2522template <typename T>
2523FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2524 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2525 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2526 ast::CompoundStmt *body) {
2527 if (!body->getChildren().empty()) {
2528 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(Val: body->getChildren().back())) {
2529 ast::Expr *resultExpr = retStmt->getResultExpr();
2530
2531 // Process the result of the decl. If no explicit signature results
2532 // were provided, check for return type inference. Otherwise, check that
2533 // the return expression can be converted to the expected type.
2534 if (results.empty())
2535 resultType = resultExpr->getType();
2536 else if (failed(Result: convertExpressionTo(expr&: resultExpr, type: resultType)))
2537 return failure();
2538 else
2539 retStmt->setResultExpr(resultExpr);
2540 }
2541 }
2542 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2543}
2544
2545FailureOr<ast::VariableDecl *>
2546Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2547 ArrayRef<ast::ConstraintRef> constraints) {
2548 // The type of the variable, which is expected to be inferred by either a
2549 // constraint or an initializer expression.
2550 ast::Type type;
2551 if (failed(Result: validateVariableConstraints(constraints, inferredType&: type)))
2552 return failure();
2553
2554 if (initializer) {
2555 // Update the variable type based on the initializer, or try to convert the
2556 // initializer to the existing type.
2557 if (!type)
2558 type = initializer->getType();
2559 else if (ast::Type mergedType = type.refineWith(other: initializer->getType()))
2560 type = mergedType;
2561 else if (failed(Result: convertExpressionTo(expr&: initializer, type)))
2562 return failure();
2563
2564 // Otherwise, if there is no initializer check that the type has already
2565 // been resolved from the constraint list.
2566 } else if (!type) {
2567 return emitErrorAndNote(
2568 loc, msg: "unable to infer type for variable `" + name + "`", noteLoc: loc,
2569 note: "the type of a variable must be inferable from the constraint "
2570 "list or the initializer");
2571 }
2572
2573 // Constraint types cannot be used when defining variables.
2574 if (isa<ast::ConstraintType, ast::RewriteType>(Val: type)) {
2575 return emitError(
2576 loc, msg: llvm::formatv(Fmt: "unable to define variable of `{0}` type", Vals&: type));
2577 }
2578
2579 // Try to define a variable with the given name.
2580 FailureOr<ast::VariableDecl *> varDecl =
2581 defineVariableDecl(name, nameLoc: loc, type, initExpr: initializer, constraints);
2582 if (failed(Result: varDecl))
2583 return failure();
2584
2585 return *varDecl;
2586}
2587
2588FailureOr<ast::VariableDecl *>
2589Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2590 const ast::ConstraintRef &constraint) {
2591 ast::Type argType;
2592 if (failed(Result: validateVariableConstraint(ref: constraint, inferredType&: argType)))
2593 return failure();
2594 return defineVariableDecl(name, nameLoc: loc, type: argType, constraints: constraint);
2595}
2596
2597LogicalResult
2598Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2599 ast::Type &inferredType) {
2600 for (const ast::ConstraintRef &ref : constraints)
2601 if (failed(Result: validateVariableConstraint(ref, inferredType)))
2602 return failure();
2603 return success();
2604}
2605
2606LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2607 ast::Type &inferredType) {
2608 ast::Type constraintType;
2609 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(Val: ref.constraint)) {
2610 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2611 if (failed(Result: validateTypeConstraintExpr(typeExpr)))
2612 return failure();
2613 }
2614 constraintType = ast::AttributeType::get(context&: ctx);
2615 } else if (const auto *cst =
2616 dyn_cast<ast::OpConstraintDecl>(Val: ref.constraint)) {
2617 constraintType = ast::OperationType::get(
2618 context&: ctx, name: cst->getName(), odsOp: lookupODSOperation(opName: cst->getName()));
2619 } else if (isa<ast::TypeConstraintDecl>(Val: ref.constraint)) {
2620 constraintType = typeTy;
2621 } else if (isa<ast::TypeRangeConstraintDecl>(Val: ref.constraint)) {
2622 constraintType = typeRangeTy;
2623 } else if (const auto *cst =
2624 dyn_cast<ast::ValueConstraintDecl>(Val: ref.constraint)) {
2625 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2626 if (failed(Result: validateTypeConstraintExpr(typeExpr)))
2627 return failure();
2628 }
2629 constraintType = valueTy;
2630 } else if (const auto *cst =
2631 dyn_cast<ast::ValueRangeConstraintDecl>(Val: ref.constraint)) {
2632 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2633 if (failed(Result: validateTypeRangeConstraintExpr(typeExpr)))
2634 return failure();
2635 }
2636 constraintType = valueRangeTy;
2637 } else if (const auto *cst =
2638 dyn_cast<ast::UserConstraintDecl>(Val: ref.constraint)) {
2639 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2640 if (inputs.size() != 1) {
2641 return emitErrorAndNote(loc: ref.referenceLoc,
2642 msg: "`Constraint`s applied via a variable constraint "
2643 "list must take a single input, but got " +
2644 Twine(inputs.size()),
2645 noteLoc: cst->getLoc(),
2646 note: "see definition of constraint here");
2647 }
2648 constraintType = inputs.front()->getType();
2649 } else {
2650 llvm_unreachable("unknown constraint type");
2651 }
2652
2653 // Check that the constraint type is compatible with the current inferred
2654 // type.
2655 if (!inferredType) {
2656 inferredType = constraintType;
2657 } else if (ast::Type mergedTy = inferredType.refineWith(other: constraintType)) {
2658 inferredType = mergedTy;
2659 } else {
2660 return emitError(loc: ref.referenceLoc,
2661 msg: llvm::formatv(Fmt: "constraint type `{0}` is incompatible "
2662 "with the previously inferred type `{1}`",
2663 Vals&: constraintType, Vals&: inferredType));
2664 }
2665 return success();
2666}
2667
2668LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2669 ast::Type typeExprType = typeExpr->getType();
2670 if (typeExprType != typeTy) {
2671 return emitError(loc: typeExpr->getLoc(),
2672 msg: "expected expression of `Type` in type constraint");
2673 }
2674 return success();
2675}
2676
2677LogicalResult
2678Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2679 ast::Type typeExprType = typeExpr->getType();
2680 if (typeExprType != typeRangeTy) {
2681 return emitError(loc: typeExpr->getLoc(),
2682 msg: "expected expression of `TypeRange` in type constraint");
2683 }
2684 return success();
2685}
2686
2687//===----------------------------------------------------------------------===//
2688// Exprs
2689//===----------------------------------------------------------------------===//
2690
2691FailureOr<ast::CallExpr *>
2692Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2693 MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2694 ast::Type parentType = parentExpr->getType();
2695
2696 ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parentExpr);
2697 if (!callableDecl) {
2698 return emitError(loc,
2699 msg: llvm::formatv(Fmt: "expected a reference to a callable "
2700 "`Constraint` or `Rewrite`, but got: `{0}`",
2701 Vals&: parentType));
2702 }
2703 if (parserContext == ParserContext::Rewrite) {
2704 if (isa<ast::UserConstraintDecl>(Val: callableDecl))
2705 return emitError(
2706 loc, msg: "unable to invoke `Constraint` within a rewrite section");
2707 if (isNegated)
2708 return emitError(loc, msg: "unable to negate a Rewrite");
2709 } else {
2710 if (isa<ast::UserRewriteDecl>(Val: callableDecl))
2711 return emitError(loc,
2712 msg: "unable to invoke `Rewrite` within a match section");
2713 if (isNegated && cast<ast::UserConstraintDecl>(Val: callableDecl)->getBody())
2714 return emitError(loc, msg: "unable to negate non native constraints");
2715 }
2716
2717 // Verify the arguments of the call.
2718 /// Handle size mismatch.
2719 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2720 if (callArgs.size() != arguments.size()) {
2721 return emitErrorAndNote(
2722 loc,
2723 msg: llvm::formatv(Fmt: "invalid number of arguments for {0} call; expected "
2724 "{1}, but got {2}",
2725 Vals: callableDecl->getCallableType(), Vals: callArgs.size(),
2726 Vals: arguments.size()),
2727 noteLoc: callableDecl->getLoc(),
2728 note: llvm::formatv(Fmt: "see the definition of {0} here",
2729 Vals: callableDecl->getName()->getName()));
2730 }
2731
2732 /// Handle argument type mismatch.
2733 auto attachDiagFn = [&](ast::Diagnostic &diag) {
2734 diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here",
2735 Vals: callableDecl->getName()->getName()),
2736 noteLoc: callableDecl->getLoc());
2737 };
2738 for (auto it : llvm::zip(t&: callArgs, u&: arguments)) {
2739 if (failed(Result: convertExpressionTo(expr&: std::get<1>(t&: it), type: std::get<0>(t&: it)->getType(),
2740 noteAttachFn: attachDiagFn)))
2741 return failure();
2742 }
2743
2744 return ast::CallExpr::create(ctx, loc, callable: parentExpr, arguments,
2745 resultType: callableDecl->getResultType(), isNegated);
2746}
2747
2748FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2749 ast::Decl *decl) {
2750 // Check the type of decl being referenced.
2751 ast::Type declType;
2752 if (isa<ast::ConstraintDecl>(Val: decl))
2753 declType = ast::ConstraintType::get(context&: ctx);
2754 else if (isa<ast::UserRewriteDecl>(Val: decl))
2755 declType = ast::RewriteType::get(context&: ctx);
2756 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl))
2757 declType = varDecl->getType();
2758 else
2759 return emitError(loc, msg: "invalid reference to `" +
2760 decl->getName()->getName() + "`");
2761
2762 return ast::DeclRefExpr::create(ctx, loc, decl, type: declType);
2763}
2764
2765FailureOr<ast::DeclRefExpr *>
2766Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2767 ArrayRef<ast::ConstraintRef> constraints) {
2768 FailureOr<ast::VariableDecl *> decl =
2769 defineVariableDecl(name, nameLoc: loc, type, constraints);
2770 if (failed(Result: decl))
2771 return failure();
2772 return ast::DeclRefExpr::create(ctx, loc, decl: *decl, type);
2773}
2774
2775FailureOr<ast::MemberAccessExpr *>
2776Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2777 SMRange loc) {
2778 // Validate the member name for the given parent expression.
2779 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2780 if (failed(Result: memberType))
2781 return failure();
2782
2783 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, memberName: name, type: *memberType);
2784}
2785
2786FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2787 StringRef name, SMRange loc) {
2788 ast::Type parentType = parentExpr->getType();
2789 if (ast::OperationType opType = dyn_cast<ast::OperationType>(Val&: parentType)) {
2790 if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2791 return valueRangeTy;
2792
2793 // Verify member access based on the operation type.
2794 if (const ods::Operation *odsOp = opType.getODSOperation()) {
2795 auto results = odsOp->getResults();
2796
2797 // Handle indexed results.
2798 unsigned index = 0;
2799 if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) &&
2800 index < results.size()) {
2801 return results[index].isVariadic() ? valueRangeTy : valueTy;
2802 }
2803
2804 // Handle named results.
2805 const auto *it = llvm::find_if(Range&: results, P: [&](const auto &result) {
2806 return result.getName() == name;
2807 });
2808 if (it != results.end())
2809 return it->isVariadic() ? valueRangeTy : valueTy;
2810 } else if (llvm::isDigit(C: name[0])) {
2811 // Allow unchecked numeric indexing of the results of unregistered
2812 // operations. It returns a single value.
2813 return valueTy;
2814 }
2815 } else if (auto tupleType = dyn_cast<ast::TupleType>(Val&: parentType)) {
2816 // Handle indexed results.
2817 unsigned index = 0;
2818 if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) &&
2819 index < tupleType.size()) {
2820 return tupleType.getElementTypes()[index];
2821 }
2822
2823 // Handle named results.
2824 auto elementNames = tupleType.getElementNames();
2825 const auto *it = llvm::find(Range&: elementNames, Val: name);
2826 if (it != elementNames.end())
2827 return tupleType.getElementTypes()[it - elementNames.begin()];
2828 }
2829 return emitError(
2830 loc,
2831 msg: llvm::formatv(Fmt: "invalid member access `{0}` on expression of type `{1}`",
2832 Vals&: name, Vals&: parentType));
2833}
2834
2835FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2836 SMRange loc, const ast::OpNameDecl *name,
2837 OpResultTypeContext resultTypeContext,
2838 SmallVectorImpl<ast::Expr *> &operands,
2839 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2840 SmallVectorImpl<ast::Expr *> &results) {
2841 std::optional<StringRef> opNameRef = name->getName();
2842 const ods::Operation *odsOp = lookupODSOperation(opName: opNameRef);
2843
2844 // Verify the inputs operands.
2845 if (failed(Result: validateOperationOperands(loc, name: opNameRef, odsOp, operands)))
2846 return failure();
2847
2848 // Verify the attribute list.
2849 for (ast::NamedAttributeDecl *attr : attributes) {
2850 // Check for an attribute type, or a type awaiting resolution.
2851 ast::Type attrType = attr->getValue()->getType();
2852 if (!isa<ast::AttributeType>(Val: attrType)) {
2853 return emitError(
2854 loc: attr->getValue()->getLoc(),
2855 msg: llvm::formatv(Fmt: "expected `Attr` expression, but got `{0}`", Vals&: attrType));
2856 }
2857 }
2858
2859 assert(
2860 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2861 "unexpected inferrence when results were explicitly specified");
2862
2863 // If we aren't relying on type inferrence, or explicit results were provided,
2864 // validate them.
2865 if (resultTypeContext == OpResultTypeContext::Explicit) {
2866 if (failed(Result: validateOperationResults(loc, name: opNameRef, odsOp, results)))
2867 return failure();
2868
2869 // Validate the use of interface based type inferrence for this operation.
2870 } else if (resultTypeContext == OpResultTypeContext::Interface) {
2871 assert(opNameRef &&
2872 "expected valid operation name when inferring operation results");
2873 checkOperationResultTypeInferrence(loc, name: *opNameRef, odsOp);
2874 }
2875
2876 return ast::OperationExpr::create(ctx, loc, odsOp, nameDecl: name, operands, resultTypes: results,
2877 attributes);
2878}
2879
2880LogicalResult
2881Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2882 const ods::Operation *odsOp,
2883 SmallVectorImpl<ast::Expr *> &operands) {
2884 return validateOperationOperandsOrResults(
2885 groupName: "operand", loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2886 values&: operands, odsValues: odsOp ? odsOp->getOperands() : std::nullopt, singleTy: valueTy,
2887 rangeTy: valueRangeTy);
2888}
2889
2890LogicalResult
2891Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2892 const ods::Operation *odsOp,
2893 SmallVectorImpl<ast::Expr *> &results) {
2894 return validateOperationOperandsOrResults(
2895 groupName: "result", loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2896 values&: results, odsValues: odsOp ? odsOp->getResults() : std::nullopt, singleTy: typeTy, rangeTy: typeRangeTy);
2897}
2898
2899void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2900 const ods::Operation *odsOp) {
2901 // If the operation might not have inferrence support, emit a warning to the
2902 // user. We don't emit an error because the interface might be added to the
2903 // operation at runtime. It's rare, but it could still happen. We emit a
2904 // warning here instead.
2905
2906 // Handle inferrence warnings for unknown operations.
2907 if (!odsOp) {
2908 ctx.getDiagEngine().emitWarning(
2909 loc, msg: llvm::formatv(
2910 Fmt: "operation result types are marked to be inferred, but "
2911 "`{0}` is unknown. Ensure that `{0}` supports zero "
2912 "results or implements `InferTypeOpInterface`. Include "
2913 "the ODS definition of this operation to remove this warning.",
2914 Vals&: opName));
2915 return;
2916 }
2917
2918 // Handle inferrence warnings for known operations that expected at least one
2919 // result, but don't have inference support. An elided results list can mean
2920 // "zero-results", and we don't want to warn when that is the expected
2921 // behavior.
2922 bool requiresInferrence =
2923 llvm::any_of(Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) {
2924 return !result.isVariableLength();
2925 });
2926 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2927 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2928 loc,
2929 msg: llvm::formatv(Fmt: "operation result types are marked to be inferred, but "
2930 "`{0}` does not provide an implementation of "
2931 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2932 "`InferTypeOpInterface` at runtime, or add support to "
2933 "the ODS definition to remove this warning.",
2934 Vals&: opName));
2935 diag->attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: opName),
2936 noteLoc: odsOp->getLoc());
2937 return;
2938 }
2939}
2940
2941LogicalResult Parser::validateOperationOperandsOrResults(
2942 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2943 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2944 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2945 ast::RangeType rangeTy) {
2946 // All operation types accept a single range parameter.
2947 if (values.size() == 1) {
2948 if (failed(Result: convertExpressionTo(expr&: values[0], type: rangeTy)))
2949 return failure();
2950 return success();
2951 }
2952
2953 /// If the operation has ODS information, we can more accurately verify the
2954 /// values.
2955 if (odsOpLoc) {
2956 auto emitSizeMismatchError = [&] {
2957 return emitErrorAndNote(
2958 loc,
2959 msg: llvm::formatv(Fmt: "invalid number of {0} groups for `{1}`; expected "
2960 "{2}, but got {3}",
2961 Vals&: groupName, Vals&: *name, Vals: odsValues.size(), Vals: values.size()),
2962 noteLoc: *odsOpLoc, note: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: *name));
2963 };
2964
2965 // Handle the case where no values were provided.
2966 if (values.empty()) {
2967 // If we don't expect any on the ODS side, we are done.
2968 if (odsValues.empty())
2969 return success();
2970
2971 // If we do, check if we actually need to provide values (i.e. if any of
2972 // the values are actually required).
2973 unsigned numVariadic = 0;
2974 for (const auto &odsValue : odsValues) {
2975 if (!odsValue.isVariableLength())
2976 return emitSizeMismatchError();
2977 ++numVariadic;
2978 }
2979
2980 // If we are in a non-rewrite context, we don't need to do anything more.
2981 // Zero-values is a valid constraint on the operation.
2982 if (parserContext != ParserContext::Rewrite)
2983 return success();
2984
2985 // Otherwise, when in a rewrite we may need to provide values to match the
2986 // ODS signature of the operation to create.
2987
2988 // If we only have one variadic value, just use an empty list.
2989 if (numVariadic == 1)
2990 return success();
2991
2992 // Otherwise, create dummy values for each of the entries so that we
2993 // adhere to the ODS signature.
2994 for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2995 values.push_back(Elt: ast::RangeExpr::create(
2996 ctx, loc, /*elements=*/std::nullopt, type: rangeTy));
2997 }
2998 return success();
2999 }
3000
3001 // Verify that the number of values provided matches the number of value
3002 // groups ODS expects.
3003 if (odsValues.size() != values.size())
3004 return emitSizeMismatchError();
3005
3006 auto diagFn = [&](ast::Diagnostic &diag) {
3007 diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: *name),
3008 noteLoc: *odsOpLoc);
3009 };
3010 for (unsigned i = 0, e = values.size(); i < e; ++i) {
3011 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3012 if (failed(Result: convertExpressionTo(expr&: values[i], type: expectedType, noteAttachFn: diagFn)))
3013 return failure();
3014 }
3015 return success();
3016 }
3017
3018 // Otherwise, accept the value groups as they have been defined and just
3019 // ensure they are one of the expected types.
3020 for (ast::Expr *&valueExpr : values) {
3021 ast::Type valueExprType = valueExpr->getType();
3022
3023 // Check if this is one of the expected types.
3024 if (valueExprType == rangeTy || valueExprType == singleTy)
3025 continue;
3026
3027 // If the operand is an Operation, allow converting to a Value or
3028 // ValueRange. This situations arises quite often with nested operation
3029 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3030 if (singleTy == valueTy) {
3031 if (isa<ast::OperationType>(Val: valueExprType)) {
3032 valueExpr = convertOpToValue(opExpr: valueExpr);
3033 continue;
3034 }
3035 }
3036
3037 // Otherwise, try to convert the expression to a range.
3038 if (succeeded(Result: convertExpressionTo(expr&: valueExpr, type: rangeTy)))
3039 continue;
3040
3041 return emitError(
3042 loc: valueExpr->getLoc(),
3043 msg: llvm::formatv(
3044 Fmt: "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3045 Vals&: singleTy, Vals&: rangeTy, Vals&: valueExprType));
3046 }
3047 return success();
3048}
3049
3050FailureOr<ast::TupleExpr *>
3051Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3052 ArrayRef<StringRef> elementNames) {
3053 for (const ast::Expr *element : elements) {
3054 ast::Type eleTy = element->getType();
3055 if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(Val: eleTy)) {
3056 return emitError(
3057 loc: element->getLoc(),
3058 msg: llvm::formatv(Fmt: "unable to build a tuple with `{0}` element", Vals&: eleTy));
3059 }
3060 }
3061 return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3062}
3063
3064//===----------------------------------------------------------------------===//
3065// Stmts
3066//===----------------------------------------------------------------------===//
3067
3068FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3069 ast::Expr *rootOp) {
3070 // Check that root is an Operation.
3071 ast::Type rootType = rootOp->getType();
3072 if (!isa<ast::OperationType>(Val: rootType))
3073 return emitError(loc: rootOp->getLoc(), msg: "expected `Op` expression");
3074
3075 return ast::EraseStmt::create(ctx, loc, rootOp);
3076}
3077
3078FailureOr<ast::ReplaceStmt *>
3079Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3080 MutableArrayRef<ast::Expr *> replValues) {
3081 // Check that root is an Operation.
3082 ast::Type rootType = rootOp->getType();
3083 if (!isa<ast::OperationType>(Val: rootType)) {
3084 return emitError(
3085 loc: rootOp->getLoc(),
3086 msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`", Vals&: rootType));
3087 }
3088
3089 // If there are multiple replacement values, we implicitly convert any Op
3090 // expressions to the value form.
3091 bool shouldConvertOpToValues = replValues.size() > 1;
3092 for (ast::Expr *&replExpr : replValues) {
3093 ast::Type replType = replExpr->getType();
3094
3095 // Check that replExpr is an Operation, Value, or ValueRange.
3096 if (isa<ast::OperationType>(Val: replType)) {
3097 if (shouldConvertOpToValues)
3098 replExpr = convertOpToValue(opExpr: replExpr);
3099 continue;
3100 }
3101
3102 if (replType != valueTy && replType != valueRangeTy) {
3103 return emitError(loc: replExpr->getLoc(),
3104 msg: llvm::formatv(Fmt: "expected `Op`, `Value` or `ValueRange` "
3105 "expression, but got `{0}`",
3106 Vals&: replType));
3107 }
3108 }
3109
3110 return ast::ReplaceStmt::create(ctx, loc, rootOp, replExprs: replValues);
3111}
3112
3113FailureOr<ast::RewriteStmt *>
3114Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3115 ast::CompoundStmt *rewriteBody) {
3116 // Check that root is an Operation.
3117 ast::Type rootType = rootOp->getType();
3118 if (!isa<ast::OperationType>(Val: rootType)) {
3119 return emitError(
3120 loc: rootOp->getLoc(),
3121 msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`", Vals&: rootType));
3122 }
3123
3124 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3125}
3126
3127//===----------------------------------------------------------------------===//
3128// Code Completion
3129//===----------------------------------------------------------------------===//
3130
3131LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3132 ast::Type parentType = parentExpr->getType();
3133 if (ast::OperationType opType = dyn_cast<ast::OperationType>(Val&: parentType))
3134 codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3135 else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(Val&: parentType))
3136 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3137 return failure();
3138}
3139
3140LogicalResult
3141Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3142 if (opName)
3143 codeCompleteContext->codeCompleteOperationAttributeName(opName: *opName);
3144 return failure();
3145}
3146
3147LogicalResult
3148Parser::codeCompleteConstraintName(ast::Type inferredType,
3149 bool allowInlineTypeConstraints) {
3150 codeCompleteContext->codeCompleteConstraintName(
3151 currentType: inferredType, allowInlineTypeConstraints, scope: curDeclScope);
3152 return failure();
3153}
3154
3155LogicalResult Parser::codeCompleteDialectName() {
3156 codeCompleteContext->codeCompleteDialectName();
3157 return failure();
3158}
3159
3160LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3161 codeCompleteContext->codeCompleteOperationName(dialectName);
3162 return failure();
3163}
3164
3165LogicalResult Parser::codeCompletePatternMetadata() {
3166 codeCompleteContext->codeCompletePatternMetadata();
3167 return failure();
3168}
3169
3170LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3171 codeCompleteContext->codeCompleteIncludeFilename(curPath);
3172 return failure();
3173}
3174
3175void Parser::codeCompleteCallSignature(ast::Node *parent,
3176 unsigned currentNumArgs) {
3177 ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parent);
3178 if (!callableDecl)
3179 return;
3180
3181 codeCompleteContext->codeCompleteCallSignature(callable: callableDecl, currentNumArgs);
3182}
3183
3184void Parser::codeCompleteOperationOperandsSignature(
3185 std::optional<StringRef> opName, unsigned currentNumOperands) {
3186 codeCompleteContext->codeCompleteOperationOperandsSignature(
3187 opName, currentNumOperands);
3188}
3189
3190void Parser::codeCompleteOperationResultsSignature(
3191 std::optional<StringRef> opName, unsigned currentNumResults) {
3192 codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3193 currentNumResults);
3194}
3195
3196//===----------------------------------------------------------------------===//
3197// Parser
3198//===----------------------------------------------------------------------===//
3199
3200FailureOr<ast::Module *>
3201mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3202 bool enableDocumentation,
3203 CodeCompleteContext *codeCompleteContext) {
3204 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3205 return parser.parseModule();
3206}
3207

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Tools/PDLL/Parser/Parser.cpp