1//===- Parser.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Tools/PDLL/Parser/Parser.h"
10#include "Lexer.h"
11#include "mlir/Support/IndentedOstream.h"
12#include "mlir/TableGen/Argument.h"
13#include "mlir/TableGen/Attribute.h"
14#include "mlir/TableGen/Constraint.h"
15#include "mlir/TableGen/Format.h"
16#include "mlir/TableGen/Operator.h"
17#include "mlir/Tools/PDLL/AST/Context.h"
18#include "mlir/Tools/PDLL/AST/Diagnostic.h"
19#include "mlir/Tools/PDLL/AST/Nodes.h"
20#include "mlir/Tools/PDLL/AST/Types.h"
21#include "mlir/Tools/PDLL/ODS/Constraint.h"
22#include "mlir/Tools/PDLL/ODS/Context.h"
23#include "mlir/Tools/PDLL/ODS/Operation.h"
24#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/ManagedStatic.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/ScopedPrinter.h"
31#include "llvm/TableGen/Error.h"
32#include "llvm/TableGen/Parser.h"
33#include <optional>
34#include <string>
35
36using namespace mlir;
37using namespace mlir::pdll;
38
39//===----------------------------------------------------------------------===//
40// Parser
41//===----------------------------------------------------------------------===//
42
43namespace {
44class Parser {
45public:
46 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47 bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50 typeTy(ast::TypeType::get(context&: ctx)), valueTy(ast::ValueType::get(context&: ctx)),
51 typeRangeTy(ast::TypeRangeType::get(context&: ctx)),
52 valueRangeTy(ast::ValueRangeType::get(context&: ctx)),
53 attrTy(ast::AttributeType::get(context&: ctx)),
54 codeCompleteContext(codeCompleteContext) {}
55
56 /// Try to parse a new module. Returns nullptr in the case of failure.
57 FailureOr<ast::Module *> parseModule();
58
59private:
60 /// The current context of the parser. It allows for the parser to know a bit
61 /// about the construct it is nested within during parsing. This is used
62 /// specifically to provide additional verification during parsing, e.g. to
63 /// prevent using rewrites within a match context, matcher constraints within
64 /// a rewrite section, etc.
65 enum class ParserContext {
66 /// The parser is in the global context.
67 Global,
68 /// The parser is currently within a Constraint, which disallows all types
69 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
70 Constraint,
71 /// The parser is currently within the matcher portion of a Pattern, which
72 /// is allows a terminal operation rewrite statement but no other rewrite
73 /// transformations.
74 PatternMatch,
75 /// The parser is currently within a Rewrite, which disallows calls to
76 /// constraints, requires operation expressions to have names, etc.
77 Rewrite,
78 };
79
80 /// The current specification context of an operations result type. This
81 /// indicates how the result types of an operation may be inferred.
82 enum class OpResultTypeContext {
83 /// The result types of the operation are not known to be inferred.
84 Explicit,
85 /// The result types of the operation are inferred from the root input of a
86 /// `replace` statement.
87 Replacement,
88 /// The result types of the operation are inferred by using the
89 /// `InferTypeOpInterface` interface provided by the operation.
90 Interface,
91 };
92
93 //===--------------------------------------------------------------------===//
94 // Parsing
95 //===--------------------------------------------------------------------===//
96
97 /// Push a new decl scope onto the lexer.
98 ast::DeclScope *pushDeclScope() {
99 ast::DeclScope *newScope =
100 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
101 return (curDeclScope = newScope);
102 }
103 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
104
105 /// Pop the last decl scope from the lexer.
106 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
107
108 /// Parse the body of an AST module.
109 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
110
111 /// Try to convert the given expression to `type`. Returns failure and emits
112 /// an error if a conversion is not viable. On failure, `noteAttachFn` is
113 /// invoked to attach notes to the emitted error diagnostic. On success,
114 /// `expr` is updated to the expression used to convert to `type`.
115 LogicalResult convertExpressionTo(
116 ast::Expr *&expr, ast::Type type,
117 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
118 LogicalResult
119 convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
120 ast::Type type,
121 function_ref<ast::InFlightDiagnostic()> emitErrorFn);
122 LogicalResult convertTupleExpressionTo(
123 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
124 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
125 function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
126
127 /// Given an operation expression, convert it to a Value or ValueRange
128 /// typed expression.
129 ast::Expr *convertOpToValue(const ast::Expr *opExpr);
130
131 /// Lookup ODS information for the given operation, returns nullptr if no
132 /// information is found.
133 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
134 return opName ? ctx.getODSContext().lookupOperation(name: *opName) : nullptr;
135 }
136
137 /// Process the given documentation string, or return an empty string if
138 /// documentation isn't enabled.
139 StringRef processDoc(StringRef doc) {
140 return enableDocumentation ? doc : StringRef();
141 }
142
143 /// Process the given documentation string and format it, or return an empty
144 /// string if documentation isn't enabled.
145 std::string processAndFormatDoc(const Twine &doc) {
146 if (!enableDocumentation)
147 return "";
148 std::string docStr;
149 {
150 llvm::raw_string_ostream docOS(docStr);
151 raw_indented_ostream(docOS).printReindented(
152 str: StringRef(docStr).rtrim(Chars: " \t"));
153 }
154 return docStr;
155 }
156
157 //===--------------------------------------------------------------------===//
158 // Directives
159
160 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
161 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
162 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
163 SmallVectorImpl<ast::Decl *> &decls);
164
165 /// Process the records of a parsed tablegen include file.
166 void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
167 SmallVectorImpl<ast::Decl *> &decls);
168
169 /// Create a user defined native constraint for a constraint imported from
170 /// ODS.
171 template <typename ConstraintT>
172 ast::Decl *
173 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
174 SMRange loc, ast::Type type,
175 StringRef nativeType, StringRef docString);
176 template <typename ConstraintT>
177 ast::Decl *
178 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
179 SMRange loc, ast::Type type,
180 StringRef nativeType);
181
182 //===--------------------------------------------------------------------===//
183 // Decls
184
185 /// This structure contains the set of pattern metadata that may be parsed.
186 struct ParsedPatternMetadata {
187 std::optional<uint16_t> benefit;
188 bool hasBoundedRecursion = false;
189 };
190
191 FailureOr<ast::Decl *> parseTopLevelDecl();
192 FailureOr<ast::NamedAttributeDecl *>
193 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
194
195 /// Parse an argument variable as part of the signature of a
196 /// UserConstraintDecl or UserRewriteDecl.
197 FailureOr<ast::VariableDecl *> parseArgumentDecl();
198
199 /// Parse a result variable as part of the signature of a UserConstraintDecl
200 /// or UserRewriteDecl.
201 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
202
203 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
204 /// defined in a non-global context.
205 FailureOr<ast::UserConstraintDecl *>
206 parseUserConstraintDecl(bool isInline = false);
207
208 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
209 /// non-global context, such as within a Pattern/Constraint/etc.
210 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
211
212 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
213 /// PDLL constructs.
214 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
215 const ast::Name &name, bool isInline,
216 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
217 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
218
219 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
220 /// defined in a non-global context.
221 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
222
223 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
224 /// non-global context, such as within a Pattern/Rewrite/etc.
225 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
226
227 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
228 /// PDLL constructs.
229 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
230 const ast::Name &name, bool isInline,
231 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
232 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
233
234 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
235 /// effectively the same syntax, and only differ on slight semantics (given
236 /// the different parsing contexts).
237 template <typename T, typename ParseUserPDLLDeclFnT>
238 FailureOr<T *> parseUserConstraintOrRewriteDecl(
239 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
240 StringRef anonymousNamePrefix, bool isInline);
241
242 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
243 /// These decls have effectively the same syntax.
244 template <typename T>
245 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
246 const ast::Name &name, bool isInline,
247 ArrayRef<ast::VariableDecl *> arguments,
248 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
249
250 /// Parse the functional signature (i.e. the arguments and results) of a
251 /// UserConstraintDecl or UserRewriteDecl.
252 LogicalResult parseUserConstraintOrRewriteSignature(
253 SmallVectorImpl<ast::VariableDecl *> &arguments,
254 SmallVectorImpl<ast::VariableDecl *> &results,
255 ast::DeclScope *&argumentScope, ast::Type &resultType);
256
257 /// Validate the return (which if present is specified by bodyIt) of a
258 /// UserConstraintDecl or UserRewriteDecl.
259 LogicalResult validateUserConstraintOrRewriteReturn(
260 StringRef declType, ast::CompoundStmt *body,
261 ArrayRef<ast::Stmt *>::iterator bodyIt,
262 ArrayRef<ast::Stmt *>::iterator bodyE,
263 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
264
265 FailureOr<ast::CompoundStmt *>
266 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
267 bool expectTerminalSemicolon = true);
268 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
269 FailureOr<ast::Decl *> parsePatternDecl();
270 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
271
272 /// Check to see if a decl has already been defined with the given name, if
273 /// one has emit and error and return failure. Returns success otherwise.
274 LogicalResult checkDefineNamedDecl(const ast::Name &name);
275
276 /// Try to define a variable decl with the given components, returns the
277 /// variable on success.
278 FailureOr<ast::VariableDecl *>
279 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
280 ast::Expr *initExpr,
281 ArrayRef<ast::ConstraintRef> constraints);
282 FailureOr<ast::VariableDecl *>
283 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
284 ArrayRef<ast::ConstraintRef> constraints);
285
286 /// Parse the constraint reference list for a variable decl.
287 LogicalResult parseVariableDeclConstraintList(
288 SmallVectorImpl<ast::ConstraintRef> &constraints);
289
290 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
291 FailureOr<ast::Expr *> parseTypeConstraintExpr();
292
293 /// Try to parse a single reference to a constraint. `typeConstraint` is the
294 /// location of a previously parsed type constraint for the entity that will
295 /// be constrained by the parsed constraint. `existingConstraints` are any
296 /// existing constraints that have already been parsed for the same entity
297 /// that will be constrained by this constraint. `allowInlineTypeConstraints`
298 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
299 FailureOr<ast::ConstraintRef>
300 parseConstraint(std::optional<SMRange> &typeConstraint,
301 ArrayRef<ast::ConstraintRef> existingConstraints,
302 bool allowInlineTypeConstraints);
303
304 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
305 /// argument or result variable. The constraints for these variables do not
306 /// allow inline type constraints, and only permit a single constraint.
307 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
308
309 //===--------------------------------------------------------------------===//
310 // Exprs
311
312 FailureOr<ast::Expr *> parseExpr();
313
314 /// Identifier expressions.
315 FailureOr<ast::Expr *> parseAttributeExpr();
316 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
317 bool isNegated = false);
318 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
319 FailureOr<ast::Expr *> parseIdentifierExpr();
320 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
321 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
322 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
323 FailureOr<ast::Expr *> parseNegatedExpr();
324 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
325 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
326 FailureOr<ast::Expr *>
327 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328 OpResultTypeContext::Explicit);
329 FailureOr<ast::Expr *> parseTupleExpr();
330 FailureOr<ast::Expr *> parseTypeExpr();
331 FailureOr<ast::Expr *> parseUnderscoreExpr();
332
333 //===--------------------------------------------------------------------===//
334 // Stmts
335
336 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
337 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338 FailureOr<ast::EraseStmt *> parseEraseStmt();
339 FailureOr<ast::LetStmt *> parseLetStmt();
340 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341 FailureOr<ast::ReturnStmt *> parseReturnStmt();
342 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
343
344 //===--------------------------------------------------------------------===//
345 // Creation+Analysis
346 //===--------------------------------------------------------------------===//
347
348 //===--------------------------------------------------------------------===//
349 // Decls
350
351 /// Try to extract a callable from the given AST node. Returns nullptr on
352 /// failure.
353 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
354
355 /// Try to create a pattern decl with the given components, returning the
356 /// Pattern on success.
357 FailureOr<ast::PatternDecl *>
358 createPatternDecl(SMRange loc, const ast::Name *name,
359 const ParsedPatternMetadata &metadata,
360 ast::CompoundStmt *body);
361
362 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
363 /// of results, defined as part of the signature.
364 ast::Type
365 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
366
367 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
368 template <typename T>
369 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
370 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
371 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
372 ast::CompoundStmt *body);
373
374 /// Try to create a variable decl with the given components, returning the
375 /// Variable on success.
376 FailureOr<ast::VariableDecl *>
377 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
378 ArrayRef<ast::ConstraintRef> constraints);
379
380 /// Create a variable for an argument or result defined as part of the
381 /// signature of a UserConstraintDecl/UserRewriteDecl.
382 FailureOr<ast::VariableDecl *>
383 createArgOrResultVariableDecl(StringRef name, SMRange loc,
384 const ast::ConstraintRef &constraint);
385
386 /// Validate the constraints used to constraint a variable decl.
387 /// `inferredType` is the type of the variable inferred by the constraints
388 /// within the list, and is updated to the most refined type as determined by
389 /// the constraints. Returns success if the constraint list is valid, failure
390 /// otherwise.
391 LogicalResult
392 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
393 ast::Type &inferredType);
394 /// Validate a single reference to a constraint. `inferredType` contains the
395 /// currently inferred variabled type and is refined within the type defined
396 /// by the constraint. Returns success if the constraint is valid, failure
397 /// otherwise.
398 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
399 ast::Type &inferredType);
400 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
401 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
402
403 //===--------------------------------------------------------------------===//
404 // Exprs
405
406 FailureOr<ast::CallExpr *>
407 createCallExpr(SMRange loc, ast::Expr *parentExpr,
408 MutableArrayRef<ast::Expr *> arguments,
409 bool isNegated = false);
410 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
411 FailureOr<ast::DeclRefExpr *>
412 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
413 ArrayRef<ast::ConstraintRef> constraints);
414 FailureOr<ast::MemberAccessExpr *>
415 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
416
417 /// Validate the member access `name` into the given parent expression. On
418 /// success, this also returns the type of the member accessed.
419 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
420 StringRef name, SMRange loc);
421 FailureOr<ast::OperationExpr *>
422 createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
423 OpResultTypeContext resultTypeContext,
424 SmallVectorImpl<ast::Expr *> &operands,
425 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
426 SmallVectorImpl<ast::Expr *> &results);
427 LogicalResult
428 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
429 const ods::Operation *odsOp,
430 SmallVectorImpl<ast::Expr *> &operands);
431 LogicalResult validateOperationResults(SMRange loc,
432 std::optional<StringRef> name,
433 const ods::Operation *odsOp,
434 SmallVectorImpl<ast::Expr *> &results);
435 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
436 const ods::Operation *odsOp);
437 LogicalResult validateOperationOperandsOrResults(
438 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
439 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
440 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
441 ast::RangeType rangeTy);
442 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
443 ArrayRef<ast::Expr *> elements,
444 ArrayRef<StringRef> elementNames);
445
446 //===--------------------------------------------------------------------===//
447 // Stmts
448
449 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
450 FailureOr<ast::ReplaceStmt *>
451 createReplaceStmt(SMRange loc, ast::Expr *rootOp,
452 MutableArrayRef<ast::Expr *> replValues);
453 FailureOr<ast::RewriteStmt *>
454 createRewriteStmt(SMRange loc, ast::Expr *rootOp,
455 ast::CompoundStmt *rewriteBody);
456
457 //===--------------------------------------------------------------------===//
458 // Code Completion
459 //===--------------------------------------------------------------------===//
460
461 /// The set of various code completion methods. Every completion method
462 /// returns `failure` to stop the parsing process after providing completion
463 /// results.
464
465 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
466 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467 LogicalResult codeCompleteConstraintName(ast::Type inferredType,
468 bool allowInlineTypeConstraints);
469 LogicalResult codeCompleteDialectName();
470 LogicalResult codeCompleteOperationName(StringRef dialectName);
471 LogicalResult codeCompletePatternMetadata();
472 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
473
474 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
475 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
476 unsigned currentNumOperands);
477 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
478 unsigned currentNumResults);
479
480 //===--------------------------------------------------------------------===//
481 // Lexer Utilities
482 //===--------------------------------------------------------------------===//
483
484 /// If the current token has the specified kind, consume it and return true.
485 /// If not, return false.
486 bool consumeIf(Token::Kind kind) {
487 if (curToken.isNot(k: kind))
488 return false;
489 consumeToken(kind);
490 return true;
491 }
492
493 /// Advance the current lexer onto the next token.
494 void consumeToken() {
495 assert(curToken.isNot(Token::eof, Token::error) &&
496 "shouldn't advance past EOF or errors");
497 curToken = lexer.lexToken();
498 }
499
500 /// Advance the current lexer onto the next token, asserting what the expected
501 /// current token is. This is preferred to the above method because it leads
502 /// to more self-documenting code with better checking.
503 void consumeToken(Token::Kind kind) {
504 assert(curToken.is(kind) && "consumed an unexpected token");
505 consumeToken();
506 }
507
508 /// Reset the lexer to the location at the given position.
509 void resetToken(SMRange tokLoc) {
510 lexer.resetPointer(newPointer: tokLoc.Start.getPointer());
511 curToken = lexer.lexToken();
512 }
513
514 /// Consume the specified token if present and return success. On failure,
515 /// output a diagnostic and return failure.
516 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
517 if (curToken.getKind() != kind)
518 return emitError(loc: curToken.getLoc(), msg);
519 consumeToken();
520 return success();
521 }
522 LogicalResult emitError(SMRange loc, const Twine &msg) {
523 lexer.emitError(loc, msg);
524 return failure();
525 }
526 LogicalResult emitError(const Twine &msg) {
527 return emitError(loc: curToken.getLoc(), msg);
528 }
529 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
530 const Twine &note) {
531 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
532 return failure();
533 }
534
535 //===--------------------------------------------------------------------===//
536 // Fields
537 //===--------------------------------------------------------------------===//
538
539 /// The owning AST context.
540 ast::Context &ctx;
541
542 /// The lexer of this parser.
543 Lexer lexer;
544
545 /// The current token within the lexer.
546 Token curToken;
547
548 /// A flag indicating if the parser should add documentation to AST nodes when
549 /// viable.
550 bool enableDocumentation;
551
552 /// The most recently defined decl scope.
553 ast::DeclScope *curDeclScope = nullptr;
554 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
555
556 /// The current context of the parser.
557 ParserContext parserContext = ParserContext::Global;
558
559 /// Cached types to simplify verification and expression creation.
560 ast::Type typeTy, valueTy;
561 ast::RangeType typeRangeTy, valueRangeTy;
562 ast::Type attrTy;
563
564 /// A counter used when naming anonymous constraints and rewrites.
565 unsigned anonymousDeclNameCounter = 0;
566
567 /// The optional code completion context.
568 CodeCompleteContext *codeCompleteContext;
569};
570} // namespace
571
572FailureOr<ast::Module *> Parser::parseModule() {
573 SMLoc moduleLoc = curToken.getStartLoc();
574 pushDeclScope();
575
576 // Parse the top-level decls of the module.
577 SmallVector<ast::Decl *> decls;
578 if (failed(Result: parseModuleBody(decls)))
579 return popDeclScope(), failure();
580
581 popDeclScope();
582 return ast::Module::create(ctx, loc: moduleLoc, children: decls);
583}
584
585LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
586 while (curToken.isNot(k: Token::eof)) {
587 if (curToken.is(k: Token::directive)) {
588 if (failed(Result: parseDirective(decls)))
589 return failure();
590 continue;
591 }
592
593 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
594 if (failed(Result: decl))
595 return failure();
596 decls.push_back(Elt: *decl);
597 }
598 return success();
599}
600
601ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
602 return ast::AllResultsMemberAccessExpr::create(ctx, loc: opExpr->getLoc(), parentExpr: opExpr,
603 type: valueRangeTy);
604}
605
606LogicalResult Parser::convertExpressionTo(
607 ast::Expr *&expr, ast::Type type,
608 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
609 ast::Type exprType = expr->getType();
610 if (exprType == type)
611 return success();
612
613 auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
614 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
615 loc: expr->getLoc(), msg: llvm::formatv(Fmt: "unable to convert expression of type "
616 "`{0}` to the expected type of "
617 "`{1}`",
618 Vals&: exprType, Vals&: type));
619 if (noteAttachFn)
620 noteAttachFn(*diag);
621 return diag;
622 };
623
624 if (auto exprOpType = dyn_cast<ast::OperationType>(Val&: exprType))
625 return convertOpExpressionTo(expr, exprType: exprOpType, type, emitErrorFn: emitConvertError);
626
627 // FIXME: Decide how to allow/support converting a single result to multiple,
628 // and multiple to a single result. For now, we just allow Single->Range,
629 // but this isn't something really supported in the PDL dialect. We should
630 // figure out some way to support both.
631 if ((exprType == valueTy || exprType == valueRangeTy) &&
632 (type == valueTy || type == valueRangeTy))
633 return success();
634 if ((exprType == typeTy || exprType == typeRangeTy) &&
635 (type == typeTy || type == typeRangeTy))
636 return success();
637
638 // Handle tuple types.
639 if (auto exprTupleType = dyn_cast<ast::TupleType>(Val&: exprType))
640 return convertTupleExpressionTo(expr, exprType: exprTupleType, type, emitErrorFn: emitConvertError,
641 noteAttachFn);
642
643 return emitConvertError();
644}
645
646LogicalResult Parser::convertOpExpressionTo(
647 ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
648 function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
649 // Two operation types are compatible if they have the same name, or if the
650 // expected type is more general.
651 if (auto opType = dyn_cast<ast::OperationType>(Val&: type)) {
652 if (opType.getName())
653 return emitErrorFn();
654 return success();
655 }
656
657 // An operation can always convert to a ValueRange.
658 if (type == valueRangeTy) {
659 expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr,
660 type: valueRangeTy);
661 return success();
662 }
663
664 // Allow conversion to a single value by constraining the result range.
665 if (type == valueTy) {
666 // If the operation is registered, we can verify if it can ever have a
667 // single result.
668 if (const ods::Operation *odsOp = exprType.getODSOperation()) {
669 if (odsOp->getResults().empty()) {
670 return emitErrorFn()->attachNote(
671 msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined "
672 "with zero results",
673 Vals: odsOp->getName()),
674 noteLoc: odsOp->getLoc());
675 }
676
677 unsigned numSingleResults = llvm::count_if(
678 Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) {
679 return result.getVariableLengthKind() ==
680 ods::VariableLengthKind::Single;
681 });
682 if (numSingleResults > 1) {
683 return emitErrorFn()->attachNote(
684 msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined "
685 "with at least {1} results",
686 Vals: odsOp->getName(), Vals&: numSingleResults),
687 noteLoc: odsOp->getLoc());
688 }
689 }
690
691 expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr,
692 type: valueTy);
693 return success();
694 }
695 return emitErrorFn();
696}
697
698LogicalResult Parser::convertTupleExpressionTo(
699 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
700 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
701 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
702 // Handle conversions between tuples.
703 if (auto tupleType = dyn_cast<ast::TupleType>(Val&: type)) {
704 if (tupleType.size() != exprType.size())
705 return emitErrorFn();
706
707 // Build a new tuple expression using each of the elements of the current
708 // tuple.
709 SmallVector<ast::Expr *> newExprs;
710 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
711 newExprs.push_back(Elt: ast::MemberAccessExpr::create(
712 ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i),
713 type: exprType.getElementTypes()[i]));
714
715 auto diagFn = [&](ast::Diagnostic &diag) {
716 diag.attachNote(msg: llvm::formatv(Fmt: "when converting element #{0} of `{1}`",
717 Vals&: i, Vals&: exprType));
718 if (noteAttachFn)
719 noteAttachFn(diag);
720 };
721 if (failed(Result: convertExpressionTo(expr&: newExprs.back(),
722 type: tupleType.getElementTypes()[i], noteAttachFn: diagFn)))
723 return failure();
724 }
725 expr = ast::TupleExpr::create(ctx, loc: expr->getLoc(), elements: newExprs,
726 elementNames: tupleType.getElementNames());
727 return success();
728 }
729
730 // Handle conversion to a range.
731 auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
732 ast::RangeType resultTy) -> LogicalResult {
733 // TODO: We currently only allow range conversion within a rewrite context.
734 if (parserContext != ParserContext::Rewrite) {
735 return emitErrorFn()->attachNote(msg: "Tuple to Range conversion is currently "
736 "only allowed within a rewrite context");
737 }
738
739 // All of the tuple elements must be allowed types.
740 for (ast::Type elementType : exprType.getElementTypes())
741 if (!llvm::is_contained(Range&: allowedElementTypes, Element: elementType))
742 return emitErrorFn();
743
744 // Build a new tuple expression using each of the elements of the current
745 // tuple.
746 SmallVector<ast::Expr *> newExprs;
747 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
748 newExprs.push_back(Elt: ast::MemberAccessExpr::create(
749 ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i),
750 type: exprType.getElementTypes()[i]));
751 }
752 expr = ast::RangeExpr::create(ctx, loc: expr->getLoc(), elements: newExprs, type: resultTy);
753 return success();
754 };
755 if (type == valueRangeTy)
756 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
757 if (type == typeRangeTy)
758 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
759
760 return emitErrorFn();
761}
762
763//===----------------------------------------------------------------------===//
764// Directives
765//===----------------------------------------------------------------------===//
766
767LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
768 StringRef directive = curToken.getSpelling();
769 if (directive == "#include")
770 return parseInclude(decls);
771
772 return emitError(msg: "unknown directive `" + directive + "`");
773}
774
775LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
776 SMRange loc = curToken.getLoc();
777 consumeToken(kind: Token::directive);
778
779 // Handle code completion of the include file path.
780 if (curToken.is(k: Token::code_complete_string))
781 return codeCompleteIncludeFilename(curPath: curToken.getStringValue());
782
783 // Parse the file being included.
784 if (!curToken.isString())
785 return emitError(loc,
786 msg: "expected string file name after `include` directive");
787 SMRange fileLoc = curToken.getLoc();
788 std::string filenameStr = curToken.getStringValue();
789 StringRef filename = filenameStr;
790 consumeToken();
791
792 // Check the type of include. If ending with `.pdll`, this is another pdl file
793 // to be parsed along with the current module.
794 if (filename.ends_with(Suffix: ".pdll")) {
795 if (failed(Result: lexer.pushInclude(filename, includeLoc: fileLoc)))
796 return emitError(loc: fileLoc,
797 msg: "unable to open include file `" + filename + "`");
798
799 // If we added the include successfully, parse it into the current module.
800 // Make sure to update to the next token after we finish parsing the nested
801 // file.
802 curToken = lexer.lexToken();
803 LogicalResult result = parseModuleBody(decls);
804 curToken = lexer.lexToken();
805 return result;
806 }
807
808 // Otherwise, this must be a `.td` include.
809 if (filename.ends_with(Suffix: ".td"))
810 return parseTdInclude(filename, fileLoc, decls);
811
812 return emitError(loc: fileLoc,
813 msg: "expected include filename to end with `.pdll` or `.td`");
814}
815
816LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
817 SmallVectorImpl<ast::Decl *> &decls) {
818 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
819
820 // Use the source manager to open the file, but don't yet add it.
821 std::string includedFile;
822 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
823 parserSrcMgr.OpenIncludeFile(Filename: filename.str(), IncludedFile&: includedFile);
824 if (!includeBuffer)
825 return emitError(loc: fileLoc, msg: "unable to open include file `" + filename + "`");
826
827 // Setup the source manager for parsing the tablegen file.
828 llvm::SourceMgr tdSrcMgr;
829 tdSrcMgr.AddNewSourceBuffer(F: std::move(*includeBuffer), IncludeLoc: SMLoc());
830 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
831
832 // This class provides a context argument for the llvm::SourceMgr diagnostic
833 // handler.
834 struct DiagHandlerContext {
835 Parser &parser;
836 StringRef filename;
837 llvm::SMRange loc;
838 } handlerContext{.parser: *this, .filename: filename, .loc: fileLoc};
839
840 // Set the diagnostic handler for the tablegen source manager.
841 tdSrcMgr.setDiagHandler(
842 DH: [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
843 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
844 (void)ctx->parser.emitError(
845 loc: ctx->loc,
846 msg: llvm::formatv(Fmt: "error while processing include file `{0}`: {1}",
847 Vals&: ctx->filename, Vals: diag.getMessage()));
848 },
849 Ctx: &handlerContext);
850
851 // Parse the tablegen file.
852 llvm::RecordKeeper tdRecords;
853 if (llvm::TableGenParseFile(InputSrcMgr&: tdSrcMgr, Records&: tdRecords))
854 return failure();
855
856 // Process the parsed records.
857 processTdIncludeRecords(tdRecords, decls);
858
859 // After we are done processing, move all of the tablegen source buffers to
860 // the main parser source mgr. This allows for directly using source locations
861 // from the .td files without needing to remap them.
862 parserSrcMgr.takeSourceBuffersFrom(SrcMgr&: tdSrcMgr, MainBufferIncludeLoc: fileLoc.End);
863 return success();
864}
865
866void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
867 SmallVectorImpl<ast::Decl *> &decls) {
868 // Return the length kind of the given value.
869 auto getLengthKind = [](const auto &value) {
870 if (value.isOptional())
871 return ods::VariableLengthKind::Optional;
872 return value.isVariadic() ? ods::VariableLengthKind::Variadic
873 : ods::VariableLengthKind::Single;
874 };
875
876 // Insert a type constraint into the ODS context.
877 ods::Context &odsContext = ctx.getODSContext();
878 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
879 -> const ods::TypeConstraint & {
880 return odsContext.insertTypeConstraint(
881 name: cst.constraint.getUniqueDefName(),
882 summary: processDoc(doc: cst.constraint.getSummary()), cppClass: cst.constraint.getCppType());
883 };
884 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
885 return {loc, llvm::SMLoc::getFromPointer(Ptr: loc.getPointer() + 1)};
886 };
887
888 // Process the parsed tablegen records to build ODS information.
889 /// Operations.
890 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Op")) {
891 tblgen::Operator op(def);
892
893 // Check to see if this operation is known to support type inferrence.
894 bool supportsResultTypeInferrence =
895 op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait");
896
897 auto [odsOp, inserted] = odsContext.insertOperation(
898 name: op.getOperationName(), summary: processDoc(doc: op.getSummary()),
899 desc: processAndFormatDoc(doc: op.getDescription()), nativeClassName: op.getQualCppClassName(),
900 supportsResultTypeInferrence, loc: op.getLoc().front());
901
902 // Ignore operations that have already been added.
903 if (!inserted)
904 continue;
905
906 for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
907 odsOp->appendAttribute(name: attr.name, optional: attr.attr.isOptional(),
908 constraint: odsContext.insertAttributeConstraint(
909 name: attr.attr.getUniqueDefName(),
910 summary: processDoc(doc: attr.attr.getSummary()),
911 cppClass: attr.attr.getStorageType()));
912 }
913 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
914 odsOp->appendOperand(name: operand.name, variableLengthKind: getLengthKind(operand),
915 constraint: addTypeConstraint(operand));
916 }
917 for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
918 odsOp->appendResult(name: result.name, variableLengthKind: getLengthKind(result),
919 constraint: addTypeConstraint(result));
920 }
921 }
922
923 auto shouldBeSkipped = [this](const llvm::Record *def) {
924 return def->isAnonymous() || curDeclScope->lookup(name: def->getName()) ||
925 def->isSubClassOf(Name: "DeclareInterfaceMethods");
926 };
927
928 /// Attr constraints.
929 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Attr")) {
930 if (shouldBeSkipped(def))
931 continue;
932
933 tblgen::Attribute constraint(def);
934 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
935 constraint, loc: convertLocToRange(def->getLoc().front()), type: attrTy,
936 nativeType: constraint.getStorageType()));
937 }
938 /// Type constraints.
939 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Type")) {
940 if (shouldBeSkipped(def))
941 continue;
942
943 tblgen::TypeConstraint constraint(def);
944 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
945 constraint, loc: convertLocToRange(def->getLoc().front()), type: typeTy,
946 nativeType: constraint.getCppType()));
947 }
948 /// OpInterfaces.
949 ast::Type opTy = ast::OperationType::get(context&: ctx);
950 for (const llvm::Record *def :
951 tdRecords.getAllDerivedDefinitions(ClassName: "OpInterface")) {
952 if (shouldBeSkipped(def))
953 continue;
954
955 SMRange loc = convertLocToRange(def->getLoc().front());
956
957 std::string cppClassName =
958 llvm::formatv(Fmt: "{0}::{1}", Vals: def->getValueAsString(FieldName: "cppNamespace"),
959 Vals: def->getValueAsString(FieldName: "cppInterfaceName"))
960 .str();
961 std::string codeBlock =
962 llvm::formatv(Fmt: "return ::mlir::success(llvm::isa<{0}>(self));",
963 Vals&: cppClassName)
964 .str();
965
966 std::string desc =
967 processAndFormatDoc(doc: def->getValueAsString(FieldName: "description"));
968 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
969 name: def->getName(), codeBlock, loc, type: opTy, nativeType: cppClassName, docString: desc));
970 }
971}
972
973template <typename ConstraintT>
974ast::Decl *Parser::createODSNativePDLLConstraintDecl(
975 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
976 StringRef nativeType, StringRef docString) {
977 // Build the single input parameter.
978 ast::DeclScope *argScope = pushDeclScope();
979 auto *paramVar = ast::VariableDecl::create(
980 ctx, name: ast::Name::create(ctx, name: "self", location: loc), type,
981 /*initExpr=*/nullptr, constraints: ast::ConstraintRef(ConstraintT::create(ctx, loc)));
982 argScope->add(decl: paramVar);
983 popDeclScope();
984
985 // Build the native constraint.
986 auto *constraintDecl = ast::UserConstraintDecl::createNative(
987 ctx, name: ast::Name::create(ctx, name, location: loc), inputs: paramVar,
988 /*results=*/{}, codeBlock, resultType: ast::TupleType::get(context&: ctx), nativeInputTypes: nativeType);
989 constraintDecl->setDocComment(ctx, comment: docString);
990 curDeclScope->add(decl: constraintDecl);
991 return constraintDecl;
992}
993
994template <typename ConstraintT>
995ast::Decl *
996Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
997 SMRange loc, ast::Type type,
998 StringRef nativeType) {
999 // Format the condition template.
1000 tblgen::FmtContext fmtContext;
1001 fmtContext.withSelf(subst: "self");
1002 std::string codeBlock = tblgen::tgfmt(
1003 fmt: "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1004 ctx: &fmtContext);
1005
1006 // If documentation was enabled, build the doc string for the generated
1007 // constraint. It would be nice to do this lazily, but TableGen information is
1008 // destroyed after we finish parsing the file.
1009 std::string docString;
1010 if (enableDocumentation) {
1011 StringRef desc = constraint.getDescription();
1012 docString = processAndFormatDoc(
1013 doc: constraint.getSummary() +
1014 (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1015 }
1016
1017 return createODSNativePDLLConstraintDecl<ConstraintT>(
1018 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1019 docString);
1020}
1021
1022//===----------------------------------------------------------------------===//
1023// Decls
1024//===----------------------------------------------------------------------===//
1025
1026FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1027 FailureOr<ast::Decl *> decl;
1028 switch (curToken.getKind()) {
1029 case Token::kw_Constraint:
1030 decl = parseUserConstraintDecl();
1031 break;
1032 case Token::kw_Pattern:
1033 decl = parsePatternDecl();
1034 break;
1035 case Token::kw_Rewrite:
1036 decl = parseUserRewriteDecl();
1037 break;
1038 default:
1039 return emitError(msg: "expected top-level declaration, such as a `Pattern`");
1040 }
1041 if (failed(Result: decl))
1042 return failure();
1043
1044 // If the decl has a name, add it to the current scope.
1045 if (const ast::Name *name = (*decl)->getName()) {
1046 if (failed(Result: checkDefineNamedDecl(name: *name)))
1047 return failure();
1048 curDeclScope->add(decl: *decl);
1049 }
1050 return decl;
1051}
1052
1053FailureOr<ast::NamedAttributeDecl *>
1054Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1055 // Check for name code completion.
1056 if (curToken.is(k: Token::code_complete))
1057 return codeCompleteAttributeName(opName: parentOpName);
1058
1059 std::string attrNameStr;
1060 if (curToken.isString())
1061 attrNameStr = curToken.getStringValue();
1062 else if (curToken.is(k: Token::identifier) || curToken.isKeyword())
1063 attrNameStr = curToken.getSpelling().str();
1064 else
1065 return emitError(msg: "expected identifier or string attribute name");
1066 const auto &name = ast::Name::create(ctx, name: attrNameStr, location: curToken.getLoc());
1067 consumeToken();
1068
1069 // Check for a value of the attribute.
1070 ast::Expr *attrValue = nullptr;
1071 if (consumeIf(kind: Token::equal)) {
1072 FailureOr<ast::Expr *> attrExpr = parseExpr();
1073 if (failed(Result: attrExpr))
1074 return failure();
1075 attrValue = *attrExpr;
1076 } else {
1077 // If there isn't a concrete value, create an expression representing a
1078 // UnitAttr.
1079 attrValue = ast::AttributeExpr::create(ctx, loc: name.getLoc(), value: "unit");
1080 }
1081
1082 return ast::NamedAttributeDecl::create(ctx, name, value: attrValue);
1083}
1084
1085FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1086 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1087 bool expectTerminalSemicolon) {
1088 consumeToken(kind: Token::equal_arrow);
1089
1090 // Parse the single statement of the lambda body.
1091 SMLoc bodyStartLoc = curToken.getStartLoc();
1092 pushDeclScope();
1093 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1094 bool failedToParse =
1095 failed(Result: singleStatement) || failed(Result: processStatementFn(*singleStatement));
1096 popDeclScope();
1097 if (failedToParse)
1098 return failure();
1099
1100 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1101 return ast::CompoundStmt::create(ctx, location: bodyLoc, children: *singleStatement);
1102}
1103
1104FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1105 // Ensure that the argument is named.
1106 if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword())
1107 return emitError(msg: "expected identifier argument name");
1108
1109 // Parse the argument similarly to a normal variable.
1110 StringRef name = curToken.getSpelling();
1111 SMRange nameLoc = curToken.getLoc();
1112 consumeToken();
1113
1114 if (failed(
1115 Result: parseToken(kind: Token::colon, msg: "expected `:` before argument constraint")))
1116 return failure();
1117
1118 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1119 if (failed(Result: cst))
1120 return failure();
1121
1122 return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst);
1123}
1124
1125FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1126 // Check to see if this result is named.
1127 if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) {
1128 // Check to see if this name actually refers to a Constraint.
1129 if (!curDeclScope->lookup<ast::ConstraintDecl>(name: curToken.getSpelling())) {
1130 // If it wasn't a constraint, parse the result similarly to a variable. If
1131 // there is already an existing decl, we will emit an error when defining
1132 // this variable later.
1133 StringRef name = curToken.getSpelling();
1134 SMRange nameLoc = curToken.getLoc();
1135 consumeToken();
1136
1137 if (failed(Result: parseToken(kind: Token::colon,
1138 msg: "expected `:` before result constraint")))
1139 return failure();
1140
1141 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1142 if (failed(Result: cst))
1143 return failure();
1144
1145 return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst);
1146 }
1147 }
1148
1149 // If it isn't named, we parse the constraint directly and create an unnamed
1150 // result variable.
1151 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1152 if (failed(Result: cst))
1153 return failure();
1154
1155 return createArgOrResultVariableDecl(name: "", loc: cst->referenceLoc, constraint: *cst);
1156}
1157
1158FailureOr<ast::UserConstraintDecl *>
1159Parser::parseUserConstraintDecl(bool isInline) {
1160 // Constraints and rewrites have very similar formats, dispatch to a shared
1161 // interface for parsing.
1162 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1163 parseUserPDLLFn: [&](auto &&...args) {
1164 return this->parseUserPDLLConstraintDecl(name: args...);
1165 },
1166 declContext: ParserContext::Constraint, anonymousNamePrefix: "constraint", isInline);
1167}
1168
1169FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1170 FailureOr<ast::UserConstraintDecl *> decl =
1171 parseUserConstraintDecl(/*isInline=*/true);
1172 if (failed(Result: decl) || failed(Result: checkDefineNamedDecl(name: (*decl)->getName())))
1173 return failure();
1174
1175 curDeclScope->add(decl: *decl);
1176 return decl;
1177}
1178
1179FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1180 const ast::Name &name, bool isInline,
1181 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1182 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1183 // Push the argument scope back onto the list, so that the body can
1184 // reference arguments.
1185 pushDeclScope(scope: argumentScope);
1186
1187 // Parse the body of the constraint. The body is either defined as a compound
1188 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1189 ast::CompoundStmt *body;
1190 if (curToken.is(k: Token::equal_arrow)) {
1191 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1192 processStatementFn: [&](ast::Stmt *&stmt) -> LogicalResult {
1193 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(Val: stmt);
1194 if (!stmtExpr) {
1195 return emitError(loc: stmt->getLoc(),
1196 msg: "expected `Constraint` lambda body to contain a "
1197 "single expression");
1198 }
1199 stmt = ast::ReturnStmt::create(ctx, loc: stmt->getLoc(), resultExpr: stmtExpr);
1200 return success();
1201 },
1202 /*expectTerminalSemicolon=*/!isInline);
1203 if (failed(Result: bodyResult))
1204 return failure();
1205 body = *bodyResult;
1206 } else {
1207 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1208 if (failed(Result: bodyResult))
1209 return failure();
1210 body = *bodyResult;
1211
1212 // Verify the structure of the body.
1213 auto bodyIt = body->begin(), bodyE = body->end();
1214 for (; bodyIt != bodyE; ++bodyIt)
1215 if (isa<ast::ReturnStmt>(Val: *bodyIt))
1216 break;
1217 if (failed(Result: validateUserConstraintOrRewriteReturn(
1218 declType: "Constraint", body, bodyIt, bodyE, results, resultType)))
1219 return failure();
1220 }
1221 popDeclScope();
1222
1223 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1224 name, arguments, results, resultType, body);
1225}
1226
1227FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1228 // Constraints and rewrites have very similar formats, dispatch to a shared
1229 // interface for parsing.
1230 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1231 parseUserPDLLFn: [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(name: args...); },
1232 declContext: ParserContext::Rewrite, anonymousNamePrefix: "rewrite", isInline);
1233}
1234
1235FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1236 FailureOr<ast::UserRewriteDecl *> decl =
1237 parseUserRewriteDecl(/*isInline=*/true);
1238 if (failed(Result: decl) || failed(Result: checkDefineNamedDecl(name: (*decl)->getName())))
1239 return failure();
1240
1241 curDeclScope->add(decl: *decl);
1242 return decl;
1243}
1244
1245FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1246 const ast::Name &name, bool isInline,
1247 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1248 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1249 // Push the argument scope back onto the list, so that the body can
1250 // reference arguments.
1251 curDeclScope = argumentScope;
1252 ast::CompoundStmt *body;
1253 if (curToken.is(k: Token::equal_arrow)) {
1254 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1255 processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult {
1256 if (isa<ast::OpRewriteStmt>(Val: statement))
1257 return success();
1258
1259 ast::Expr *statementExpr = dyn_cast<ast::Expr>(Val: statement);
1260 if (!statementExpr) {
1261 return emitError(
1262 loc: statement->getLoc(),
1263 msg: "expected `Rewrite` lambda body to contain a single expression "
1264 "or an operation rewrite statement; such as `erase`, "
1265 "`replace`, or `rewrite`");
1266 }
1267 statement =
1268 ast::ReturnStmt::create(ctx, loc: statement->getLoc(), resultExpr: statementExpr);
1269 return success();
1270 },
1271 /*expectTerminalSemicolon=*/!isInline);
1272 if (failed(Result: bodyResult))
1273 return failure();
1274 body = *bodyResult;
1275 } else {
1276 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1277 if (failed(Result: bodyResult))
1278 return failure();
1279 body = *bodyResult;
1280 }
1281 popDeclScope();
1282
1283 // Verify the structure of the body.
1284 auto bodyIt = body->begin(), bodyE = body->end();
1285 for (; bodyIt != bodyE; ++bodyIt)
1286 if (isa<ast::ReturnStmt>(Val: *bodyIt))
1287 break;
1288 if (failed(Result: validateUserConstraintOrRewriteReturn(declType: "Rewrite", body, bodyIt,
1289 bodyE, results, resultType)))
1290 return failure();
1291 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1292 name, arguments, results, resultType, body);
1293}
1294
1295template <typename T, typename ParseUserPDLLDeclFnT>
1296FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1297 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1298 StringRef anonymousNamePrefix, bool isInline) {
1299 SMRange loc = curToken.getLoc();
1300 consumeToken();
1301 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1302
1303 // Parse the name of the decl.
1304 const ast::Name *name = nullptr;
1305 if (curToken.isNot(k: Token::identifier)) {
1306 // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1307 // in C++, so being unnamed is fine.
1308 if (!isInline)
1309 return emitError(msg: "expected identifier name");
1310
1311 // Create a unique anonymous name to use, as the name for this decl is not
1312 // important.
1313 std::string anonName =
1314 llvm::formatv(Fmt: "<anonymous_{0}_{1}>", Vals&: anonymousNamePrefix,
1315 Vals: anonymousDeclNameCounter++)
1316 .str();
1317 name = &ast::Name::create(ctx, name: anonName, location: loc);
1318 } else {
1319 // If a name was provided, we can use it directly.
1320 name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc());
1321 consumeToken(kind: Token::identifier);
1322 }
1323
1324 // Parse the functional signature of the decl.
1325 SmallVector<ast::VariableDecl *> arguments, results;
1326 ast::DeclScope *argumentScope;
1327 ast::Type resultType;
1328 if (failed(Result: parseUserConstraintOrRewriteSignature(arguments, results,
1329 argumentScope, resultType)))
1330 return failure();
1331
1332 // Check to see which type of constraint this is. If the constraint contains a
1333 // compound body, this is a PDLL decl.
1334 if (curToken.isAny(k1: Token::l_brace, k2: Token::equal_arrow))
1335 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1336 resultType);
1337
1338 // Otherwise, this is a native decl.
1339 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1340 results, resultType);
1341}
1342
1343template <typename T>
1344FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1345 const ast::Name &name, bool isInline,
1346 ArrayRef<ast::VariableDecl *> arguments,
1347 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1348 // If followed by a string, the native code body has also been specified.
1349 std::string codeStrStorage;
1350 std::optional<StringRef> optCodeStr;
1351 if (curToken.isString()) {
1352 codeStrStorage = curToken.getStringValue();
1353 optCodeStr = codeStrStorage;
1354 consumeToken();
1355 } else if (isInline) {
1356 return emitError(loc: name.getLoc(),
1357 msg: "external declarations must be declared in global scope");
1358 } else if (curToken.is(k: Token::error)) {
1359 return failure();
1360 }
1361 if (failed(Result: parseToken(kind: Token::semicolon,
1362 msg: "expected `;` after native declaration")))
1363 return failure();
1364 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1365}
1366
1367LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1368 SmallVectorImpl<ast::VariableDecl *> &arguments,
1369 SmallVectorImpl<ast::VariableDecl *> &results,
1370 ast::DeclScope *&argumentScope, ast::Type &resultType) {
1371 // Parse the argument list of the decl.
1372 if (failed(Result: parseToken(kind: Token::l_paren, msg: "expected `(` to start argument list")))
1373 return failure();
1374
1375 argumentScope = pushDeclScope();
1376 if (curToken.isNot(k: Token::r_paren)) {
1377 do {
1378 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1379 if (failed(Result: argument))
1380 return failure();
1381 arguments.emplace_back(Args&: *argument);
1382 } while (consumeIf(kind: Token::comma));
1383 }
1384 popDeclScope();
1385 if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` to end argument list")))
1386 return failure();
1387
1388 // Parse the results of the decl.
1389 pushDeclScope();
1390 if (consumeIf(kind: Token::arrow)) {
1391 auto parseResultFn = [&]() -> LogicalResult {
1392 FailureOr<ast::VariableDecl *> result = parseResultDecl(resultNum: results.size());
1393 if (failed(Result: result))
1394 return failure();
1395 results.emplace_back(Args&: *result);
1396 return success();
1397 };
1398
1399 // Check for a list of results.
1400 if (consumeIf(kind: Token::l_paren)) {
1401 do {
1402 if (failed(Result: parseResultFn()))
1403 return failure();
1404 } while (consumeIf(kind: Token::comma));
1405 if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` to end result list")))
1406 return failure();
1407
1408 // Otherwise, there is only one result.
1409 } else if (failed(Result: parseResultFn())) {
1410 return failure();
1411 }
1412 }
1413 popDeclScope();
1414
1415 // Compute the result type of the decl.
1416 resultType = createUserConstraintRewriteResultType(results);
1417
1418 // Verify that results are only named if there are more than one.
1419 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1420 return emitError(
1421 loc: results.front()->getLoc(),
1422 msg: "cannot create a single-element tuple with an element label");
1423 }
1424 return success();
1425}
1426
1427LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1428 StringRef declType, ast::CompoundStmt *body,
1429 ArrayRef<ast::Stmt *>::iterator bodyIt,
1430 ArrayRef<ast::Stmt *>::iterator bodyE,
1431 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1432 // Handle if a `return` was provided.
1433 if (bodyIt != bodyE) {
1434 // Emit an error if we have trailing statements after the return.
1435 if (std::next(x: bodyIt) != bodyE) {
1436 return emitError(
1437 loc: (*std::next(x: bodyIt))->getLoc(),
1438 msg: llvm::formatv(Fmt: "`return` terminated the `{0}` body, but found "
1439 "trailing statements afterwards",
1440 Vals&: declType));
1441 }
1442
1443 // Otherwise if a return wasn't provided, check that no results are
1444 // expected.
1445 } else if (!results.empty()) {
1446 return emitError(
1447 loc: {body->getLoc().End, body->getLoc().End},
1448 msg: llvm::formatv(Fmt: "missing return in a `{0}` expected to return `{1}`",
1449 Vals&: declType, Vals&: resultType));
1450 }
1451 return success();
1452}
1453
1454FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1455 return parseLambdaBody(processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult {
1456 if (isa<ast::OpRewriteStmt>(Val: statement))
1457 return success();
1458 return emitError(
1459 loc: statement->getLoc(),
1460 msg: "expected Pattern lambda body to contain a single operation "
1461 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1462 });
1463}
1464
1465FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1466 SMRange loc = curToken.getLoc();
1467 consumeToken(kind: Token::kw_Pattern);
1468 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1469
1470 // Check for an optional identifier for the pattern name.
1471 const ast::Name *name = nullptr;
1472 if (curToken.is(k: Token::identifier)) {
1473 name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc());
1474 consumeToken(kind: Token::identifier);
1475 }
1476
1477 // Parse any pattern metadata.
1478 ParsedPatternMetadata metadata;
1479 if (consumeIf(kind: Token::kw_with) && failed(Result: parsePatternDeclMetadata(metadata)))
1480 return failure();
1481
1482 // Parse the pattern body.
1483 ast::CompoundStmt *body;
1484
1485 // Handle a lambda body.
1486 if (curToken.is(k: Token::equal_arrow)) {
1487 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1488 if (failed(Result: bodyResult))
1489 return failure();
1490 body = *bodyResult;
1491 } else {
1492 if (curToken.isNot(k: Token::l_brace))
1493 return emitError(msg: "expected `{` or `=>` to start pattern body");
1494 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1495 if (failed(Result: bodyResult))
1496 return failure();
1497 body = *bodyResult;
1498
1499 // Verify the body of the pattern.
1500 auto bodyIt = body->begin(), bodyE = body->end();
1501 for (; bodyIt != bodyE; ++bodyIt) {
1502 if (isa<ast::ReturnStmt>(Val: *bodyIt)) {
1503 return emitError(loc: (*bodyIt)->getLoc(),
1504 msg: "`return` statements are only permitted within a "
1505 "`Constraint` or `Rewrite` body");
1506 }
1507 // Break when we've found the rewrite statement.
1508 if (isa<ast::OpRewriteStmt>(Val: *bodyIt))
1509 break;
1510 }
1511 if (bodyIt == bodyE) {
1512 return emitError(loc,
1513 msg: "expected Pattern body to terminate with an operation "
1514 "rewrite statement, such as `erase`");
1515 }
1516 if (std::next(x: bodyIt) != bodyE) {
1517 return emitError(loc: (*std::next(x: bodyIt))->getLoc(),
1518 msg: "Pattern body was terminated by an operation "
1519 "rewrite statement, but found trailing statements");
1520 }
1521 }
1522
1523 return createPatternDecl(loc, name, metadata, body);
1524}
1525
1526LogicalResult
1527Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1528 std::optional<SMRange> benefitLoc;
1529 std::optional<SMRange> hasBoundedRecursionLoc;
1530
1531 do {
1532 // Handle metadata code completion.
1533 if (curToken.is(k: Token::code_complete))
1534 return codeCompletePatternMetadata();
1535
1536 if (curToken.isNot(k: Token::identifier))
1537 return emitError(msg: "expected pattern metadata identifier");
1538 StringRef metadataStr = curToken.getSpelling();
1539 SMRange metadataLoc = curToken.getLoc();
1540 consumeToken(kind: Token::identifier);
1541
1542 // Parse the benefit metadata: benefit(<integer-value>)
1543 if (metadataStr == "benefit") {
1544 if (benefitLoc) {
1545 return emitErrorAndNote(loc: metadataLoc,
1546 msg: "pattern benefit has already been specified",
1547 noteLoc: *benefitLoc, note: "see previous definition here");
1548 }
1549 if (failed(Result: parseToken(kind: Token::l_paren,
1550 msg: "expected `(` before pattern benefit")))
1551 return failure();
1552
1553 uint16_t benefitValue = 0;
1554 if (curToken.isNot(k: Token::integer))
1555 return emitError(msg: "expected integral pattern benefit");
1556 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, Result&: benefitValue))
1557 return emitError(
1558 msg: "expected pattern benefit to fit within a 16-bit integer");
1559 consumeToken(kind: Token::integer);
1560
1561 metadata.benefit = benefitValue;
1562 benefitLoc = metadataLoc;
1563
1564 if (failed(
1565 Result: parseToken(kind: Token::r_paren, msg: "expected `)` after pattern benefit")))
1566 return failure();
1567 continue;
1568 }
1569
1570 // Parse the bounded recursion metadata: recursion
1571 if (metadataStr == "recursion") {
1572 if (hasBoundedRecursionLoc) {
1573 return emitErrorAndNote(
1574 loc: metadataLoc,
1575 msg: "pattern recursion metadata has already been specified",
1576 noteLoc: *hasBoundedRecursionLoc, note: "see previous definition here");
1577 }
1578 metadata.hasBoundedRecursion = true;
1579 hasBoundedRecursionLoc = metadataLoc;
1580 continue;
1581 }
1582
1583 return emitError(loc: metadataLoc, msg: "unknown pattern metadata");
1584 } while (consumeIf(kind: Token::comma));
1585
1586 return success();
1587}
1588
1589FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1590 consumeToken(kind: Token::less);
1591
1592 FailureOr<ast::Expr *> typeExpr = parseExpr();
1593 if (failed(Result: typeExpr) ||
1594 failed(Result: parseToken(kind: Token::greater,
1595 msg: "expected `>` after variable type constraint")))
1596 return failure();
1597 return typeExpr;
1598}
1599
1600LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1601 assert(curDeclScope && "defining decl outside of a decl scope");
1602 if (ast::Decl *lastDecl = curDeclScope->lookup(name: name.getName())) {
1603 return emitErrorAndNote(
1604 loc: name.getLoc(), msg: "`" + name.getName() + "` has already been defined",
1605 noteLoc: lastDecl->getName()->getLoc(), note: "see previous definition here");
1606 }
1607 return success();
1608}
1609
1610FailureOr<ast::VariableDecl *>
1611Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1612 ast::Expr *initExpr,
1613 ArrayRef<ast::ConstraintRef> constraints) {
1614 assert(curDeclScope && "defining variable outside of decl scope");
1615 const ast::Name &nameDecl = ast::Name::create(ctx, name, location: nameLoc);
1616
1617 // If the name of the variable indicates a special variable, we don't add it
1618 // to the scope. This variable is local to the definition point.
1619 if (name.empty() || name == "_") {
1620 return ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr,
1621 constraints);
1622 }
1623 if (failed(Result: checkDefineNamedDecl(name: nameDecl)))
1624 return failure();
1625
1626 auto *varDecl =
1627 ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr, constraints);
1628 curDeclScope->add(decl: varDecl);
1629 return varDecl;
1630}
1631
1632FailureOr<ast::VariableDecl *>
1633Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1634 ArrayRef<ast::ConstraintRef> constraints) {
1635 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1636 constraints);
1637}
1638
1639LogicalResult Parser::parseVariableDeclConstraintList(
1640 SmallVectorImpl<ast::ConstraintRef> &constraints) {
1641 std::optional<SMRange> typeConstraint;
1642 auto parseSingleConstraint = [&] {
1643 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1644 typeConstraint, existingConstraints: constraints, /*allowInlineTypeConstraints=*/true);
1645 if (failed(Result: constraint))
1646 return failure();
1647 constraints.push_back(Elt: *constraint);
1648 return success();
1649 };
1650
1651 // Check to see if this is a single constraint, or a list.
1652 if (!consumeIf(kind: Token::l_square))
1653 return parseSingleConstraint();
1654
1655 do {
1656 if (failed(Result: parseSingleConstraint()))
1657 return failure();
1658 } while (consumeIf(kind: Token::comma));
1659 return parseToken(kind: Token::r_square, msg: "expected `]` after constraint list");
1660}
1661
1662FailureOr<ast::ConstraintRef>
1663Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1664 ArrayRef<ast::ConstraintRef> existingConstraints,
1665 bool allowInlineTypeConstraints) {
1666 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1667 if (!allowInlineTypeConstraints) {
1668 return emitError(
1669 loc: curToken.getLoc(),
1670 msg: "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1671 "permitted on arguments or results");
1672 }
1673 if (typeConstraint)
1674 return emitErrorAndNote(
1675 loc: curToken.getLoc(),
1676 msg: "the type of this variable has already been constrained",
1677 noteLoc: *typeConstraint, note: "see previous constraint location here");
1678 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1679 if (failed(Result: constraintExpr))
1680 return failure();
1681 typeExpr = *constraintExpr;
1682 typeConstraint = typeExpr->getLoc();
1683 return success();
1684 };
1685
1686 SMRange loc = curToken.getLoc();
1687 switch (curToken.getKind()) {
1688 case Token::kw_Attr: {
1689 consumeToken(kind: Token::kw_Attr);
1690
1691 // Check for a type constraint.
1692 ast::Expr *typeExpr = nullptr;
1693 if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr)))
1694 return failure();
1695 return ast::ConstraintRef(
1696 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1697 }
1698 case Token::kw_Op: {
1699 consumeToken(kind: Token::kw_Op);
1700
1701 // Parse an optional operation name. If the name isn't provided, this refers
1702 // to "any" operation.
1703 FailureOr<ast::OpNameDecl *> opName =
1704 parseWrappedOperationName(/*allowEmptyName=*/true);
1705 if (failed(Result: opName))
1706 return failure();
1707
1708 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, nameDecl: *opName),
1709 loc);
1710 }
1711 case Token::kw_Type:
1712 consumeToken(kind: Token::kw_Type);
1713 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1714 case Token::kw_TypeRange:
1715 consumeToken(kind: Token::kw_TypeRange);
1716 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1717 loc);
1718 case Token::kw_Value: {
1719 consumeToken(kind: Token::kw_Value);
1720
1721 // Check for a type constraint.
1722 ast::Expr *typeExpr = nullptr;
1723 if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr)))
1724 return failure();
1725
1726 return ast::ConstraintRef(
1727 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1728 }
1729 case Token::kw_ValueRange: {
1730 consumeToken(kind: Token::kw_ValueRange);
1731
1732 // Check for a type constraint.
1733 ast::Expr *typeExpr = nullptr;
1734 if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr)))
1735 return failure();
1736
1737 return ast::ConstraintRef(
1738 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1739 }
1740
1741 case Token::kw_Constraint: {
1742 // Handle an inline constraint.
1743 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1744 if (failed(Result: decl))
1745 return failure();
1746 return ast::ConstraintRef(*decl, loc);
1747 }
1748 case Token::identifier: {
1749 StringRef constraintName = curToken.getSpelling();
1750 consumeToken(kind: Token::identifier);
1751
1752 // Lookup the referenced constraint.
1753 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(name: constraintName);
1754 if (!cstDecl) {
1755 return emitError(loc, msg: "unknown reference to constraint `" +
1756 constraintName + "`");
1757 }
1758
1759 // Handle a reference to a proper constraint.
1760 if (auto *cst = dyn_cast<ast::ConstraintDecl>(Val: cstDecl))
1761 return ast::ConstraintRef(cst, loc);
1762
1763 return emitErrorAndNote(
1764 loc, msg: "invalid reference to non-constraint", noteLoc: cstDecl->getLoc(),
1765 note: "see the definition of `" + constraintName + "` here");
1766 }
1767 // Handle single entity constraint code completion.
1768 case Token::code_complete: {
1769 // Try to infer the current type for use by code completion.
1770 ast::Type inferredType;
1771 if (failed(Result: validateVariableConstraints(constraints: existingConstraints, inferredType)))
1772 return failure();
1773
1774 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1775 }
1776 default:
1777 break;
1778 }
1779 return emitError(loc, msg: "expected identifier constraint");
1780}
1781
1782FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1783 std::optional<SMRange> typeConstraint;
1784 return parseConstraint(typeConstraint, /*existingConstraints=*/{},
1785 /*allowInlineTypeConstraints=*/false);
1786}
1787
1788//===----------------------------------------------------------------------===//
1789// Exprs
1790//===----------------------------------------------------------------------===//
1791
1792FailureOr<ast::Expr *> Parser::parseExpr() {
1793 if (curToken.is(k: Token::underscore))
1794 return parseUnderscoreExpr();
1795
1796 // Parse the LHS expression.
1797 FailureOr<ast::Expr *> lhsExpr;
1798 switch (curToken.getKind()) {
1799 case Token::kw_attr:
1800 lhsExpr = parseAttributeExpr();
1801 break;
1802 case Token::kw_Constraint:
1803 lhsExpr = parseInlineConstraintLambdaExpr();
1804 break;
1805 case Token::kw_not:
1806 lhsExpr = parseNegatedExpr();
1807 break;
1808 case Token::identifier:
1809 lhsExpr = parseIdentifierExpr();
1810 break;
1811 case Token::kw_op:
1812 lhsExpr = parseOperationExpr();
1813 break;
1814 case Token::kw_Rewrite:
1815 lhsExpr = parseInlineRewriteLambdaExpr();
1816 break;
1817 case Token::kw_type:
1818 lhsExpr = parseTypeExpr();
1819 break;
1820 case Token::l_paren:
1821 lhsExpr = parseTupleExpr();
1822 break;
1823 default:
1824 return emitError(msg: "expected expression");
1825 }
1826 if (failed(Result: lhsExpr))
1827 return failure();
1828
1829 // Check for an operator expression.
1830 while (true) {
1831 switch (curToken.getKind()) {
1832 case Token::dot:
1833 lhsExpr = parseMemberAccessExpr(parentExpr: *lhsExpr);
1834 break;
1835 case Token::l_paren:
1836 lhsExpr = parseCallExpr(parentExpr: *lhsExpr);
1837 break;
1838 default:
1839 return lhsExpr;
1840 }
1841 if (failed(Result: lhsExpr))
1842 return failure();
1843 }
1844}
1845
1846FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1847 SMRange loc = curToken.getLoc();
1848 consumeToken(kind: Token::kw_attr);
1849
1850 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1851 // identifier.
1852 if (!consumeIf(kind: Token::less)) {
1853 resetToken(tokLoc: loc);
1854 return parseIdentifierExpr();
1855 }
1856
1857 if (!curToken.isString())
1858 return emitError(msg: "expected string literal containing MLIR attribute");
1859 std::string attrExpr = curToken.getStringValue();
1860 consumeToken();
1861
1862 loc.End = curToken.getEndLoc();
1863 if (failed(
1864 Result: parseToken(kind: Token::greater, msg: "expected `>` after attribute literal")))
1865 return failure();
1866 return ast::AttributeExpr::create(ctx, loc, value: attrExpr);
1867}
1868
1869FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1870 bool isNegated) {
1871 consumeToken(kind: Token::l_paren);
1872
1873 // Parse the arguments of the call.
1874 SmallVector<ast::Expr *> arguments;
1875 if (curToken.isNot(k: Token::r_paren)) {
1876 do {
1877 // Handle code completion for the call arguments.
1878 if (curToken.is(k: Token::code_complete)) {
1879 codeCompleteCallSignature(parent: parentExpr, currentNumArgs: arguments.size());
1880 return failure();
1881 }
1882
1883 FailureOr<ast::Expr *> argument = parseExpr();
1884 if (failed(Result: argument))
1885 return failure();
1886 arguments.push_back(Elt: *argument);
1887 } while (consumeIf(kind: Token::comma));
1888 }
1889
1890 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1891 if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` after argument list")))
1892 return failure();
1893
1894 return createCallExpr(loc, parentExpr, arguments, isNegated);
1895}
1896
1897FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1898 ast::Decl *decl = curDeclScope->lookup(name);
1899 if (!decl)
1900 return emitError(loc, msg: "undefined reference to `" + name + "`");
1901
1902 return createDeclRefExpr(loc, decl);
1903}
1904
1905FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1906 StringRef name = curToken.getSpelling();
1907 SMRange nameLoc = curToken.getLoc();
1908 consumeToken();
1909
1910 // Check to see if this is a decl ref expression that defines a variable
1911 // inline.
1912 if (consumeIf(kind: Token::colon)) {
1913 SmallVector<ast::ConstraintRef> constraints;
1914 if (failed(Result: parseVariableDeclConstraintList(constraints)))
1915 return failure();
1916 ast::Type type;
1917 if (failed(Result: validateVariableConstraints(constraints, inferredType&: type)))
1918 return failure();
1919 return createInlineVariableExpr(type, name, loc: nameLoc, constraints);
1920 }
1921
1922 return parseDeclRefExpr(name, loc: nameLoc);
1923}
1924
1925FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1926 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1927 if (failed(Result: decl))
1928 return failure();
1929
1930 return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl,
1931 type: ast::ConstraintType::get(context&: ctx));
1932}
1933
1934FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1935 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1936 if (failed(Result: decl))
1937 return failure();
1938
1939 return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl,
1940 type: ast::RewriteType::get(context&: ctx));
1941}
1942
1943FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1944 SMRange dotLoc = curToken.getLoc();
1945 consumeToken(kind: Token::dot);
1946
1947 // Check for code completion of the member name.
1948 if (curToken.is(k: Token::code_complete))
1949 return codeCompleteMemberAccess(parentExpr);
1950
1951 // Parse the member name.
1952 Token memberNameTok = curToken;
1953 if (memberNameTok.isNot(k1: Token::identifier, k2: Token::integer) &&
1954 !memberNameTok.isKeyword())
1955 return emitError(loc: dotLoc, msg: "expected identifier or numeric member name");
1956 StringRef memberName = memberNameTok.getSpelling();
1957 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1958 consumeToken();
1959
1960 return createMemberAccessExpr(parentExpr, name: memberName, loc);
1961}
1962
1963FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1964 consumeToken(kind: Token::kw_not);
1965 // Only native constraints are supported after negation
1966 if (!curToken.is(k: Token::identifier))
1967 return emitError(msg: "expected native constraint");
1968 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1969 if (failed(Result: identifierExpr))
1970 return failure();
1971 if (!curToken.is(k: Token::l_paren))
1972 return emitError(msg: "expected `(` after function name");
1973 return parseCallExpr(parentExpr: *identifierExpr, /*isNegated = */ true);
1974}
1975
1976FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1977 SMRange loc = curToken.getLoc();
1978
1979 // Check for code completion for the dialect name.
1980 if (curToken.is(k: Token::code_complete))
1981 return codeCompleteDialectName();
1982
1983 // Handle the case of an no operation name.
1984 if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword()) {
1985 if (allowEmptyName)
1986 return ast::OpNameDecl::create(ctx, loc: SMRange());
1987 return emitError(msg: "expected dialect namespace");
1988 }
1989 StringRef name = curToken.getSpelling();
1990 consumeToken();
1991
1992 // Otherwise, this is a literal operation name.
1993 if (failed(Result: parseToken(kind: Token::dot, msg: "expected `.` after dialect namespace")))
1994 return failure();
1995
1996 // Check for code completion for the operation name.
1997 if (curToken.is(k: Token::code_complete))
1998 return codeCompleteOperationName(dialectName: name);
1999
2000 if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword())
2001 return emitError(msg: "expected operation name after dialect namespace");
2002
2003 name = StringRef(name.data(), name.size() + 1);
2004 do {
2005 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2006 loc.End = curToken.getEndLoc();
2007 consumeToken();
2008 } while (curToken.isAny(k1: Token::identifier, k2: Token::dot) ||
2009 curToken.isKeyword());
2010 return ast::OpNameDecl::create(ctx, name: ast::Name::create(ctx, name, location: loc));
2011}
2012
2013FailureOr<ast::OpNameDecl *>
2014Parser::parseWrappedOperationName(bool allowEmptyName) {
2015 if (!consumeIf(kind: Token::less))
2016 return ast::OpNameDecl::create(ctx, loc: SMRange());
2017
2018 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2019 if (failed(Result: opNameDecl))
2020 return failure();
2021
2022 if (failed(Result: parseToken(kind: Token::greater, msg: "expected `>` after operation name")))
2023 return failure();
2024 return opNameDecl;
2025}
2026
2027FailureOr<ast::Expr *>
2028Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2029 SMRange loc = curToken.getLoc();
2030 consumeToken(kind: Token::kw_op);
2031
2032 // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2033 // identifier.
2034 if (curToken.isNot(k: Token::less)) {
2035 resetToken(tokLoc: loc);
2036 return parseIdentifierExpr();
2037 }
2038
2039 // Parse the operation name. The name may be elided, in which case the
2040 // operation refers to "any" operation(i.e. a difference between `MyOp` and
2041 // `Operation*`). Operation names within a rewrite context must be named.
2042 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2043 FailureOr<ast::OpNameDecl *> opNameDecl =
2044 parseWrappedOperationName(allowEmptyName);
2045 if (failed(Result: opNameDecl))
2046 return failure();
2047 std::optional<StringRef> opName = (*opNameDecl)->getName();
2048
2049 // Functor used to create an implicit range variable, used for implicit "all"
2050 // operand or results variables.
2051 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2052 FailureOr<ast::VariableDecl *> rangeVar =
2053 defineVariableDecl(name: "_", nameLoc: loc, type, constraints: ast::ConstraintRef(cst, loc));
2054 assert(succeeded(rangeVar) && "expected range variable to be valid");
2055 return ast::DeclRefExpr::create(ctx, loc, decl: *rangeVar, type);
2056 };
2057
2058 // Check for the optional list of operands.
2059 SmallVector<ast::Expr *> operands;
2060 if (!consumeIf(kind: Token::l_paren)) {
2061 // If the operand list isn't specified and we are in a match context, define
2062 // an inplace unconstrained operand range corresponding to all of the
2063 // operands of the operation. This avoids treating zero operands the same
2064 // way as "unconstrained operands".
2065 if (parserContext != ParserContext::Rewrite) {
2066 operands.push_back(Elt: createImplicitRangeVar(
2067 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2068 }
2069 } else if (!consumeIf(kind: Token::r_paren)) {
2070 // If the operand list was specified and non-empty, parse the operands.
2071 do {
2072 // Check for operand signature code completion.
2073 if (curToken.is(k: Token::code_complete)) {
2074 codeCompleteOperationOperandsSignature(opName, currentNumOperands: operands.size());
2075 return failure();
2076 }
2077
2078 FailureOr<ast::Expr *> operand = parseExpr();
2079 if (failed(Result: operand))
2080 return failure();
2081 operands.push_back(Elt: *operand);
2082 } while (consumeIf(kind: Token::comma));
2083
2084 if (failed(Result: parseToken(kind: Token::r_paren,
2085 msg: "expected `)` after operation operand list")))
2086 return failure();
2087 }
2088
2089 // Check for the optional list of attributes.
2090 SmallVector<ast::NamedAttributeDecl *> attributes;
2091 if (consumeIf(kind: Token::l_brace)) {
2092 do {
2093 FailureOr<ast::NamedAttributeDecl *> decl =
2094 parseNamedAttributeDecl(parentOpName: opName);
2095 if (failed(Result: decl))
2096 return failure();
2097 attributes.emplace_back(Args&: *decl);
2098 } while (consumeIf(kind: Token::comma));
2099
2100 if (failed(Result: parseToken(kind: Token::r_brace,
2101 msg: "expected `}` after operation attribute list")))
2102 return failure();
2103 }
2104
2105 // Handle the result types of the operation.
2106 SmallVector<ast::Expr *> resultTypes;
2107 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2108
2109 // Check for an explicit list of result types.
2110 if (consumeIf(kind: Token::arrow)) {
2111 if (failed(Result: parseToken(kind: Token::l_paren,
2112 msg: "expected `(` before operation result type list")))
2113 return failure();
2114
2115 // If result types are provided, initially assume that the operation does
2116 // not rely on type inferrence. We don't assert that it isn't, because we
2117 // may be inferring the value of some type/type range variables, but given
2118 // that these variables may be defined in calls we can't always discern when
2119 // this is the case.
2120 resultTypeContext = OpResultTypeContext::Explicit;
2121
2122 // Handle the case of an empty result list.
2123 if (!consumeIf(kind: Token::r_paren)) {
2124 do {
2125 // Check for result signature code completion.
2126 if (curToken.is(k: Token::code_complete)) {
2127 codeCompleteOperationResultsSignature(opName, currentNumResults: resultTypes.size());
2128 return failure();
2129 }
2130
2131 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2132 if (failed(Result: resultTypeExpr))
2133 return failure();
2134 resultTypes.push_back(Elt: *resultTypeExpr);
2135 } while (consumeIf(kind: Token::comma));
2136
2137 if (failed(Result: parseToken(kind: Token::r_paren,
2138 msg: "expected `)` after operation result type list")))
2139 return failure();
2140 }
2141 } else if (parserContext != ParserContext::Rewrite) {
2142 // If the result list isn't specified and we are in a match context, define
2143 // an inplace unconstrained result range corresponding to all of the results
2144 // of the operation. This avoids treating zero results the same way as
2145 // "unconstrained results".
2146 resultTypes.push_back(Elt: createImplicitRangeVar(
2147 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2148 } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2149 // If the result list isn't specified and we are in a rewrite, try to infer
2150 // them at runtime instead.
2151 resultTypeContext = OpResultTypeContext::Interface;
2152 }
2153
2154 return createOperationExpr(loc, name: *opNameDecl, resultTypeContext, operands,
2155 attributes, results&: resultTypes);
2156}
2157
2158FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2159 SMRange loc = curToken.getLoc();
2160 consumeToken(kind: Token::l_paren);
2161
2162 DenseMap<StringRef, SMRange> usedNames;
2163 SmallVector<StringRef> elementNames;
2164 SmallVector<ast::Expr *> elements;
2165 if (curToken.isNot(k: Token::r_paren)) {
2166 do {
2167 // Check for the optional element name assignment before the value.
2168 StringRef elementName;
2169 if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) {
2170 Token elementNameTok = curToken;
2171 consumeToken();
2172
2173 // The element name is only present if followed by an `=`.
2174 if (consumeIf(kind: Token::equal)) {
2175 elementName = elementNameTok.getSpelling();
2176
2177 // Check to see if this name is already used.
2178 auto elementNameIt =
2179 usedNames.try_emplace(Key: elementName, Args: elementNameTok.getLoc());
2180 if (!elementNameIt.second) {
2181 return emitErrorAndNote(
2182 loc: elementNameTok.getLoc(),
2183 msg: llvm::formatv(Fmt: "duplicate tuple element label `{0}`",
2184 Vals&: elementName),
2185 noteLoc: elementNameIt.first->getSecond(),
2186 note: "see previous label use here");
2187 }
2188 } else {
2189 // Otherwise, we treat this as part of an expression so reset the
2190 // lexer.
2191 resetToken(tokLoc: elementNameTok.getLoc());
2192 }
2193 }
2194 elementNames.push_back(Elt: elementName);
2195
2196 // Parse the tuple element value.
2197 FailureOr<ast::Expr *> element = parseExpr();
2198 if (failed(Result: element))
2199 return failure();
2200 elements.push_back(Elt: *element);
2201 } while (consumeIf(kind: Token::comma));
2202 }
2203 loc.End = curToken.getEndLoc();
2204 if (failed(
2205 Result: parseToken(kind: Token::r_paren, msg: "expected `)` after tuple element list")))
2206 return failure();
2207 return createTupleExpr(loc, elements, elementNames);
2208}
2209
2210FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2211 SMRange loc = curToken.getLoc();
2212 consumeToken(kind: Token::kw_type);
2213
2214 // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2215 // identifier.
2216 if (!consumeIf(kind: Token::less)) {
2217 resetToken(tokLoc: loc);
2218 return parseIdentifierExpr();
2219 }
2220
2221 if (!curToken.isString())
2222 return emitError(msg: "expected string literal containing MLIR type");
2223 std::string attrExpr = curToken.getStringValue();
2224 consumeToken();
2225
2226 loc.End = curToken.getEndLoc();
2227 if (failed(Result: parseToken(kind: Token::greater, msg: "expected `>` after type literal")))
2228 return failure();
2229 return ast::TypeExpr::create(ctx, loc, value: attrExpr);
2230}
2231
2232FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2233 StringRef name = curToken.getSpelling();
2234 SMRange nameLoc = curToken.getLoc();
2235 consumeToken(kind: Token::underscore);
2236
2237 // Underscore expressions require a constraint list.
2238 if (failed(Result: parseToken(kind: Token::colon, msg: "expected `:` after `_` variable")))
2239 return failure();
2240
2241 // Parse the constraints for the expression.
2242 SmallVector<ast::ConstraintRef> constraints;
2243 if (failed(Result: parseVariableDeclConstraintList(constraints)))
2244 return failure();
2245
2246 ast::Type type;
2247 if (failed(Result: validateVariableConstraints(constraints, inferredType&: type)))
2248 return failure();
2249 return createInlineVariableExpr(type, name, loc: nameLoc, constraints);
2250}
2251
2252//===----------------------------------------------------------------------===//
2253// Stmts
2254//===----------------------------------------------------------------------===//
2255
2256FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2257 FailureOr<ast::Stmt *> stmt;
2258 switch (curToken.getKind()) {
2259 case Token::kw_erase:
2260 stmt = parseEraseStmt();
2261 break;
2262 case Token::kw_let:
2263 stmt = parseLetStmt();
2264 break;
2265 case Token::kw_replace:
2266 stmt = parseReplaceStmt();
2267 break;
2268 case Token::kw_return:
2269 stmt = parseReturnStmt();
2270 break;
2271 case Token::kw_rewrite:
2272 stmt = parseRewriteStmt();
2273 break;
2274 default:
2275 stmt = parseExpr();
2276 break;
2277 }
2278 if (failed(Result: stmt) ||
2279 (expectTerminalSemicolon &&
2280 failed(Result: parseToken(kind: Token::semicolon, msg: "expected `;` after statement"))))
2281 return failure();
2282 return stmt;
2283}
2284
2285FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2286 SMLoc startLoc = curToken.getStartLoc();
2287 consumeToken(kind: Token::l_brace);
2288
2289 // Push a new block scope and parse any nested statements.
2290 pushDeclScope();
2291 SmallVector<ast::Stmt *> statements;
2292 while (curToken.isNot(k: Token::r_brace)) {
2293 FailureOr<ast::Stmt *> statement = parseStmt();
2294 if (failed(Result: statement))
2295 return popDeclScope(), failure();
2296 statements.push_back(Elt: *statement);
2297 }
2298 popDeclScope();
2299
2300 // Consume the end brace.
2301 SMRange location(startLoc, curToken.getEndLoc());
2302 consumeToken(kind: Token::r_brace);
2303
2304 return ast::CompoundStmt::create(ctx, location, children: statements);
2305}
2306
2307FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2308 if (parserContext == ParserContext::Constraint)
2309 return emitError(msg: "`erase` cannot be used within a Constraint");
2310 SMRange loc = curToken.getLoc();
2311 consumeToken(kind: Token::kw_erase);
2312
2313 // Parse the root operation expression.
2314 FailureOr<ast::Expr *> rootOp = parseExpr();
2315 if (failed(Result: rootOp))
2316 return failure();
2317
2318 return createEraseStmt(loc, rootOp: *rootOp);
2319}
2320
2321FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2322 SMRange loc = curToken.getLoc();
2323 consumeToken(kind: Token::kw_let);
2324
2325 // Parse the name of the new variable.
2326 SMRange varLoc = curToken.getLoc();
2327 if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword()) {
2328 // `_` is a reserved variable name.
2329 if (curToken.is(k: Token::underscore)) {
2330 return emitError(loc: varLoc,
2331 msg: "`_` may only be used to define \"inline\" variables");
2332 }
2333 return emitError(loc: varLoc,
2334 msg: "expected identifier after `let` to name a new variable");
2335 }
2336 StringRef varName = curToken.getSpelling();
2337 consumeToken();
2338
2339 // Parse the optional set of constraints.
2340 SmallVector<ast::ConstraintRef> constraints;
2341 if (consumeIf(kind: Token::colon) &&
2342 failed(Result: parseVariableDeclConstraintList(constraints)))
2343 return failure();
2344
2345 // Parse the optional initializer expression.
2346 ast::Expr *initializer = nullptr;
2347 if (consumeIf(kind: Token::equal)) {
2348 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2349 if (failed(Result: initOrFailure))
2350 return failure();
2351 initializer = *initOrFailure;
2352
2353 // Check that the constraints are compatible with having an initializer,
2354 // e.g. type constraints cannot be used with initializers.
2355 for (ast::ConstraintRef constraint : constraints) {
2356 LogicalResult result =
2357 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2358 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2359 ast::ValueRangeConstraintDecl>(caseFn: [&](const auto *cst) {
2360 if (cst->getTypeExpr()) {
2361 return this->emitError(
2362 loc: constraint.referenceLoc,
2363 msg: "type constraints are not permitted on variables with "
2364 "initializers");
2365 }
2366 return success();
2367 })
2368 .Default(defaultResult: success());
2369 if (failed(Result: result))
2370 return failure();
2371 }
2372 }
2373
2374 FailureOr<ast::VariableDecl *> varDecl =
2375 createVariableDecl(name: varName, loc: varLoc, initializer, constraints);
2376 if (failed(Result: varDecl))
2377 return failure();
2378 return ast::LetStmt::create(ctx, loc, varDecl: *varDecl);
2379}
2380
2381FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2382 if (parserContext == ParserContext::Constraint)
2383 return emitError(msg: "`replace` cannot be used within a Constraint");
2384 SMRange loc = curToken.getLoc();
2385 consumeToken(kind: Token::kw_replace);
2386
2387 // Parse the root operation expression.
2388 FailureOr<ast::Expr *> rootOp = parseExpr();
2389 if (failed(Result: rootOp))
2390 return failure();
2391
2392 if (failed(
2393 Result: parseToken(kind: Token::kw_with, msg: "expected `with` after root operation")))
2394 return failure();
2395
2396 // The replacement portion of this statement is within a rewrite context.
2397 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2398
2399 // Parse the replacement values.
2400 SmallVector<ast::Expr *> replValues;
2401 if (consumeIf(kind: Token::l_paren)) {
2402 if (consumeIf(kind: Token::r_paren)) {
2403 return emitError(
2404 loc, msg: "expected at least one replacement value, consider using "
2405 "`erase` if no replacement values are desired");
2406 }
2407
2408 do {
2409 FailureOr<ast::Expr *> replExpr = parseExpr();
2410 if (failed(Result: replExpr))
2411 return failure();
2412 replValues.emplace_back(Args&: *replExpr);
2413 } while (consumeIf(kind: Token::comma));
2414
2415 if (failed(Result: parseToken(kind: Token::r_paren,
2416 msg: "expected `)` after replacement values")))
2417 return failure();
2418 } else {
2419 // Handle replacement with an operation uniquely, as the replacement
2420 // operation supports type inferrence from the root operation.
2421 FailureOr<ast::Expr *> replExpr;
2422 if (curToken.is(k: Token::kw_op))
2423 replExpr = parseOperationExpr(inputResultTypeContext: OpResultTypeContext::Replacement);
2424 else
2425 replExpr = parseExpr();
2426 if (failed(Result: replExpr))
2427 return failure();
2428 replValues.emplace_back(Args&: *replExpr);
2429 }
2430
2431 return createReplaceStmt(loc, rootOp: *rootOp, replValues);
2432}
2433
2434FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2435 SMRange loc = curToken.getLoc();
2436 consumeToken(kind: Token::kw_return);
2437
2438 // Parse the result value.
2439 FailureOr<ast::Expr *> resultExpr = parseExpr();
2440 if (failed(Result: resultExpr))
2441 return failure();
2442
2443 return ast::ReturnStmt::create(ctx, loc, resultExpr: *resultExpr);
2444}
2445
2446FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2447 if (parserContext == ParserContext::Constraint)
2448 return emitError(msg: "`rewrite` cannot be used within a Constraint");
2449 SMRange loc = curToken.getLoc();
2450 consumeToken(kind: Token::kw_rewrite);
2451
2452 // Parse the root operation.
2453 FailureOr<ast::Expr *> rootOp = parseExpr();
2454 if (failed(Result: rootOp))
2455 return failure();
2456
2457 if (failed(Result: parseToken(kind: Token::kw_with, msg: "expected `with` before rewrite body")))
2458 return failure();
2459
2460 if (curToken.isNot(k: Token::l_brace))
2461 return emitError(msg: "expected `{` to start rewrite body");
2462
2463 // The rewrite body of this statement is within a rewrite context.
2464 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2465
2466 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2467 if (failed(Result: rewriteBody))
2468 return failure();
2469
2470 // Verify the rewrite body.
2471 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2472 if (isa<ast::ReturnStmt>(Val: stmt)) {
2473 return emitError(loc: stmt->getLoc(),
2474 msg: "`return` statements are only permitted within a "
2475 "`Constraint` or `Rewrite` body");
2476 }
2477 }
2478
2479 return createRewriteStmt(loc, rootOp: *rootOp, rewriteBody: *rewriteBody);
2480}
2481
2482//===----------------------------------------------------------------------===//
2483// Creation+Analysis
2484//===----------------------------------------------------------------------===//
2485
2486//===----------------------------------------------------------------------===//
2487// Decls
2488//===----------------------------------------------------------------------===//
2489
2490ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2491 // Unwrap reference expressions.
2492 if (auto *init = dyn_cast<ast::DeclRefExpr>(Val: node))
2493 node = init->getDecl();
2494 return dyn_cast<ast::CallableDecl>(Val: node);
2495}
2496
2497FailureOr<ast::PatternDecl *>
2498Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2499 const ParsedPatternMetadata &metadata,
2500 ast::CompoundStmt *body) {
2501 return ast::PatternDecl::create(ctx, location: loc, name, benefit: metadata.benefit,
2502 hasBoundedRecursion: metadata.hasBoundedRecursion, body);
2503}
2504
2505ast::Type Parser::createUserConstraintRewriteResultType(
2506 ArrayRef<ast::VariableDecl *> results) {
2507 // Single result decls use the type of the single result.
2508 if (results.size() == 1)
2509 return results[0]->getType();
2510
2511 // Multiple results use a tuple type, with the types and names grabbed from
2512 // the result variable decls.
2513 auto resultTypes = llvm::map_range(
2514 C&: results, F: [&](const auto *result) { return result->getType(); });
2515 auto resultNames = llvm::map_range(
2516 C&: results, F: [&](const auto *result) { return result->getName().getName(); });
2517 return ast::TupleType::get(context&: ctx, elementTypes: llvm::to_vector(Range&: resultTypes),
2518 elementNames: llvm::to_vector(Range&: resultNames));
2519}
2520
2521template <typename T>
2522FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2523 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2524 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2525 ast::CompoundStmt *body) {
2526 if (!body->getChildren().empty()) {
2527 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(Val: body->getChildren().back())) {
2528 ast::Expr *resultExpr = retStmt->getResultExpr();
2529
2530 // Process the result of the decl. If no explicit signature results
2531 // were provided, check for return type inference. Otherwise, check that
2532 // the return expression can be converted to the expected type.
2533 if (results.empty())
2534 resultType = resultExpr->getType();
2535 else if (failed(Result: convertExpressionTo(expr&: resultExpr, type: resultType)))
2536 return failure();
2537 else
2538 retStmt->setResultExpr(resultExpr);
2539 }
2540 }
2541 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2542}
2543
2544FailureOr<ast::VariableDecl *>
2545Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2546 ArrayRef<ast::ConstraintRef> constraints) {
2547 // The type of the variable, which is expected to be inferred by either a
2548 // constraint or an initializer expression.
2549 ast::Type type;
2550 if (failed(Result: validateVariableConstraints(constraints, inferredType&: type)))
2551 return failure();
2552
2553 if (initializer) {
2554 // Update the variable type based on the initializer, or try to convert the
2555 // initializer to the existing type.
2556 if (!type)
2557 type = initializer->getType();
2558 else if (ast::Type mergedType = type.refineWith(other: initializer->getType()))
2559 type = mergedType;
2560 else if (failed(Result: convertExpressionTo(expr&: initializer, type)))
2561 return failure();
2562
2563 // Otherwise, if there is no initializer check that the type has already
2564 // been resolved from the constraint list.
2565 } else if (!type) {
2566 return emitErrorAndNote(
2567 loc, msg: "unable to infer type for variable `" + name + "`", noteLoc: loc,
2568 note: "the type of a variable must be inferable from the constraint "
2569 "list or the initializer");
2570 }
2571
2572 // Constraint types cannot be used when defining variables.
2573 if (isa<ast::ConstraintType, ast::RewriteType>(Val: type)) {
2574 return emitError(
2575 loc, msg: llvm::formatv(Fmt: "unable to define variable of `{0}` type", Vals&: type));
2576 }
2577
2578 // Try to define a variable with the given name.
2579 FailureOr<ast::VariableDecl *> varDecl =
2580 defineVariableDecl(name, nameLoc: loc, type, initExpr: initializer, constraints);
2581 if (failed(Result: varDecl))
2582 return failure();
2583
2584 return *varDecl;
2585}
2586
2587FailureOr<ast::VariableDecl *>
2588Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2589 const ast::ConstraintRef &constraint) {
2590 ast::Type argType;
2591 if (failed(Result: validateVariableConstraint(ref: constraint, inferredType&: argType)))
2592 return failure();
2593 return defineVariableDecl(name, nameLoc: loc, type: argType, constraints: constraint);
2594}
2595
2596LogicalResult
2597Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2598 ast::Type &inferredType) {
2599 for (const ast::ConstraintRef &ref : constraints)
2600 if (failed(Result: validateVariableConstraint(ref, inferredType)))
2601 return failure();
2602 return success();
2603}
2604
2605LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2606 ast::Type &inferredType) {
2607 ast::Type constraintType;
2608 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(Val: ref.constraint)) {
2609 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2610 if (failed(Result: validateTypeConstraintExpr(typeExpr)))
2611 return failure();
2612 }
2613 constraintType = ast::AttributeType::get(context&: ctx);
2614 } else if (const auto *cst =
2615 dyn_cast<ast::OpConstraintDecl>(Val: ref.constraint)) {
2616 constraintType = ast::OperationType::get(
2617 context&: ctx, name: cst->getName(), odsOp: lookupODSOperation(opName: cst->getName()));
2618 } else if (isa<ast::TypeConstraintDecl>(Val: ref.constraint)) {
2619 constraintType = typeTy;
2620 } else if (isa<ast::TypeRangeConstraintDecl>(Val: ref.constraint)) {
2621 constraintType = typeRangeTy;
2622 } else if (const auto *cst =
2623 dyn_cast<ast::ValueConstraintDecl>(Val: ref.constraint)) {
2624 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2625 if (failed(Result: validateTypeConstraintExpr(typeExpr)))
2626 return failure();
2627 }
2628 constraintType = valueTy;
2629 } else if (const auto *cst =
2630 dyn_cast<ast::ValueRangeConstraintDecl>(Val: ref.constraint)) {
2631 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2632 if (failed(Result: validateTypeRangeConstraintExpr(typeExpr)))
2633 return failure();
2634 }
2635 constraintType = valueRangeTy;
2636 } else if (const auto *cst =
2637 dyn_cast<ast::UserConstraintDecl>(Val: ref.constraint)) {
2638 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2639 if (inputs.size() != 1) {
2640 return emitErrorAndNote(loc: ref.referenceLoc,
2641 msg: "`Constraint`s applied via a variable constraint "
2642 "list must take a single input, but got " +
2643 Twine(inputs.size()),
2644 noteLoc: cst->getLoc(),
2645 note: "see definition of constraint here");
2646 }
2647 constraintType = inputs.front()->getType();
2648 } else {
2649 llvm_unreachable("unknown constraint type");
2650 }
2651
2652 // Check that the constraint type is compatible with the current inferred
2653 // type.
2654 if (!inferredType) {
2655 inferredType = constraintType;
2656 } else if (ast::Type mergedTy = inferredType.refineWith(other: constraintType)) {
2657 inferredType = mergedTy;
2658 } else {
2659 return emitError(loc: ref.referenceLoc,
2660 msg: llvm::formatv(Fmt: "constraint type `{0}` is incompatible "
2661 "with the previously inferred type `{1}`",
2662 Vals&: constraintType, Vals&: inferredType));
2663 }
2664 return success();
2665}
2666
2667LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2668 ast::Type typeExprType = typeExpr->getType();
2669 if (typeExprType != typeTy) {
2670 return emitError(loc: typeExpr->getLoc(),
2671 msg: "expected expression of `Type` in type constraint");
2672 }
2673 return success();
2674}
2675
2676LogicalResult
2677Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2678 ast::Type typeExprType = typeExpr->getType();
2679 if (typeExprType != typeRangeTy) {
2680 return emitError(loc: typeExpr->getLoc(),
2681 msg: "expected expression of `TypeRange` in type constraint");
2682 }
2683 return success();
2684}
2685
2686//===----------------------------------------------------------------------===//
2687// Exprs
2688//===----------------------------------------------------------------------===//
2689
2690FailureOr<ast::CallExpr *>
2691Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2692 MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2693 ast::Type parentType = parentExpr->getType();
2694
2695 ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parentExpr);
2696 if (!callableDecl) {
2697 return emitError(loc,
2698 msg: llvm::formatv(Fmt: "expected a reference to a callable "
2699 "`Constraint` or `Rewrite`, but got: `{0}`",
2700 Vals&: parentType));
2701 }
2702 if (parserContext == ParserContext::Rewrite) {
2703 if (isa<ast::UserConstraintDecl>(Val: callableDecl))
2704 return emitError(
2705 loc, msg: "unable to invoke `Constraint` within a rewrite section");
2706 if (isNegated)
2707 return emitError(loc, msg: "unable to negate a Rewrite");
2708 } else {
2709 if (isa<ast::UserRewriteDecl>(Val: callableDecl))
2710 return emitError(loc,
2711 msg: "unable to invoke `Rewrite` within a match section");
2712 if (isNegated && cast<ast::UserConstraintDecl>(Val: callableDecl)->getBody())
2713 return emitError(loc, msg: "unable to negate non native constraints");
2714 }
2715
2716 // Verify the arguments of the call.
2717 /// Handle size mismatch.
2718 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2719 if (callArgs.size() != arguments.size()) {
2720 return emitErrorAndNote(
2721 loc,
2722 msg: llvm::formatv(Fmt: "invalid number of arguments for {0} call; expected "
2723 "{1}, but got {2}",
2724 Vals: callableDecl->getCallableType(), Vals: callArgs.size(),
2725 Vals: arguments.size()),
2726 noteLoc: callableDecl->getLoc(),
2727 note: llvm::formatv(Fmt: "see the definition of {0} here",
2728 Vals: callableDecl->getName()->getName()));
2729 }
2730
2731 /// Handle argument type mismatch.
2732 auto attachDiagFn = [&](ast::Diagnostic &diag) {
2733 diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here",
2734 Vals: callableDecl->getName()->getName()),
2735 noteLoc: callableDecl->getLoc());
2736 };
2737 for (auto it : llvm::zip(t&: callArgs, u&: arguments)) {
2738 if (failed(Result: convertExpressionTo(expr&: std::get<1>(t&: it), type: std::get<0>(t&: it)->getType(),
2739 noteAttachFn: attachDiagFn)))
2740 return failure();
2741 }
2742
2743 return ast::CallExpr::create(ctx, loc, callable: parentExpr, arguments,
2744 resultType: callableDecl->getResultType(), isNegated);
2745}
2746
2747FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2748 ast::Decl *decl) {
2749 // Check the type of decl being referenced.
2750 ast::Type declType;
2751 if (isa<ast::ConstraintDecl>(Val: decl))
2752 declType = ast::ConstraintType::get(context&: ctx);
2753 else if (isa<ast::UserRewriteDecl>(Val: decl))
2754 declType = ast::RewriteType::get(context&: ctx);
2755 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl))
2756 declType = varDecl->getType();
2757 else
2758 return emitError(loc, msg: "invalid reference to `" +
2759 decl->getName()->getName() + "`");
2760
2761 return ast::DeclRefExpr::create(ctx, loc, decl, type: declType);
2762}
2763
2764FailureOr<ast::DeclRefExpr *>
2765Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2766 ArrayRef<ast::ConstraintRef> constraints) {
2767 FailureOr<ast::VariableDecl *> decl =
2768 defineVariableDecl(name, nameLoc: loc, type, constraints);
2769 if (failed(Result: decl))
2770 return failure();
2771 return ast::DeclRefExpr::create(ctx, loc, decl: *decl, type);
2772}
2773
2774FailureOr<ast::MemberAccessExpr *>
2775Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2776 SMRange loc) {
2777 // Validate the member name for the given parent expression.
2778 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2779 if (failed(Result: memberType))
2780 return failure();
2781
2782 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, memberName: name, type: *memberType);
2783}
2784
2785FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2786 StringRef name, SMRange loc) {
2787 ast::Type parentType = parentExpr->getType();
2788 if (ast::OperationType opType = dyn_cast<ast::OperationType>(Val&: parentType)) {
2789 if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2790 return valueRangeTy;
2791
2792 // Verify member access based on the operation type.
2793 if (const ods::Operation *odsOp = opType.getODSOperation()) {
2794 auto results = odsOp->getResults();
2795
2796 // Handle indexed results.
2797 unsigned index = 0;
2798 if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) &&
2799 index < results.size()) {
2800 return results[index].isVariadic() ? valueRangeTy : valueTy;
2801 }
2802
2803 // Handle named results.
2804 const auto *it = llvm::find_if(Range&: results, P: [&](const auto &result) {
2805 return result.getName() == name;
2806 });
2807 if (it != results.end())
2808 return it->isVariadic() ? valueRangeTy : valueTy;
2809 } else if (llvm::isDigit(C: name[0])) {
2810 // Allow unchecked numeric indexing of the results of unregistered
2811 // operations. It returns a single value.
2812 return valueTy;
2813 }
2814 } else if (auto tupleType = dyn_cast<ast::TupleType>(Val&: parentType)) {
2815 // Handle indexed results.
2816 unsigned index = 0;
2817 if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) &&
2818 index < tupleType.size()) {
2819 return tupleType.getElementTypes()[index];
2820 }
2821
2822 // Handle named results.
2823 auto elementNames = tupleType.getElementNames();
2824 const auto *it = llvm::find(Range&: elementNames, Val: name);
2825 if (it != elementNames.end())
2826 return tupleType.getElementTypes()[it - elementNames.begin()];
2827 }
2828 return emitError(
2829 loc,
2830 msg: llvm::formatv(Fmt: "invalid member access `{0}` on expression of type `{1}`",
2831 Vals&: name, Vals&: parentType));
2832}
2833
2834FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2835 SMRange loc, const ast::OpNameDecl *name,
2836 OpResultTypeContext resultTypeContext,
2837 SmallVectorImpl<ast::Expr *> &operands,
2838 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2839 SmallVectorImpl<ast::Expr *> &results) {
2840 std::optional<StringRef> opNameRef = name->getName();
2841 const ods::Operation *odsOp = lookupODSOperation(opName: opNameRef);
2842
2843 // Verify the inputs operands.
2844 if (failed(Result: validateOperationOperands(loc, name: opNameRef, odsOp, operands)))
2845 return failure();
2846
2847 // Verify the attribute list.
2848 for (ast::NamedAttributeDecl *attr : attributes) {
2849 // Check for an attribute type, or a type awaiting resolution.
2850 ast::Type attrType = attr->getValue()->getType();
2851 if (!isa<ast::AttributeType>(Val: attrType)) {
2852 return emitError(
2853 loc: attr->getValue()->getLoc(),
2854 msg: llvm::formatv(Fmt: "expected `Attr` expression, but got `{0}`", Vals&: attrType));
2855 }
2856 }
2857
2858 assert(
2859 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2860 "unexpected inferrence when results were explicitly specified");
2861
2862 // If we aren't relying on type inferrence, or explicit results were provided,
2863 // validate them.
2864 if (resultTypeContext == OpResultTypeContext::Explicit) {
2865 if (failed(Result: validateOperationResults(loc, name: opNameRef, odsOp, results)))
2866 return failure();
2867
2868 // Validate the use of interface based type inferrence for this operation.
2869 } else if (resultTypeContext == OpResultTypeContext::Interface) {
2870 assert(opNameRef &&
2871 "expected valid operation name when inferring operation results");
2872 checkOperationResultTypeInferrence(loc, name: *opNameRef, odsOp);
2873 }
2874
2875 return ast::OperationExpr::create(ctx, loc, odsOp, nameDecl: name, operands, resultTypes: results,
2876 attributes);
2877}
2878
2879LogicalResult
2880Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2881 const ods::Operation *odsOp,
2882 SmallVectorImpl<ast::Expr *> &operands) {
2883 return validateOperationOperandsOrResults(
2884 groupName: "operand", loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2885 values&: operands,
2886 odsValues: odsOp ? odsOp->getOperands() : ArrayRef<pdll::ods::OperandOrResult>(),
2887 singleTy: valueTy, rangeTy: valueRangeTy);
2888}
2889
2890LogicalResult
2891Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2892 const ods::Operation *odsOp,
2893 SmallVectorImpl<ast::Expr *> &results) {
2894 return validateOperationOperandsOrResults(
2895 groupName: "result", loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2896 values&: results,
2897 odsValues: odsOp ? odsOp->getResults() : ArrayRef<pdll::ods::OperandOrResult>(),
2898 singleTy: typeTy, rangeTy: typeRangeTy);
2899}
2900
2901void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2902 const ods::Operation *odsOp) {
2903 // If the operation might not have inferrence support, emit a warning to the
2904 // user. We don't emit an error because the interface might be added to the
2905 // operation at runtime. It's rare, but it could still happen. We emit a
2906 // warning here instead.
2907
2908 // Handle inferrence warnings for unknown operations.
2909 if (!odsOp) {
2910 ctx.getDiagEngine().emitWarning(
2911 loc, msg: llvm::formatv(
2912 Fmt: "operation result types are marked to be inferred, but "
2913 "`{0}` is unknown. Ensure that `{0}` supports zero "
2914 "results or implements `InferTypeOpInterface`. Include "
2915 "the ODS definition of this operation to remove this warning.",
2916 Vals&: opName));
2917 return;
2918 }
2919
2920 // Handle inferrence warnings for known operations that expected at least one
2921 // result, but don't have inference support. An elided results list can mean
2922 // "zero-results", and we don't want to warn when that is the expected
2923 // behavior.
2924 bool requiresInferrence =
2925 llvm::any_of(Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) {
2926 return !result.isVariableLength();
2927 });
2928 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2929 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2930 loc,
2931 msg: llvm::formatv(Fmt: "operation result types are marked to be inferred, but "
2932 "`{0}` does not provide an implementation of "
2933 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2934 "`InferTypeOpInterface` at runtime, or add support to "
2935 "the ODS definition to remove this warning.",
2936 Vals&: opName));
2937 diag->attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: opName),
2938 noteLoc: odsOp->getLoc());
2939 return;
2940 }
2941}
2942
2943LogicalResult Parser::validateOperationOperandsOrResults(
2944 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2945 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2946 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2947 ast::RangeType rangeTy) {
2948 // All operation types accept a single range parameter.
2949 if (values.size() == 1) {
2950 if (failed(Result: convertExpressionTo(expr&: values[0], type: rangeTy)))
2951 return failure();
2952 return success();
2953 }
2954
2955 /// If the operation has ODS information, we can more accurately verify the
2956 /// values.
2957 if (odsOpLoc) {
2958 auto emitSizeMismatchError = [&] {
2959 return emitErrorAndNote(
2960 loc,
2961 msg: llvm::formatv(Fmt: "invalid number of {0} groups for `{1}`; expected "
2962 "{2}, but got {3}",
2963 Vals&: groupName, Vals&: *name, Vals: odsValues.size(), Vals: values.size()),
2964 noteLoc: *odsOpLoc, note: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: *name));
2965 };
2966
2967 // Handle the case where no values were provided.
2968 if (values.empty()) {
2969 // If we don't expect any on the ODS side, we are done.
2970 if (odsValues.empty())
2971 return success();
2972
2973 // If we do, check if we actually need to provide values (i.e. if any of
2974 // the values are actually required).
2975 unsigned numVariadic = 0;
2976 for (const auto &odsValue : odsValues) {
2977 if (!odsValue.isVariableLength())
2978 return emitSizeMismatchError();
2979 ++numVariadic;
2980 }
2981
2982 // If we are in a non-rewrite context, we don't need to do anything more.
2983 // Zero-values is a valid constraint on the operation.
2984 if (parserContext != ParserContext::Rewrite)
2985 return success();
2986
2987 // Otherwise, when in a rewrite we may need to provide values to match the
2988 // ODS signature of the operation to create.
2989
2990 // If we only have one variadic value, just use an empty list.
2991 if (numVariadic == 1)
2992 return success();
2993
2994 // Otherwise, create dummy values for each of the entries so that we
2995 // adhere to the ODS signature.
2996 for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2997 values.push_back(
2998 Elt: ast::RangeExpr::create(ctx, loc, /*elements=*/{}, type: rangeTy));
2999 }
3000 return success();
3001 }
3002
3003 // Verify that the number of values provided matches the number of value
3004 // groups ODS expects.
3005 if (odsValues.size() != values.size())
3006 return emitSizeMismatchError();
3007
3008 auto diagFn = [&](ast::Diagnostic &diag) {
3009 diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: *name),
3010 noteLoc: *odsOpLoc);
3011 };
3012 for (unsigned i = 0, e = values.size(); i < e; ++i) {
3013 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3014 if (failed(Result: convertExpressionTo(expr&: values[i], type: expectedType, noteAttachFn: diagFn)))
3015 return failure();
3016 }
3017 return success();
3018 }
3019
3020 // Otherwise, accept the value groups as they have been defined and just
3021 // ensure they are one of the expected types.
3022 for (ast::Expr *&valueExpr : values) {
3023 ast::Type valueExprType = valueExpr->getType();
3024
3025 // Check if this is one of the expected types.
3026 if (valueExprType == rangeTy || valueExprType == singleTy)
3027 continue;
3028
3029 // If the operand is an Operation, allow converting to a Value or
3030 // ValueRange. This situations arises quite often with nested operation
3031 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3032 if (singleTy == valueTy) {
3033 if (isa<ast::OperationType>(Val: valueExprType)) {
3034 valueExpr = convertOpToValue(opExpr: valueExpr);
3035 continue;
3036 }
3037 }
3038
3039 // Otherwise, try to convert the expression to a range.
3040 if (succeeded(Result: convertExpressionTo(expr&: valueExpr, type: rangeTy)))
3041 continue;
3042
3043 return emitError(
3044 loc: valueExpr->getLoc(),
3045 msg: llvm::formatv(
3046 Fmt: "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3047 Vals&: singleTy, Vals&: rangeTy, Vals&: valueExprType));
3048 }
3049 return success();
3050}
3051
3052FailureOr<ast::TupleExpr *>
3053Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3054 ArrayRef<StringRef> elementNames) {
3055 for (const ast::Expr *element : elements) {
3056 ast::Type eleTy = element->getType();
3057 if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(Val: eleTy)) {
3058 return emitError(
3059 loc: element->getLoc(),
3060 msg: llvm::formatv(Fmt: "unable to build a tuple with `{0}` element", Vals&: eleTy));
3061 }
3062 }
3063 return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3064}
3065
3066//===----------------------------------------------------------------------===//
3067// Stmts
3068//===----------------------------------------------------------------------===//
3069
3070FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3071 ast::Expr *rootOp) {
3072 // Check that root is an Operation.
3073 ast::Type rootType = rootOp->getType();
3074 if (!isa<ast::OperationType>(Val: rootType))
3075 return emitError(loc: rootOp->getLoc(), msg: "expected `Op` expression");
3076
3077 return ast::EraseStmt::create(ctx, loc, rootOp);
3078}
3079
3080FailureOr<ast::ReplaceStmt *>
3081Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3082 MutableArrayRef<ast::Expr *> replValues) {
3083 // Check that root is an Operation.
3084 ast::Type rootType = rootOp->getType();
3085 if (!isa<ast::OperationType>(Val: rootType)) {
3086 return emitError(
3087 loc: rootOp->getLoc(),
3088 msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`", Vals&: rootType));
3089 }
3090
3091 // If there are multiple replacement values, we implicitly convert any Op
3092 // expressions to the value form.
3093 bool shouldConvertOpToValues = replValues.size() > 1;
3094 for (ast::Expr *&replExpr : replValues) {
3095 ast::Type replType = replExpr->getType();
3096
3097 // Check that replExpr is an Operation, Value, or ValueRange.
3098 if (isa<ast::OperationType>(Val: replType)) {
3099 if (shouldConvertOpToValues)
3100 replExpr = convertOpToValue(opExpr: replExpr);
3101 continue;
3102 }
3103
3104 if (replType != valueTy && replType != valueRangeTy) {
3105 return emitError(loc: replExpr->getLoc(),
3106 msg: llvm::formatv(Fmt: "expected `Op`, `Value` or `ValueRange` "
3107 "expression, but got `{0}`",
3108 Vals&: replType));
3109 }
3110 }
3111
3112 return ast::ReplaceStmt::create(ctx, loc, rootOp, replExprs: replValues);
3113}
3114
3115FailureOr<ast::RewriteStmt *>
3116Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3117 ast::CompoundStmt *rewriteBody) {
3118 // Check that root is an Operation.
3119 ast::Type rootType = rootOp->getType();
3120 if (!isa<ast::OperationType>(Val: rootType)) {
3121 return emitError(
3122 loc: rootOp->getLoc(),
3123 msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`", Vals&: rootType));
3124 }
3125
3126 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3127}
3128
3129//===----------------------------------------------------------------------===//
3130// Code Completion
3131//===----------------------------------------------------------------------===//
3132
3133LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3134 ast::Type parentType = parentExpr->getType();
3135 if (ast::OperationType opType = dyn_cast<ast::OperationType>(Val&: parentType))
3136 codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3137 else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(Val&: parentType))
3138 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3139 return failure();
3140}
3141
3142LogicalResult
3143Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3144 if (opName)
3145 codeCompleteContext->codeCompleteOperationAttributeName(opName: *opName);
3146 return failure();
3147}
3148
3149LogicalResult
3150Parser::codeCompleteConstraintName(ast::Type inferredType,
3151 bool allowInlineTypeConstraints) {
3152 codeCompleteContext->codeCompleteConstraintName(
3153 currentType: inferredType, allowInlineTypeConstraints, scope: curDeclScope);
3154 return failure();
3155}
3156
3157LogicalResult Parser::codeCompleteDialectName() {
3158 codeCompleteContext->codeCompleteDialectName();
3159 return failure();
3160}
3161
3162LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3163 codeCompleteContext->codeCompleteOperationName(dialectName);
3164 return failure();
3165}
3166
3167LogicalResult Parser::codeCompletePatternMetadata() {
3168 codeCompleteContext->codeCompletePatternMetadata();
3169 return failure();
3170}
3171
3172LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3173 codeCompleteContext->codeCompleteIncludeFilename(curPath);
3174 return failure();
3175}
3176
3177void Parser::codeCompleteCallSignature(ast::Node *parent,
3178 unsigned currentNumArgs) {
3179 ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parent);
3180 if (!callableDecl)
3181 return;
3182
3183 codeCompleteContext->codeCompleteCallSignature(callable: callableDecl, currentNumArgs);
3184}
3185
3186void Parser::codeCompleteOperationOperandsSignature(
3187 std::optional<StringRef> opName, unsigned currentNumOperands) {
3188 codeCompleteContext->codeCompleteOperationOperandsSignature(
3189 opName, currentNumOperands);
3190}
3191
3192void Parser::codeCompleteOperationResultsSignature(
3193 std::optional<StringRef> opName, unsigned currentNumResults) {
3194 codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3195 currentNumResults);
3196}
3197
3198//===----------------------------------------------------------------------===//
3199// Parser
3200//===----------------------------------------------------------------------===//
3201
3202FailureOr<ast::Module *>
3203mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3204 bool enableDocumentation,
3205 CodeCompleteContext *codeCompleteContext) {
3206 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3207 return parser.parseModule();
3208}
3209

source code of mlir/lib/Tools/PDLL/Parser/Parser.cpp