1//===- Parser.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Tools/PDLL/Parser/Parser.h"
10#include "Lexer.h"
11#include "mlir/Support/IndentedOstream.h"
12#include "mlir/Support/LogicalResult.h"
13#include "mlir/TableGen/Argument.h"
14#include "mlir/TableGen/Attribute.h"
15#include "mlir/TableGen/Constraint.h"
16#include "mlir/TableGen/Format.h"
17#include "mlir/TableGen/Operator.h"
18#include "mlir/Tools/PDLL/AST/Context.h"
19#include "mlir/Tools/PDLL/AST/Diagnostic.h"
20#include "mlir/Tools/PDLL/AST/Nodes.h"
21#include "mlir/Tools/PDLL/AST/Types.h"
22#include "mlir/Tools/PDLL/ODS/Constraint.h"
23#include "mlir/Tools/PDLL/ODS/Context.h"
24#include "mlir/Tools/PDLL/ODS/Operation.h"
25#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/FormatVariadic.h"
29#include "llvm/Support/ManagedStatic.h"
30#include "llvm/Support/SaveAndRestore.h"
31#include "llvm/Support/ScopedPrinter.h"
32#include "llvm/TableGen/Error.h"
33#include "llvm/TableGen/Parser.h"
34#include <optional>
35#include <string>
36
37using namespace mlir;
38using namespace mlir::pdll;
39
40//===----------------------------------------------------------------------===//
41// Parser
42//===----------------------------------------------------------------------===//
43
44namespace {
45class Parser {
46public:
47 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
48 bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
49 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
50 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
51 typeTy(ast::TypeType::get(context&: ctx)), valueTy(ast::ValueType::get(context&: ctx)),
52 typeRangeTy(ast::TypeRangeType::get(context&: ctx)),
53 valueRangeTy(ast::ValueRangeType::get(context&: ctx)),
54 attrTy(ast::AttributeType::get(context&: ctx)),
55 codeCompleteContext(codeCompleteContext) {}
56
57 /// Try to parse a new module. Returns nullptr in the case of failure.
58 FailureOr<ast::Module *> parseModule();
59
60private:
61 /// The current context of the parser. It allows for the parser to know a bit
62 /// about the construct it is nested within during parsing. This is used
63 /// specifically to provide additional verification during parsing, e.g. to
64 /// prevent using rewrites within a match context, matcher constraints within
65 /// a rewrite section, etc.
66 enum class ParserContext {
67 /// The parser is in the global context.
68 Global,
69 /// The parser is currently within a Constraint, which disallows all types
70 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71 Constraint,
72 /// The parser is currently within the matcher portion of a Pattern, which
73 /// is allows a terminal operation rewrite statement but no other rewrite
74 /// transformations.
75 PatternMatch,
76 /// The parser is currently within a Rewrite, which disallows calls to
77 /// constraints, requires operation expressions to have names, etc.
78 Rewrite,
79 };
80
81 /// The current specification context of an operations result type. This
82 /// indicates how the result types of an operation may be inferred.
83 enum class OpResultTypeContext {
84 /// The result types of the operation are not known to be inferred.
85 Explicit,
86 /// The result types of the operation are inferred from the root input of a
87 /// `replace` statement.
88 Replacement,
89 /// The result types of the operation are inferred by using the
90 /// `InferTypeOpInterface` interface provided by the operation.
91 Interface,
92 };
93
94 //===--------------------------------------------------------------------===//
95 // Parsing
96 //===--------------------------------------------------------------------===//
97
98 /// Push a new decl scope onto the lexer.
99 ast::DeclScope *pushDeclScope() {
100 ast::DeclScope *newScope =
101 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102 return (curDeclScope = newScope);
103 }
104 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
105
106 /// Pop the last decl scope from the lexer.
107 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108
109 /// Parse the body of an AST module.
110 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111
112 /// Try to convert the given expression to `type`. Returns failure and emits
113 /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114 /// invoked to attach notes to the emitted error diagnostic. On success,
115 /// `expr` is updated to the expression used to convert to `type`.
116 LogicalResult convertExpressionTo(
117 ast::Expr *&expr, ast::Type type,
118 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
119 LogicalResult
120 convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
121 ast::Type type,
122 function_ref<ast::InFlightDiagnostic()> emitErrorFn);
123 LogicalResult convertTupleExpressionTo(
124 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
125 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
126 function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
127
128 /// Given an operation expression, convert it to a Value or ValueRange
129 /// typed expression.
130 ast::Expr *convertOpToValue(const ast::Expr *opExpr);
131
132 /// Lookup ODS information for the given operation, returns nullptr if no
133 /// information is found.
134 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
135 return opName ? ctx.getODSContext().lookupOperation(name: *opName) : nullptr;
136 }
137
138 /// Process the given documentation string, or return an empty string if
139 /// documentation isn't enabled.
140 StringRef processDoc(StringRef doc) {
141 return enableDocumentation ? doc : StringRef();
142 }
143
144 /// Process the given documentation string and format it, or return an empty
145 /// string if documentation isn't enabled.
146 std::string processAndFormatDoc(const Twine &doc) {
147 if (!enableDocumentation)
148 return "";
149 std::string docStr;
150 {
151 llvm::raw_string_ostream docOS(docStr);
152 std::string tmpDocStr = doc.str();
153 raw_indented_ostream(docOS).printReindented(
154 str: StringRef(tmpDocStr).rtrim(Chars: " \t"));
155 }
156 return docStr;
157 }
158
159 //===--------------------------------------------------------------------===//
160 // Directives
161
162 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
163 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
164 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
165 SmallVectorImpl<ast::Decl *> &decls);
166
167 /// Process the records of a parsed tablegen include file.
168 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
169 SmallVectorImpl<ast::Decl *> &decls);
170
171 /// Create a user defined native constraint for a constraint imported from
172 /// ODS.
173 template <typename ConstraintT>
174 ast::Decl *
175 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
176 SMRange loc, ast::Type type,
177 StringRef nativeType, StringRef docString);
178 template <typename ConstraintT>
179 ast::Decl *
180 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
181 SMRange loc, ast::Type type,
182 StringRef nativeType);
183
184 //===--------------------------------------------------------------------===//
185 // Decls
186
187 /// This structure contains the set of pattern metadata that may be parsed.
188 struct ParsedPatternMetadata {
189 std::optional<uint16_t> benefit;
190 bool hasBoundedRecursion = false;
191 };
192
193 FailureOr<ast::Decl *> parseTopLevelDecl();
194 FailureOr<ast::NamedAttributeDecl *>
195 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
196
197 /// Parse an argument variable as part of the signature of a
198 /// UserConstraintDecl or UserRewriteDecl.
199 FailureOr<ast::VariableDecl *> parseArgumentDecl();
200
201 /// Parse a result variable as part of the signature of a UserConstraintDecl
202 /// or UserRewriteDecl.
203 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
204
205 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
206 /// defined in a non-global context.
207 FailureOr<ast::UserConstraintDecl *>
208 parseUserConstraintDecl(bool isInline = false);
209
210 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
211 /// non-global context, such as within a Pattern/Constraint/etc.
212 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
213
214 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
215 /// PDLL constructs.
216 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
217 const ast::Name &name, bool isInline,
218 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
219 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
220
221 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
222 /// defined in a non-global context.
223 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
224
225 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
226 /// non-global context, such as within a Pattern/Rewrite/etc.
227 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
228
229 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
230 /// PDLL constructs.
231 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
232 const ast::Name &name, bool isInline,
233 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
234 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
235
236 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
237 /// effectively the same syntax, and only differ on slight semantics (given
238 /// the different parsing contexts).
239 template <typename T, typename ParseUserPDLLDeclFnT>
240 FailureOr<T *> parseUserConstraintOrRewriteDecl(
241 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
242 StringRef anonymousNamePrefix, bool isInline);
243
244 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
245 /// These decls have effectively the same syntax.
246 template <typename T>
247 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
248 const ast::Name &name, bool isInline,
249 ArrayRef<ast::VariableDecl *> arguments,
250 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
251
252 /// Parse the functional signature (i.e. the arguments and results) of a
253 /// UserConstraintDecl or UserRewriteDecl.
254 LogicalResult parseUserConstraintOrRewriteSignature(
255 SmallVectorImpl<ast::VariableDecl *> &arguments,
256 SmallVectorImpl<ast::VariableDecl *> &results,
257 ast::DeclScope *&argumentScope, ast::Type &resultType);
258
259 /// Validate the return (which if present is specified by bodyIt) of a
260 /// UserConstraintDecl or UserRewriteDecl.
261 LogicalResult validateUserConstraintOrRewriteReturn(
262 StringRef declType, ast::CompoundStmt *body,
263 ArrayRef<ast::Stmt *>::iterator bodyIt,
264 ArrayRef<ast::Stmt *>::iterator bodyE,
265 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
266
267 FailureOr<ast::CompoundStmt *>
268 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
269 bool expectTerminalSemicolon = true);
270 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
271 FailureOr<ast::Decl *> parsePatternDecl();
272 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
273
274 /// Check to see if a decl has already been defined with the given name, if
275 /// one has emit and error and return failure. Returns success otherwise.
276 LogicalResult checkDefineNamedDecl(const ast::Name &name);
277
278 /// Try to define a variable decl with the given components, returns the
279 /// variable on success.
280 FailureOr<ast::VariableDecl *>
281 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
282 ast::Expr *initExpr,
283 ArrayRef<ast::ConstraintRef> constraints);
284 FailureOr<ast::VariableDecl *>
285 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
286 ArrayRef<ast::ConstraintRef> constraints);
287
288 /// Parse the constraint reference list for a variable decl.
289 LogicalResult parseVariableDeclConstraintList(
290 SmallVectorImpl<ast::ConstraintRef> &constraints);
291
292 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
293 FailureOr<ast::Expr *> parseTypeConstraintExpr();
294
295 /// Try to parse a single reference to a constraint. `typeConstraint` is the
296 /// location of a previously parsed type constraint for the entity that will
297 /// be constrained by the parsed constraint. `existingConstraints` are any
298 /// existing constraints that have already been parsed for the same entity
299 /// that will be constrained by this constraint. `allowInlineTypeConstraints`
300 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
301 FailureOr<ast::ConstraintRef>
302 parseConstraint(std::optional<SMRange> &typeConstraint,
303 ArrayRef<ast::ConstraintRef> existingConstraints,
304 bool allowInlineTypeConstraints);
305
306 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
307 /// argument or result variable. The constraints for these variables do not
308 /// allow inline type constraints, and only permit a single constraint.
309 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
310
311 //===--------------------------------------------------------------------===//
312 // Exprs
313
314 FailureOr<ast::Expr *> parseExpr();
315
316 /// Identifier expressions.
317 FailureOr<ast::Expr *> parseAttributeExpr();
318 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
319 bool isNegated = false);
320 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
321 FailureOr<ast::Expr *> parseIdentifierExpr();
322 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
323 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
324 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
325 FailureOr<ast::Expr *> parseNegatedExpr();
326 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
327 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
328 FailureOr<ast::Expr *>
329 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
330 OpResultTypeContext::Explicit);
331 FailureOr<ast::Expr *> parseTupleExpr();
332 FailureOr<ast::Expr *> parseTypeExpr();
333 FailureOr<ast::Expr *> parseUnderscoreExpr();
334
335 //===--------------------------------------------------------------------===//
336 // Stmts
337
338 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
339 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
340 FailureOr<ast::EraseStmt *> parseEraseStmt();
341 FailureOr<ast::LetStmt *> parseLetStmt();
342 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
343 FailureOr<ast::ReturnStmt *> parseReturnStmt();
344 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
345
346 //===--------------------------------------------------------------------===//
347 // Creation+Analysis
348 //===--------------------------------------------------------------------===//
349
350 //===--------------------------------------------------------------------===//
351 // Decls
352
353 /// Try to extract a callable from the given AST node. Returns nullptr on
354 /// failure.
355 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
356
357 /// Try to create a pattern decl with the given components, returning the
358 /// Pattern on success.
359 FailureOr<ast::PatternDecl *>
360 createPatternDecl(SMRange loc, const ast::Name *name,
361 const ParsedPatternMetadata &metadata,
362 ast::CompoundStmt *body);
363
364 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
365 /// of results, defined as part of the signature.
366 ast::Type
367 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
368
369 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
370 template <typename T>
371 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
372 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
373 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
374 ast::CompoundStmt *body);
375
376 /// Try to create a variable decl with the given components, returning the
377 /// Variable on success.
378 FailureOr<ast::VariableDecl *>
379 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
380 ArrayRef<ast::ConstraintRef> constraints);
381
382 /// Create a variable for an argument or result defined as part of the
383 /// signature of a UserConstraintDecl/UserRewriteDecl.
384 FailureOr<ast::VariableDecl *>
385 createArgOrResultVariableDecl(StringRef name, SMRange loc,
386 const ast::ConstraintRef &constraint);
387
388 /// Validate the constraints used to constraint a variable decl.
389 /// `inferredType` is the type of the variable inferred by the constraints
390 /// within the list, and is updated to the most refined type as determined by
391 /// the constraints. Returns success if the constraint list is valid, failure
392 /// otherwise.
393 LogicalResult
394 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
395 ast::Type &inferredType);
396 /// Validate a single reference to a constraint. `inferredType` contains the
397 /// currently inferred variabled type and is refined within the type defined
398 /// by the constraint. Returns success if the constraint is valid, failure
399 /// otherwise.
400 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
401 ast::Type &inferredType);
402 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
403 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
404
405 //===--------------------------------------------------------------------===//
406 // Exprs
407
408 FailureOr<ast::CallExpr *>
409 createCallExpr(SMRange loc, ast::Expr *parentExpr,
410 MutableArrayRef<ast::Expr *> arguments,
411 bool isNegated = false);
412 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
413 FailureOr<ast::DeclRefExpr *>
414 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
415 ArrayRef<ast::ConstraintRef> constraints);
416 FailureOr<ast::MemberAccessExpr *>
417 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
418
419 /// Validate the member access `name` into the given parent expression. On
420 /// success, this also returns the type of the member accessed.
421 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
422 StringRef name, SMRange loc);
423 FailureOr<ast::OperationExpr *>
424 createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
425 OpResultTypeContext resultTypeContext,
426 SmallVectorImpl<ast::Expr *> &operands,
427 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
428 SmallVectorImpl<ast::Expr *> &results);
429 LogicalResult
430 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
431 const ods::Operation *odsOp,
432 SmallVectorImpl<ast::Expr *> &operands);
433 LogicalResult validateOperationResults(SMRange loc,
434 std::optional<StringRef> name,
435 const ods::Operation *odsOp,
436 SmallVectorImpl<ast::Expr *> &results);
437 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
438 const ods::Operation *odsOp);
439 LogicalResult validateOperationOperandsOrResults(
440 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
441 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
442 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
443 ast::RangeType rangeTy);
444 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
445 ArrayRef<ast::Expr *> elements,
446 ArrayRef<StringRef> elementNames);
447
448 //===--------------------------------------------------------------------===//
449 // Stmts
450
451 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
452 FailureOr<ast::ReplaceStmt *>
453 createReplaceStmt(SMRange loc, ast::Expr *rootOp,
454 MutableArrayRef<ast::Expr *> replValues);
455 FailureOr<ast::RewriteStmt *>
456 createRewriteStmt(SMRange loc, ast::Expr *rootOp,
457 ast::CompoundStmt *rewriteBody);
458
459 //===--------------------------------------------------------------------===//
460 // Code Completion
461 //===--------------------------------------------------------------------===//
462
463 /// The set of various code completion methods. Every completion method
464 /// returns `failure` to stop the parsing process after providing completion
465 /// results.
466
467 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
468 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
469 LogicalResult codeCompleteConstraintName(ast::Type inferredType,
470 bool allowInlineTypeConstraints);
471 LogicalResult codeCompleteDialectName();
472 LogicalResult codeCompleteOperationName(StringRef dialectName);
473 LogicalResult codeCompletePatternMetadata();
474 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
475
476 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
477 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
478 unsigned currentNumOperands);
479 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
480 unsigned currentNumResults);
481
482 //===--------------------------------------------------------------------===//
483 // Lexer Utilities
484 //===--------------------------------------------------------------------===//
485
486 /// If the current token has the specified kind, consume it and return true.
487 /// If not, return false.
488 bool consumeIf(Token::Kind kind) {
489 if (curToken.isNot(k: kind))
490 return false;
491 consumeToken(kind);
492 return true;
493 }
494
495 /// Advance the current lexer onto the next token.
496 void consumeToken() {
497 assert(curToken.isNot(Token::eof, Token::error) &&
498 "shouldn't advance past EOF or errors");
499 curToken = lexer.lexToken();
500 }
501
502 /// Advance the current lexer onto the next token, asserting what the expected
503 /// current token is. This is preferred to the above method because it leads
504 /// to more self-documenting code with better checking.
505 void consumeToken(Token::Kind kind) {
506 assert(curToken.is(kind) && "consumed an unexpected token");
507 consumeToken();
508 }
509
510 /// Reset the lexer to the location at the given position.
511 void resetToken(SMRange tokLoc) {
512 lexer.resetPointer(newPointer: tokLoc.Start.getPointer());
513 curToken = lexer.lexToken();
514 }
515
516 /// Consume the specified token if present and return success. On failure,
517 /// output a diagnostic and return failure.
518 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
519 if (curToken.getKind() != kind)
520 return emitError(loc: curToken.getLoc(), msg);
521 consumeToken();
522 return success();
523 }
524 LogicalResult emitError(SMRange loc, const Twine &msg) {
525 lexer.emitError(loc, msg);
526 return failure();
527 }
528 LogicalResult emitError(const Twine &msg) {
529 return emitError(loc: curToken.getLoc(), msg);
530 }
531 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
532 const Twine &note) {
533 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
534 return failure();
535 }
536
537 //===--------------------------------------------------------------------===//
538 // Fields
539 //===--------------------------------------------------------------------===//
540
541 /// The owning AST context.
542 ast::Context &ctx;
543
544 /// The lexer of this parser.
545 Lexer lexer;
546
547 /// The current token within the lexer.
548 Token curToken;
549
550 /// A flag indicating if the parser should add documentation to AST nodes when
551 /// viable.
552 bool enableDocumentation;
553
554 /// The most recently defined decl scope.
555 ast::DeclScope *curDeclScope = nullptr;
556 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
557
558 /// The current context of the parser.
559 ParserContext parserContext = ParserContext::Global;
560
561 /// Cached types to simplify verification and expression creation.
562 ast::Type typeTy, valueTy;
563 ast::RangeType typeRangeTy, valueRangeTy;
564 ast::Type attrTy;
565
566 /// A counter used when naming anonymous constraints and rewrites.
567 unsigned anonymousDeclNameCounter = 0;
568
569 /// The optional code completion context.
570 CodeCompleteContext *codeCompleteContext;
571};
572} // namespace
573
574FailureOr<ast::Module *> Parser::parseModule() {
575 SMLoc moduleLoc = curToken.getStartLoc();
576 pushDeclScope();
577
578 // Parse the top-level decls of the module.
579 SmallVector<ast::Decl *> decls;
580 if (failed(result: parseModuleBody(decls)))
581 return popDeclScope(), failure();
582
583 popDeclScope();
584 return ast::Module::create(ctx, loc: moduleLoc, children: decls);
585}
586
587LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
588 while (curToken.isNot(k: Token::eof)) {
589 if (curToken.is(k: Token::directive)) {
590 if (failed(result: parseDirective(decls)))
591 return failure();
592 continue;
593 }
594
595 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
596 if (failed(result: decl))
597 return failure();
598 decls.push_back(Elt: *decl);
599 }
600 return success();
601}
602
603ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
604 return ast::AllResultsMemberAccessExpr::create(ctx, loc: opExpr->getLoc(), parentExpr: opExpr,
605 type: valueRangeTy);
606}
607
608LogicalResult Parser::convertExpressionTo(
609 ast::Expr *&expr, ast::Type type,
610 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
611 ast::Type exprType = expr->getType();
612 if (exprType == type)
613 return success();
614
615 auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
616 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
617 loc: expr->getLoc(), msg: llvm::formatv(Fmt: "unable to convert expression of type "
618 "`{0}` to the expected type of "
619 "`{1}`",
620 Vals&: exprType, Vals&: type));
621 if (noteAttachFn)
622 noteAttachFn(*diag);
623 return diag;
624 };
625
626 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
627 return convertOpExpressionTo(expr, exprType: exprOpType, type, emitErrorFn: emitConvertError);
628
629 // FIXME: Decide how to allow/support converting a single result to multiple,
630 // and multiple to a single result. For now, we just allow Single->Range,
631 // but this isn't something really supported in the PDL dialect. We should
632 // figure out some way to support both.
633 if ((exprType == valueTy || exprType == valueRangeTy) &&
634 (type == valueTy || type == valueRangeTy))
635 return success();
636 if ((exprType == typeTy || exprType == typeRangeTy) &&
637 (type == typeTy || type == typeRangeTy))
638 return success();
639
640 // Handle tuple types.
641 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
642 return convertTupleExpressionTo(expr, exprType: exprTupleType, type, emitErrorFn: emitConvertError,
643 noteAttachFn);
644
645 return emitConvertError();
646}
647
648LogicalResult Parser::convertOpExpressionTo(
649 ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
650 function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
651 // Two operation types are compatible if they have the same name, or if the
652 // expected type is more general.
653 if (auto opType = type.dyn_cast<ast::OperationType>()) {
654 if (opType.getName())
655 return emitErrorFn();
656 return success();
657 }
658
659 // An operation can always convert to a ValueRange.
660 if (type == valueRangeTy) {
661 expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr,
662 type: valueRangeTy);
663 return success();
664 }
665
666 // Allow conversion to a single value by constraining the result range.
667 if (type == valueTy) {
668 // If the operation is registered, we can verify if it can ever have a
669 // single result.
670 if (const ods::Operation *odsOp = exprType.getODSOperation()) {
671 if (odsOp->getResults().empty()) {
672 return emitErrorFn()->attachNote(
673 msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined "
674 "with zero results",
675 Vals: odsOp->getName()),
676 noteLoc: odsOp->getLoc());
677 }
678
679 unsigned numSingleResults = llvm::count_if(
680 Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) {
681 return result.getVariableLengthKind() ==
682 ods::VariableLengthKind::Single;
683 });
684 if (numSingleResults > 1) {
685 return emitErrorFn()->attachNote(
686 msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined "
687 "with at least {1} results",
688 Vals: odsOp->getName(), Vals&: numSingleResults),
689 noteLoc: odsOp->getLoc());
690 }
691 }
692
693 expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr,
694 type: valueTy);
695 return success();
696 }
697 return emitErrorFn();
698}
699
700LogicalResult Parser::convertTupleExpressionTo(
701 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
702 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
703 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
704 // Handle conversions between tuples.
705 if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
706 if (tupleType.size() != exprType.size())
707 return emitErrorFn();
708
709 // Build a new tuple expression using each of the elements of the current
710 // tuple.
711 SmallVector<ast::Expr *> newExprs;
712 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
713 newExprs.push_back(Elt: ast::MemberAccessExpr::create(
714 ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i),
715 type: exprType.getElementTypes()[i]));
716
717 auto diagFn = [&](ast::Diagnostic &diag) {
718 diag.attachNote(msg: llvm::formatv(Fmt: "when converting element #{0} of `{1}`",
719 Vals&: i, Vals&: exprType));
720 if (noteAttachFn)
721 noteAttachFn(diag);
722 };
723 if (failed(result: convertExpressionTo(expr&: newExprs.back(),
724 type: tupleType.getElementTypes()[i], noteAttachFn: diagFn)))
725 return failure();
726 }
727 expr = ast::TupleExpr::create(ctx, loc: expr->getLoc(), elements: newExprs,
728 elementNames: tupleType.getElementNames());
729 return success();
730 }
731
732 // Handle conversion to a range.
733 auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
734 ast::RangeType resultTy) -> LogicalResult {
735 // TODO: We currently only allow range conversion within a rewrite context.
736 if (parserContext != ParserContext::Rewrite) {
737 return emitErrorFn()->attachNote(msg: "Tuple to Range conversion is currently "
738 "only allowed within a rewrite context");
739 }
740
741 // All of the tuple elements must be allowed types.
742 for (ast::Type elementType : exprType.getElementTypes())
743 if (!llvm::is_contained(Range&: allowedElementTypes, Element: elementType))
744 return emitErrorFn();
745
746 // Build a new tuple expression using each of the elements of the current
747 // tuple.
748 SmallVector<ast::Expr *> newExprs;
749 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
750 newExprs.push_back(Elt: ast::MemberAccessExpr::create(
751 ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i),
752 type: exprType.getElementTypes()[i]));
753 }
754 expr = ast::RangeExpr::create(ctx, loc: expr->getLoc(), elements: newExprs, type: resultTy);
755 return success();
756 };
757 if (type == valueRangeTy)
758 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
759 if (type == typeRangeTy)
760 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
761
762 return emitErrorFn();
763}
764
765//===----------------------------------------------------------------------===//
766// Directives
767
768LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
769 StringRef directive = curToken.getSpelling();
770 if (directive == "#include")
771 return parseInclude(decls);
772
773 return emitError(msg: "unknown directive `" + directive + "`");
774}
775
776LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
777 SMRange loc = curToken.getLoc();
778 consumeToken(kind: Token::directive);
779
780 // Handle code completion of the include file path.
781 if (curToken.is(k: Token::code_complete_string))
782 return codeCompleteIncludeFilename(curPath: curToken.getStringValue());
783
784 // Parse the file being included.
785 if (!curToken.isString())
786 return emitError(loc,
787 msg: "expected string file name after `include` directive");
788 SMRange fileLoc = curToken.getLoc();
789 std::string filenameStr = curToken.getStringValue();
790 StringRef filename = filenameStr;
791 consumeToken();
792
793 // Check the type of include. If ending with `.pdll`, this is another pdl file
794 // to be parsed along with the current module.
795 if (filename.ends_with(Suffix: ".pdll")) {
796 if (failed(result: lexer.pushInclude(filename, includeLoc: fileLoc)))
797 return emitError(loc: fileLoc,
798 msg: "unable to open include file `" + filename + "`");
799
800 // If we added the include successfully, parse it into the current module.
801 // Make sure to update to the next token after we finish parsing the nested
802 // file.
803 curToken = lexer.lexToken();
804 LogicalResult result = parseModuleBody(decls);
805 curToken = lexer.lexToken();
806 return result;
807 }
808
809 // Otherwise, this must be a `.td` include.
810 if (filename.ends_with(Suffix: ".td"))
811 return parseTdInclude(filename, fileLoc, decls);
812
813 return emitError(loc: fileLoc,
814 msg: "expected include filename to end with `.pdll` or `.td`");
815}
816
817LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818 SmallVectorImpl<ast::Decl *> &decls) {
819 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
820
821 // Use the source manager to open the file, but don't yet add it.
822 std::string includedFile;
823 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
824 parserSrcMgr.OpenIncludeFile(Filename: filename.str(), IncludedFile&: includedFile);
825 if (!includeBuffer)
826 return emitError(loc: fileLoc, msg: "unable to open include file `" + filename + "`");
827
828 // Setup the source manager for parsing the tablegen file.
829 llvm::SourceMgr tdSrcMgr;
830 tdSrcMgr.AddNewSourceBuffer(F: std::move(*includeBuffer), IncludeLoc: SMLoc());
831 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
832
833 // This class provides a context argument for the llvm::SourceMgr diagnostic
834 // handler.
835 struct DiagHandlerContext {
836 Parser &parser;
837 StringRef filename;
838 llvm::SMRange loc;
839 } handlerContext{.parser: *this, .filename: filename, .loc: fileLoc};
840
841 // Set the diagnostic handler for the tablegen source manager.
842 tdSrcMgr.setDiagHandler(
843 DH: [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
844 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
845 (void)ctx->parser.emitError(
846 loc: ctx->loc,
847 msg: llvm::formatv(Fmt: "error while processing include file `{0}`: {1}",
848 Vals&: ctx->filename, Vals: diag.getMessage()));
849 },
850 Ctx: &handlerContext);
851
852 // Parse the tablegen file.
853 llvm::RecordKeeper tdRecords;
854 if (llvm::TableGenParseFile(InputSrcMgr&: tdSrcMgr, Records&: tdRecords))
855 return failure();
856
857 // Process the parsed records.
858 processTdIncludeRecords(tdRecords, decls);
859
860 // After we are done processing, move all of the tablegen source buffers to
861 // the main parser source mgr. This allows for directly using source locations
862 // from the .td files without needing to remap them.
863 parserSrcMgr.takeSourceBuffersFrom(SrcMgr&: tdSrcMgr, MainBufferIncludeLoc: fileLoc.End);
864 return success();
865}
866
867void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
868 SmallVectorImpl<ast::Decl *> &decls) {
869 // Return the length kind of the given value.
870 auto getLengthKind = [](const auto &value) {
871 if (value.isOptional())
872 return ods::VariableLengthKind::Optional;
873 return value.isVariadic() ? ods::VariableLengthKind::Variadic
874 : ods::VariableLengthKind::Single;
875 };
876
877 // Insert a type constraint into the ODS context.
878 ods::Context &odsContext = ctx.getODSContext();
879 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
880 -> const ods::TypeConstraint & {
881 return odsContext.insertTypeConstraint(
882 name: cst.constraint.getUniqueDefName(),
883 summary: processDoc(doc: cst.constraint.getSummary()),
884 cppClass: cst.constraint.getCPPClassName());
885 };
886 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
887 return {loc, llvm::SMLoc::getFromPointer(Ptr: loc.getPointer() + 1)};
888 };
889
890 // Process the parsed tablegen records to build ODS information.
891 /// Operations.
892 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Op")) {
893 tblgen::Operator op(def);
894
895 // Check to see if this operation is known to support type inferrence.
896 bool supportsResultTypeInferrence =
897 op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait");
898
899 auto [odsOp, inserted] = odsContext.insertOperation(
900 name: op.getOperationName(), summary: processDoc(doc: op.getSummary()),
901 desc: processAndFormatDoc(doc: op.getDescription()), nativeClassName: op.getQualCppClassName(),
902 supportsResultTypeInferrence, loc: op.getLoc().front());
903
904 // Ignore operations that have already been added.
905 if (!inserted)
906 continue;
907
908 for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
909 odsOp->appendAttribute(name: attr.name, optional: attr.attr.isOptional(),
910 constraint: odsContext.insertAttributeConstraint(
911 name: attr.attr.getUniqueDefName(),
912 summary: processDoc(doc: attr.attr.getSummary()),
913 cppClass: attr.attr.getStorageType()));
914 }
915 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
916 odsOp->appendOperand(name: operand.name, variableLengthKind: getLengthKind(operand),
917 constraint: addTypeConstraint(operand));
918 }
919 for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
920 odsOp->appendResult(name: result.name, variableLengthKind: getLengthKind(result),
921 constraint: addTypeConstraint(result));
922 }
923 }
924
925 auto shouldBeSkipped = [this](llvm::Record *def) {
926 return def->isAnonymous() || curDeclScope->lookup(name: def->getName()) ||
927 def->isSubClassOf(Name: "DeclareInterfaceMethods");
928 };
929
930 /// Attr constraints.
931 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Attr")) {
932 if (shouldBeSkipped(def))
933 continue;
934
935 tblgen::Attribute constraint(def);
936 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
937 constraint, loc: convertLocToRange(def->getLoc().front()), type: attrTy,
938 nativeType: constraint.getStorageType()));
939 }
940 /// Type constraints.
941 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Type")) {
942 if (shouldBeSkipped(def))
943 continue;
944
945 tblgen::TypeConstraint constraint(def);
946 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
947 constraint, loc: convertLocToRange(def->getLoc().front()), type: typeTy,
948 nativeType: constraint.getCPPClassName()));
949 }
950 /// OpInterfaces.
951 ast::Type opTy = ast::OperationType::get(context&: ctx);
952 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "OpInterface")) {
953 if (shouldBeSkipped(def))
954 continue;
955
956 SMRange loc = convertLocToRange(def->getLoc().front());
957
958 std::string cppClassName =
959 llvm::formatv(Fmt: "{0}::{1}", Vals: def->getValueAsString(FieldName: "cppNamespace"),
960 Vals: def->getValueAsString(FieldName: "cppInterfaceName"))
961 .str();
962 std::string codeBlock =
963 llvm::formatv(Fmt: "return ::mlir::success(llvm::isa<{0}>(self));",
964 Vals&: cppClassName)
965 .str();
966
967 std::string desc =
968 processAndFormatDoc(doc: def->getValueAsString(FieldName: "description"));
969 decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
970 name: def->getName(), codeBlock, loc, type: opTy, nativeType: cppClassName, docString: desc));
971 }
972}
973
974template <typename ConstraintT>
975ast::Decl *Parser::createODSNativePDLLConstraintDecl(
976 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
977 StringRef nativeType, StringRef docString) {
978 // Build the single input parameter.
979 ast::DeclScope *argScope = pushDeclScope();
980 auto *paramVar = ast::VariableDecl::create(
981 ctx, name: ast::Name::create(ctx, name: "self", location: loc), type,
982 /*initExpr=*/nullptr, constraints: ast::ConstraintRef(ConstraintT::create(ctx, loc)));
983 argScope->add(decl: paramVar);
984 popDeclScope();
985
986 // Build the native constraint.
987 auto *constraintDecl = ast::UserConstraintDecl::createNative(
988 ctx, name: ast::Name::create(ctx, name, location: loc), inputs: paramVar,
989 /*results=*/std::nullopt, codeBlock, resultType: ast::TupleType::get(context&: ctx),
990 nativeInputTypes: nativeType);
991 constraintDecl->setDocComment(ctx, comment: docString);
992 curDeclScope->add(decl: constraintDecl);
993 return constraintDecl;
994}
995
996template <typename ConstraintT>
997ast::Decl *
998Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
999 SMRange loc, ast::Type type,
1000 StringRef nativeType) {
1001 // Format the condition template.
1002 tblgen::FmtContext fmtContext;
1003 fmtContext.withSelf(subst: "self");
1004 std::string codeBlock = tblgen::tgfmt(
1005 fmt: "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1006 ctx: &fmtContext);
1007
1008 // If documentation was enabled, build the doc string for the generated
1009 // constraint. It would be nice to do this lazily, but TableGen information is
1010 // destroyed after we finish parsing the file.
1011 std::string docString;
1012 if (enableDocumentation) {
1013 StringRef desc = constraint.getDescription();
1014 docString = processAndFormatDoc(
1015 doc: constraint.getSummary() +
1016 (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1017 }
1018
1019 return createODSNativePDLLConstraintDecl<ConstraintT>(
1020 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1021 docString);
1022}
1023
1024//===----------------------------------------------------------------------===//
1025// Decls
1026
1027FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1028 FailureOr<ast::Decl *> decl;
1029 switch (curToken.getKind()) {
1030 case Token::kw_Constraint:
1031 decl = parseUserConstraintDecl();
1032 break;
1033 case Token::kw_Pattern:
1034 decl = parsePatternDecl();
1035 break;
1036 case Token::kw_Rewrite:
1037 decl = parseUserRewriteDecl();
1038 break;
1039 default:
1040 return emitError(msg: "expected top-level declaration, such as a `Pattern`");
1041 }
1042 if (failed(result: decl))
1043 return failure();
1044
1045 // If the decl has a name, add it to the current scope.
1046 if (const ast::Name *name = (*decl)->getName()) {
1047 if (failed(result: checkDefineNamedDecl(name: *name)))
1048 return failure();
1049 curDeclScope->add(decl: *decl);
1050 }
1051 return decl;
1052}
1053
1054FailureOr<ast::NamedAttributeDecl *>
1055Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1056 // Check for name code completion.
1057 if (curToken.is(k: Token::code_complete))
1058 return codeCompleteAttributeName(opName: parentOpName);
1059
1060 std::string attrNameStr;
1061 if (curToken.isString())
1062 attrNameStr = curToken.getStringValue();
1063 else if (curToken.is(k: Token::identifier) || curToken.isKeyword())
1064 attrNameStr = curToken.getSpelling().str();
1065 else
1066 return emitError(msg: "expected identifier or string attribute name");
1067 const auto &name = ast::Name::create(ctx, name: attrNameStr, location: curToken.getLoc());
1068 consumeToken();
1069
1070 // Check for a value of the attribute.
1071 ast::Expr *attrValue = nullptr;
1072 if (consumeIf(kind: Token::equal)) {
1073 FailureOr<ast::Expr *> attrExpr = parseExpr();
1074 if (failed(result: attrExpr))
1075 return failure();
1076 attrValue = *attrExpr;
1077 } else {
1078 // If there isn't a concrete value, create an expression representing a
1079 // UnitAttr.
1080 attrValue = ast::AttributeExpr::create(ctx, loc: name.getLoc(), value: "unit");
1081 }
1082
1083 return ast::NamedAttributeDecl::create(ctx, name, value: attrValue);
1084}
1085
1086FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1087 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1088 bool expectTerminalSemicolon) {
1089 consumeToken(kind: Token::equal_arrow);
1090
1091 // Parse the single statement of the lambda body.
1092 SMLoc bodyStartLoc = curToken.getStartLoc();
1093 pushDeclScope();
1094 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1095 bool failedToParse =
1096 failed(result: singleStatement) || failed(result: processStatementFn(*singleStatement));
1097 popDeclScope();
1098 if (failedToParse)
1099 return failure();
1100
1101 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1102 return ast::CompoundStmt::create(ctx, location: bodyLoc, children: *singleStatement);
1103}
1104
1105FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1106 // Ensure that the argument is named.
1107 if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword())
1108 return emitError(msg: "expected identifier argument name");
1109
1110 // Parse the argument similarly to a normal variable.
1111 StringRef name = curToken.getSpelling();
1112 SMRange nameLoc = curToken.getLoc();
1113 consumeToken();
1114
1115 if (failed(
1116 result: parseToken(kind: Token::colon, msg: "expected `:` before argument constraint")))
1117 return failure();
1118
1119 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1120 if (failed(result: cst))
1121 return failure();
1122
1123 return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst);
1124}
1125
1126FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1127 // Check to see if this result is named.
1128 if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) {
1129 // Check to see if this name actually refers to a Constraint.
1130 if (!curDeclScope->lookup<ast::ConstraintDecl>(name: curToken.getSpelling())) {
1131 // If it wasn't a constraint, parse the result similarly to a variable. If
1132 // there is already an existing decl, we will emit an error when defining
1133 // this variable later.
1134 StringRef name = curToken.getSpelling();
1135 SMRange nameLoc = curToken.getLoc();
1136 consumeToken();
1137
1138 if (failed(result: parseToken(kind: Token::colon,
1139 msg: "expected `:` before result constraint")))
1140 return failure();
1141
1142 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1143 if (failed(result: cst))
1144 return failure();
1145
1146 return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst);
1147 }
1148 }
1149
1150 // If it isn't named, we parse the constraint directly and create an unnamed
1151 // result variable.
1152 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1153 if (failed(result: cst))
1154 return failure();
1155
1156 return createArgOrResultVariableDecl(name: "", loc: cst->referenceLoc, constraint: *cst);
1157}
1158
1159FailureOr<ast::UserConstraintDecl *>
1160Parser::parseUserConstraintDecl(bool isInline) {
1161 // Constraints and rewrites have very similar formats, dispatch to a shared
1162 // interface for parsing.
1163 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1164 parseUserPDLLFn: [&](auto &&...args) {
1165 return this->parseUserPDLLConstraintDecl(name: args...);
1166 },
1167 declContext: ParserContext::Constraint, anonymousNamePrefix: "constraint", isInline);
1168}
1169
1170FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1171 FailureOr<ast::UserConstraintDecl *> decl =
1172 parseUserConstraintDecl(/*isInline=*/true);
1173 if (failed(result: decl) || failed(result: checkDefineNamedDecl(name: (*decl)->getName())))
1174 return failure();
1175
1176 curDeclScope->add(decl: *decl);
1177 return decl;
1178}
1179
1180FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1181 const ast::Name &name, bool isInline,
1182 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1183 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1184 // Push the argument scope back onto the list, so that the body can
1185 // reference arguments.
1186 pushDeclScope(scope: argumentScope);
1187
1188 // Parse the body of the constraint. The body is either defined as a compound
1189 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1190 ast::CompoundStmt *body;
1191 if (curToken.is(k: Token::equal_arrow)) {
1192 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1193 processStatementFn: [&](ast::Stmt *&stmt) -> LogicalResult {
1194 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(Val: stmt);
1195 if (!stmtExpr) {
1196 return emitError(loc: stmt->getLoc(),
1197 msg: "expected `Constraint` lambda body to contain a "
1198 "single expression");
1199 }
1200 stmt = ast::ReturnStmt::create(ctx, loc: stmt->getLoc(), resultExpr: stmtExpr);
1201 return success();
1202 },
1203 /*expectTerminalSemicolon=*/!isInline);
1204 if (failed(result: bodyResult))
1205 return failure();
1206 body = *bodyResult;
1207 } else {
1208 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1209 if (failed(result: bodyResult))
1210 return failure();
1211 body = *bodyResult;
1212
1213 // Verify the structure of the body.
1214 auto bodyIt = body->begin(), bodyE = body->end();
1215 for (; bodyIt != bodyE; ++bodyIt)
1216 if (isa<ast::ReturnStmt>(Val: *bodyIt))
1217 break;
1218 if (failed(result: validateUserConstraintOrRewriteReturn(
1219 declType: "Constraint", body, bodyIt, bodyE, results, resultType)))
1220 return failure();
1221 }
1222 popDeclScope();
1223
1224 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1225 name, arguments, results, resultType, body);
1226}
1227
1228FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1229 // Constraints and rewrites have very similar formats, dispatch to a shared
1230 // interface for parsing.
1231 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1232 parseUserPDLLFn: [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(name: args...); },
1233 declContext: ParserContext::Rewrite, anonymousNamePrefix: "rewrite", isInline);
1234}
1235
1236FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1237 FailureOr<ast::UserRewriteDecl *> decl =
1238 parseUserRewriteDecl(/*isInline=*/true);
1239 if (failed(result: decl) || failed(result: checkDefineNamedDecl(name: (*decl)->getName())))
1240 return failure();
1241
1242 curDeclScope->add(decl: *decl);
1243 return decl;
1244}
1245
1246FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1247 const ast::Name &name, bool isInline,
1248 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1249 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1250 // Push the argument scope back onto the list, so that the body can
1251 // reference arguments.
1252 curDeclScope = argumentScope;
1253 ast::CompoundStmt *body;
1254 if (curToken.is(k: Token::equal_arrow)) {
1255 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1256 processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult {
1257 if (isa<ast::OpRewriteStmt>(Val: statement))
1258 return success();
1259
1260 ast::Expr *statementExpr = dyn_cast<ast::Expr>(Val: statement);
1261 if (!statementExpr) {
1262 return emitError(
1263 loc: statement->getLoc(),
1264 msg: "expected `Rewrite` lambda body to contain a single expression "
1265 "or an operation rewrite statement; such as `erase`, "
1266 "`replace`, or `rewrite`");
1267 }
1268 statement =
1269 ast::ReturnStmt::create(ctx, loc: statement->getLoc(), resultExpr: statementExpr);
1270 return success();
1271 },
1272 /*expectTerminalSemicolon=*/!isInline);
1273 if (failed(result: bodyResult))
1274 return failure();
1275 body = *bodyResult;
1276 } else {
1277 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1278 if (failed(result: bodyResult))
1279 return failure();
1280 body = *bodyResult;
1281 }
1282 popDeclScope();
1283
1284 // Verify the structure of the body.
1285 auto bodyIt = body->begin(), bodyE = body->end();
1286 for (; bodyIt != bodyE; ++bodyIt)
1287 if (isa<ast::ReturnStmt>(Val: *bodyIt))
1288 break;
1289 if (failed(result: validateUserConstraintOrRewriteReturn(declType: "Rewrite", body, bodyIt,
1290 bodyE, results, resultType)))
1291 return failure();
1292 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1293 name, arguments, results, resultType, body);
1294}
1295
1296template <typename T, typename ParseUserPDLLDeclFnT>
1297FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1298 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1299 StringRef anonymousNamePrefix, bool isInline) {
1300 SMRange loc = curToken.getLoc();
1301 consumeToken();
1302 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1303
1304 // Parse the name of the decl.
1305 const ast::Name *name = nullptr;
1306 if (curToken.isNot(k: Token::identifier)) {
1307 // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1308 // in C++, so being unnamed is fine.
1309 if (!isInline)
1310 return emitError(msg: "expected identifier name");
1311
1312 // Create a unique anonymous name to use, as the name for this decl is not
1313 // important.
1314 std::string anonName =
1315 llvm::formatv(Fmt: "<anonymous_{0}_{1}>", Vals&: anonymousNamePrefix,
1316 Vals: anonymousDeclNameCounter++)
1317 .str();
1318 name = &ast::Name::create(ctx, name: anonName, location: loc);
1319 } else {
1320 // If a name was provided, we can use it directly.
1321 name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc());
1322 consumeToken(kind: Token::identifier);
1323 }
1324
1325 // Parse the functional signature of the decl.
1326 SmallVector<ast::VariableDecl *> arguments, results;
1327 ast::DeclScope *argumentScope;
1328 ast::Type resultType;
1329 if (failed(result: parseUserConstraintOrRewriteSignature(arguments, results,
1330 argumentScope, resultType)))
1331 return failure();
1332
1333 // Check to see which type of constraint this is. If the constraint contains a
1334 // compound body, this is a PDLL decl.
1335 if (curToken.isAny(k1: Token::l_brace, k2: Token::equal_arrow))
1336 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1337 resultType);
1338
1339 // Otherwise, this is a native decl.
1340 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1341 results, resultType);
1342}
1343
1344template <typename T>
1345FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1346 const ast::Name &name, bool isInline,
1347 ArrayRef<ast::VariableDecl *> arguments,
1348 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1349 // If followed by a string, the native code body has also been specified.
1350 std::string codeStrStorage;
1351 std::optional<StringRef> optCodeStr;
1352 if (curToken.isString()) {
1353 codeStrStorage = curToken.getStringValue();
1354 optCodeStr = codeStrStorage;
1355 consumeToken();
1356 } else if (isInline) {
1357 return emitError(loc: name.getLoc(),
1358 msg: "external declarations must be declared in global scope");
1359 } else if (curToken.is(k: Token::error)) {
1360 return failure();
1361 }
1362 if (failed(result: parseToken(kind: Token::semicolon,
1363 msg: "expected `;` after native declaration")))
1364 return failure();
1365 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1366}
1367
1368LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1369 SmallVectorImpl<ast::VariableDecl *> &arguments,
1370 SmallVectorImpl<ast::VariableDecl *> &results,
1371 ast::DeclScope *&argumentScope, ast::Type &resultType) {
1372 // Parse the argument list of the decl.
1373 if (failed(result: parseToken(kind: Token::l_paren, msg: "expected `(` to start argument list")))
1374 return failure();
1375
1376 argumentScope = pushDeclScope();
1377 if (curToken.isNot(k: Token::r_paren)) {
1378 do {
1379 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1380 if (failed(result: argument))
1381 return failure();
1382 arguments.emplace_back(Args&: *argument);
1383 } while (consumeIf(kind: Token::comma));
1384 }
1385 popDeclScope();
1386 if (failed(result: parseToken(kind: Token::r_paren, msg: "expected `)` to end argument list")))
1387 return failure();
1388
1389 // Parse the results of the decl.
1390 pushDeclScope();
1391 if (consumeIf(kind: Token::arrow)) {
1392 auto parseResultFn = [&]() -> LogicalResult {
1393 FailureOr<ast::VariableDecl *> result = parseResultDecl(resultNum: results.size());
1394 if (failed(result))
1395 return failure();
1396 results.emplace_back(Args&: *result);
1397 return success();
1398 };
1399
1400 // Check for a list of results.
1401 if (consumeIf(kind: Token::l_paren)) {
1402 do {
1403 if (failed(result: parseResultFn()))
1404 return failure();
1405 } while (consumeIf(kind: Token::comma));
1406 if (failed(result: parseToken(kind: Token::r_paren, msg: "expected `)` to end result list")))
1407 return failure();
1408
1409 // Otherwise, there is only one result.
1410 } else if (failed(result: parseResultFn())) {
1411 return failure();
1412 }
1413 }
1414 popDeclScope();
1415
1416 // Compute the result type of the decl.
1417 resultType = createUserConstraintRewriteResultType(results);
1418
1419 // Verify that results are only named if there are more than one.
1420 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1421 return emitError(
1422 loc: results.front()->getLoc(),
1423 msg: "cannot create a single-element tuple with an element label");
1424 }
1425 return success();
1426}
1427
1428LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1429 StringRef declType, ast::CompoundStmt *body,
1430 ArrayRef<ast::Stmt *>::iterator bodyIt,
1431 ArrayRef<ast::Stmt *>::iterator bodyE,
1432 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1433 // Handle if a `return` was provided.
1434 if (bodyIt != bodyE) {
1435 // Emit an error if we have trailing statements after the return.
1436 if (std::next(x: bodyIt) != bodyE) {
1437 return emitError(
1438 loc: (*std::next(x: bodyIt))->getLoc(),
1439 msg: llvm::formatv(Fmt: "`return` terminated the `{0}` body, but found "
1440 "trailing statements afterwards",
1441 Vals&: declType));
1442 }
1443
1444 // Otherwise if a return wasn't provided, check that no results are
1445 // expected.
1446 } else if (!results.empty()) {
1447 return emitError(
1448 loc: {body->getLoc().End, body->getLoc().End},
1449 msg: llvm::formatv(Fmt: "missing return in a `{0}` expected to return `{1}`",
1450 Vals&: declType, Vals&: resultType));
1451 }
1452 return success();
1453}
1454
1455FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1456 return parseLambdaBody(processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult {
1457 if (isa<ast::OpRewriteStmt>(Val: statement))
1458 return success();
1459 return emitError(
1460 loc: statement->getLoc(),
1461 msg: "expected Pattern lambda body to contain a single operation "
1462 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1463 });
1464}
1465
1466FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1467 SMRange loc = curToken.getLoc();
1468 consumeToken(kind: Token::kw_Pattern);
1469 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1470
1471 // Check for an optional identifier for the pattern name.
1472 const ast::Name *name = nullptr;
1473 if (curToken.is(k: Token::identifier)) {
1474 name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc());
1475 consumeToken(kind: Token::identifier);
1476 }
1477
1478 // Parse any pattern metadata.
1479 ParsedPatternMetadata metadata;
1480 if (consumeIf(kind: Token::kw_with) && failed(result: parsePatternDeclMetadata(metadata)))
1481 return failure();
1482
1483 // Parse the pattern body.
1484 ast::CompoundStmt *body;
1485
1486 // Handle a lambda body.
1487 if (curToken.is(k: Token::equal_arrow)) {
1488 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1489 if (failed(result: bodyResult))
1490 return failure();
1491 body = *bodyResult;
1492 } else {
1493 if (curToken.isNot(k: Token::l_brace))
1494 return emitError(msg: "expected `{` or `=>` to start pattern body");
1495 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1496 if (failed(result: bodyResult))
1497 return failure();
1498 body = *bodyResult;
1499
1500 // Verify the body of the pattern.
1501 auto bodyIt = body->begin(), bodyE = body->end();
1502 for (; bodyIt != bodyE; ++bodyIt) {
1503 if (isa<ast::ReturnStmt>(Val: *bodyIt)) {
1504 return emitError(loc: (*bodyIt)->getLoc(),
1505 msg: "`return` statements are only permitted within a "
1506 "`Constraint` or `Rewrite` body");
1507 }
1508 // Break when we've found the rewrite statement.
1509 if (isa<ast::OpRewriteStmt>(Val: *bodyIt))
1510 break;
1511 }
1512 if (bodyIt == bodyE) {
1513 return emitError(loc,
1514 msg: "expected Pattern body to terminate with an operation "
1515 "rewrite statement, such as `erase`");
1516 }
1517 if (std::next(x: bodyIt) != bodyE) {
1518 return emitError(loc: (*std::next(x: bodyIt))->getLoc(),
1519 msg: "Pattern body was terminated by an operation "
1520 "rewrite statement, but found trailing statements");
1521 }
1522 }
1523
1524 return createPatternDecl(loc, name, metadata, body);
1525}
1526
1527LogicalResult
1528Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1529 std::optional<SMRange> benefitLoc;
1530 std::optional<SMRange> hasBoundedRecursionLoc;
1531
1532 do {
1533 // Handle metadata code completion.
1534 if (curToken.is(k: Token::code_complete))
1535 return codeCompletePatternMetadata();
1536
1537 if (curToken.isNot(k: Token::identifier))
1538 return emitError(msg: "expected pattern metadata identifier");
1539 StringRef metadataStr = curToken.getSpelling();
1540 SMRange metadataLoc = curToken.getLoc();
1541 consumeToken(kind: Token::identifier);
1542
1543 // Parse the benefit metadata: benefit(<integer-value>)
1544 if (metadataStr == "benefit") {
1545 if (benefitLoc) {
1546 return emitErrorAndNote(loc: metadataLoc,
1547 msg: "pattern benefit has already been specified",
1548 noteLoc: *benefitLoc, note: "see previous definition here");
1549 }
1550 if (failed(result: parseToken(kind: Token::l_paren,
1551 msg: "expected `(` before pattern benefit")))
1552 return failure();
1553
1554 uint16_t benefitValue = 0;
1555 if (curToken.isNot(k: Token::integer))
1556 return emitError(msg: "expected integral pattern benefit");
1557 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, Result&: benefitValue))
1558 return emitError(
1559 msg: "expected pattern benefit to fit within a 16-bit integer");
1560 consumeToken(kind: Token::integer);
1561
1562 metadata.benefit = benefitValue;
1563 benefitLoc = metadataLoc;
1564
1565 if (failed(
1566 result: parseToken(kind: Token::r_paren, msg: "expected `)` after pattern benefit")))
1567 return failure();
1568 continue;
1569 }
1570
1571 // Parse the bounded recursion metadata: recursion
1572 if (metadataStr == "recursion") {
1573 if (hasBoundedRecursionLoc) {
1574 return emitErrorAndNote(
1575 loc: metadataLoc,
1576 msg: "pattern recursion metadata has already been specified",
1577 noteLoc: *hasBoundedRecursionLoc, note: "see previous definition here");
1578 }
1579 metadata.hasBoundedRecursion = true;
1580 hasBoundedRecursionLoc = metadataLoc;
1581 continue;
1582 }
1583
1584 return emitError(loc: metadataLoc, msg: "unknown pattern metadata");
1585 } while (consumeIf(kind: Token::comma));
1586
1587 return success();
1588}
1589
1590FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1591 consumeToken(kind: Token::less);
1592
1593 FailureOr<ast::Expr *> typeExpr = parseExpr();
1594 if (failed(result: typeExpr) ||
1595 failed(result: parseToken(kind: Token::greater,
1596 msg: "expected `>` after variable type constraint")))
1597 return failure();
1598 return typeExpr;
1599}
1600
1601LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1602 assert(curDeclScope && "defining decl outside of a decl scope");
1603 if (ast::Decl *lastDecl = curDeclScope->lookup(name: name.getName())) {
1604 return emitErrorAndNote(
1605 loc: name.getLoc(), msg: "`" + name.getName() + "` has already been defined",
1606 noteLoc: lastDecl->getName()->getLoc(), note: "see previous definition here");
1607 }
1608 return success();
1609}
1610
1611FailureOr<ast::VariableDecl *>
1612Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1613 ast::Expr *initExpr,
1614 ArrayRef<ast::ConstraintRef> constraints) {
1615 assert(curDeclScope && "defining variable outside of decl scope");
1616 const ast::Name &nameDecl = ast::Name::create(ctx, name, location: nameLoc);
1617
1618 // If the name of the variable indicates a special variable, we don't add it
1619 // to the scope. This variable is local to the definition point.
1620 if (name.empty() || name == "_") {
1621 return ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr,
1622 constraints);
1623 }
1624 if (failed(result: checkDefineNamedDecl(name: nameDecl)))
1625 return failure();
1626
1627 auto *varDecl =
1628 ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr, constraints);
1629 curDeclScope->add(decl: varDecl);
1630 return varDecl;
1631}
1632
1633FailureOr<ast::VariableDecl *>
1634Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1635 ArrayRef<ast::ConstraintRef> constraints) {
1636 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1637 constraints);
1638}
1639
1640LogicalResult Parser::parseVariableDeclConstraintList(
1641 SmallVectorImpl<ast::ConstraintRef> &constraints) {
1642 std::optional<SMRange> typeConstraint;
1643 auto parseSingleConstraint = [&] {
1644 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1645 typeConstraint, existingConstraints: constraints, /*allowInlineTypeConstraints=*/true);
1646 if (failed(result: constraint))
1647 return failure();
1648 constraints.push_back(Elt: *constraint);
1649 return success();
1650 };
1651
1652 // Check to see if this is a single constraint, or a list.
1653 if (!consumeIf(kind: Token::l_square))
1654 return parseSingleConstraint();
1655
1656 do {
1657 if (failed(result: parseSingleConstraint()))
1658 return failure();
1659 } while (consumeIf(kind: Token::comma));
1660 return parseToken(kind: Token::r_square, msg: "expected `]` after constraint list");
1661}
1662
1663FailureOr<ast::ConstraintRef>
1664Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1665 ArrayRef<ast::ConstraintRef> existingConstraints,
1666 bool allowInlineTypeConstraints) {
1667 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1668 if (!allowInlineTypeConstraints) {
1669 return emitError(
1670 loc: curToken.getLoc(),
1671 msg: "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1672 "permitted on arguments or results");
1673 }
1674 if (typeConstraint)
1675 return emitErrorAndNote(
1676 loc: curToken.getLoc(),
1677 msg: "the type of this variable has already been constrained",
1678 noteLoc: *typeConstraint, note: "see previous constraint location here");
1679 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1680 if (failed(result: constraintExpr))
1681 return failure();
1682 typeExpr = *constraintExpr;
1683 typeConstraint = typeExpr->getLoc();
1684 return success();
1685 };
1686
1687 SMRange loc = curToken.getLoc();
1688 switch (curToken.getKind()) {
1689 case Token::kw_Attr: {
1690 consumeToken(kind: Token::kw_Attr);
1691
1692 // Check for a type constraint.
1693 ast::Expr *typeExpr = nullptr;
1694 if (curToken.is(k: Token::less) && failed(result: parseTypeConstraint(typeExpr)))
1695 return failure();
1696 return ast::ConstraintRef(
1697 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1698 }
1699 case Token::kw_Op: {
1700 consumeToken(kind: Token::kw_Op);
1701
1702 // Parse an optional operation name. If the name isn't provided, this refers
1703 // to "any" operation.
1704 FailureOr<ast::OpNameDecl *> opName =
1705 parseWrappedOperationName(/*allowEmptyName=*/true);
1706 if (failed(result: opName))
1707 return failure();
1708
1709 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, nameDecl: *opName),
1710 loc);
1711 }
1712 case Token::kw_Type:
1713 consumeToken(kind: Token::kw_Type);
1714 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1715 case Token::kw_TypeRange:
1716 consumeToken(kind: Token::kw_TypeRange);
1717 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1718 loc);
1719 case Token::kw_Value: {
1720 consumeToken(kind: Token::kw_Value);
1721
1722 // Check for a type constraint.
1723 ast::Expr *typeExpr = nullptr;
1724 if (curToken.is(k: Token::less) && failed(result: parseTypeConstraint(typeExpr)))
1725 return failure();
1726
1727 return ast::ConstraintRef(
1728 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1729 }
1730 case Token::kw_ValueRange: {
1731 consumeToken(kind: Token::kw_ValueRange);
1732
1733 // Check for a type constraint.
1734 ast::Expr *typeExpr = nullptr;
1735 if (curToken.is(k: Token::less) && failed(result: parseTypeConstraint(typeExpr)))
1736 return failure();
1737
1738 return ast::ConstraintRef(
1739 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1740 }
1741
1742 case Token::kw_Constraint: {
1743 // Handle an inline constraint.
1744 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1745 if (failed(result: decl))
1746 return failure();
1747 return ast::ConstraintRef(*decl, loc);
1748 }
1749 case Token::identifier: {
1750 StringRef constraintName = curToken.getSpelling();
1751 consumeToken(kind: Token::identifier);
1752
1753 // Lookup the referenced constraint.
1754 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(name: constraintName);
1755 if (!cstDecl) {
1756 return emitError(loc, msg: "unknown reference to constraint `" +
1757 constraintName + "`");
1758 }
1759
1760 // Handle a reference to a proper constraint.
1761 if (auto *cst = dyn_cast<ast::ConstraintDecl>(Val: cstDecl))
1762 return ast::ConstraintRef(cst, loc);
1763
1764 return emitErrorAndNote(
1765 loc, msg: "invalid reference to non-constraint", noteLoc: cstDecl->getLoc(),
1766 note: "see the definition of `" + constraintName + "` here");
1767 }
1768 // Handle single entity constraint code completion.
1769 case Token::code_complete: {
1770 // Try to infer the current type for use by code completion.
1771 ast::Type inferredType;
1772 if (failed(result: validateVariableConstraints(constraints: existingConstraints, inferredType)))
1773 return failure();
1774
1775 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1776 }
1777 default:
1778 break;
1779 }
1780 return emitError(loc, msg: "expected identifier constraint");
1781}
1782
1783FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1784 std::optional<SMRange> typeConstraint;
1785 return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1786 /*allowInlineTypeConstraints=*/false);
1787}
1788
1789//===----------------------------------------------------------------------===//
1790// Exprs
1791
1792FailureOr<ast::Expr *> Parser::parseExpr() {
1793 if (curToken.is(k: Token::underscore))
1794 return parseUnderscoreExpr();
1795
1796 // Parse the LHS expression.
1797 FailureOr<ast::Expr *> lhsExpr;
1798 switch (curToken.getKind()) {
1799 case Token::kw_attr:
1800 lhsExpr = parseAttributeExpr();
1801 break;
1802 case Token::kw_Constraint:
1803 lhsExpr = parseInlineConstraintLambdaExpr();
1804 break;
1805 case Token::kw_not:
1806 lhsExpr = parseNegatedExpr();
1807 break;
1808 case Token::identifier:
1809 lhsExpr = parseIdentifierExpr();
1810 break;
1811 case Token::kw_op:
1812 lhsExpr = parseOperationExpr();
1813 break;
1814 case Token::kw_Rewrite:
1815 lhsExpr = parseInlineRewriteLambdaExpr();
1816 break;
1817 case Token::kw_type:
1818 lhsExpr = parseTypeExpr();
1819 break;
1820 case Token::l_paren:
1821 lhsExpr = parseTupleExpr();
1822 break;
1823 default:
1824 return emitError(msg: "expected expression");
1825 }
1826 if (failed(result: lhsExpr))
1827 return failure();
1828
1829 // Check for an operator expression.
1830 while (true) {
1831 switch (curToken.getKind()) {
1832 case Token::dot:
1833 lhsExpr = parseMemberAccessExpr(parentExpr: *lhsExpr);
1834 break;
1835 case Token::l_paren:
1836 lhsExpr = parseCallExpr(parentExpr: *lhsExpr);
1837 break;
1838 default:
1839 return lhsExpr;
1840 }
1841 if (failed(result: lhsExpr))
1842 return failure();
1843 }
1844}
1845
1846FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1847 SMRange loc = curToken.getLoc();
1848 consumeToken(kind: Token::kw_attr);
1849
1850 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1851 // identifier.
1852 if (!consumeIf(kind: Token::less)) {
1853 resetToken(tokLoc: loc);
1854 return parseIdentifierExpr();
1855 }
1856
1857 if (!curToken.isString())
1858 return emitError(msg: "expected string literal containing MLIR attribute");
1859 std::string attrExpr = curToken.getStringValue();
1860 consumeToken();
1861
1862 loc.End = curToken.getEndLoc();
1863 if (failed(
1864 result: parseToken(kind: Token::greater, msg: "expected `>` after attribute literal")))
1865 return failure();
1866 return ast::AttributeExpr::create(ctx, loc, value: attrExpr);
1867}
1868
1869FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1870 bool isNegated) {
1871 consumeToken(kind: Token::l_paren);
1872
1873 // Parse the arguments of the call.
1874 SmallVector<ast::Expr *> arguments;
1875 if (curToken.isNot(k: Token::r_paren)) {
1876 do {
1877 // Handle code completion for the call arguments.
1878 if (curToken.is(k: Token::code_complete)) {
1879 codeCompleteCallSignature(parent: parentExpr, currentNumArgs: arguments.size());
1880 return failure();
1881 }
1882
1883 FailureOr<ast::Expr *> argument = parseExpr();
1884 if (failed(result: argument))
1885 return failure();
1886 arguments.push_back(Elt: *argument);
1887 } while (consumeIf(kind: Token::comma));
1888 }
1889
1890 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1891 if (failed(result: parseToken(kind: Token::r_paren, msg: "expected `)` after argument list")))
1892 return failure();
1893
1894 return createCallExpr(loc, parentExpr, arguments, isNegated);
1895}
1896
1897FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1898 ast::Decl *decl = curDeclScope->lookup(name);
1899 if (!decl)
1900 return emitError(loc, msg: "undefined reference to `" + name + "`");
1901
1902 return createDeclRefExpr(loc, decl);
1903}
1904
1905FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1906 StringRef name = curToken.getSpelling();
1907 SMRange nameLoc = curToken.getLoc();
1908 consumeToken();
1909
1910 // Check to see if this is a decl ref expression that defines a variable
1911 // inline.
1912 if (consumeIf(kind: Token::colon)) {
1913 SmallVector<ast::ConstraintRef> constraints;
1914 if (failed(result: parseVariableDeclConstraintList(constraints)))
1915 return failure();
1916 ast::Type type;
1917 if (failed(result: validateVariableConstraints(constraints, inferredType&: type)))
1918 return failure();
1919 return createInlineVariableExpr(type, name, loc: nameLoc, constraints);
1920 }
1921
1922 return parseDeclRefExpr(name, loc: nameLoc);
1923}
1924
1925FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1926 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1927 if (failed(result: decl))
1928 return failure();
1929
1930 return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl,
1931 type: ast::ConstraintType::get(context&: ctx));
1932}
1933
1934FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1935 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1936 if (failed(result: decl))
1937 return failure();
1938
1939 return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl,
1940 type: ast::RewriteType::get(context&: ctx));
1941}
1942
1943FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1944 SMRange dotLoc = curToken.getLoc();
1945 consumeToken(kind: Token::dot);
1946
1947 // Check for code completion of the member name.
1948 if (curToken.is(k: Token::code_complete))
1949 return codeCompleteMemberAccess(parentExpr);
1950
1951 // Parse the member name.
1952 Token memberNameTok = curToken;
1953 if (memberNameTok.isNot(k1: Token::identifier, k2: Token::integer) &&
1954 !memberNameTok.isKeyword())
1955 return emitError(loc: dotLoc, msg: "expected identifier or numeric member name");
1956 StringRef memberName = memberNameTok.getSpelling();
1957 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1958 consumeToken();
1959
1960 return createMemberAccessExpr(parentExpr, name: memberName, loc);
1961}
1962
1963FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1964 consumeToken(kind: Token::kw_not);
1965 // Only native constraints are supported after negation
1966 if (!curToken.is(k: Token::identifier))
1967 return emitError(msg: "expected native constraint");
1968 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1969 if (failed(result: identifierExpr))
1970 return failure();
1971 if (!curToken.is(k: Token::l_paren))
1972 return emitError(msg: "expected `(` after function name");
1973 return parseCallExpr(parentExpr: *identifierExpr, /*isNegated = */ true);
1974}
1975
1976FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1977 SMRange loc = curToken.getLoc();
1978
1979 // Check for code completion for the dialect name.
1980 if (curToken.is(k: Token::code_complete))
1981 return codeCompleteDialectName();
1982
1983 // Handle the case of an no operation name.
1984 if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword()) {
1985 if (allowEmptyName)
1986 return ast::OpNameDecl::create(ctx, loc: SMRange());
1987 return emitError(msg: "expected dialect namespace");
1988 }
1989 StringRef name = curToken.getSpelling();
1990 consumeToken();
1991
1992 // Otherwise, this is a literal operation name.
1993 if (failed(result: parseToken(kind: Token::dot, msg: "expected `.` after dialect namespace")))
1994 return failure();
1995
1996 // Check for code completion for the operation name.
1997 if (curToken.is(k: Token::code_complete))
1998 return codeCompleteOperationName(dialectName: name);
1999
2000 if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword())
2001 return emitError(msg: "expected operation name after dialect namespace");
2002
2003 name = StringRef(name.data(), name.size() + 1);
2004 do {
2005 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2006 loc.End = curToken.getEndLoc();
2007 consumeToken();
2008 } while (curToken.isAny(k1: Token::identifier, k2: Token::dot) ||
2009 curToken.isKeyword());
2010 return ast::OpNameDecl::create(ctx, name: ast::Name::create(ctx, name, location: loc));
2011}
2012
2013FailureOr<ast::OpNameDecl *>
2014Parser::parseWrappedOperationName(bool allowEmptyName) {
2015 if (!consumeIf(kind: Token::less))
2016 return ast::OpNameDecl::create(ctx, loc: SMRange());
2017
2018 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2019 if (failed(result: opNameDecl))
2020 return failure();
2021
2022 if (failed(result: parseToken(kind: Token::greater, msg: "expected `>` after operation name")))
2023 return failure();
2024 return opNameDecl;
2025}
2026
2027FailureOr<ast::Expr *>
2028Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2029 SMRange loc = curToken.getLoc();
2030 consumeToken(kind: Token::kw_op);
2031
2032 // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2033 // identifier.
2034 if (curToken.isNot(k: Token::less)) {
2035 resetToken(tokLoc: loc);
2036 return parseIdentifierExpr();
2037 }
2038
2039 // Parse the operation name. The name may be elided, in which case the
2040 // operation refers to "any" operation(i.e. a difference between `MyOp` and
2041 // `Operation*`). Operation names within a rewrite context must be named.
2042 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2043 FailureOr<ast::OpNameDecl *> opNameDecl =
2044 parseWrappedOperationName(allowEmptyName);
2045 if (failed(result: opNameDecl))
2046 return failure();
2047 std::optional<StringRef> opName = (*opNameDecl)->getName();
2048
2049 // Functor used to create an implicit range variable, used for implicit "all"
2050 // operand or results variables.
2051 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2052 FailureOr<ast::VariableDecl *> rangeVar =
2053 defineVariableDecl(name: "_", nameLoc: loc, type, constraints: ast::ConstraintRef(cst, loc));
2054 assert(succeeded(rangeVar) && "expected range variable to be valid");
2055 return ast::DeclRefExpr::create(ctx, loc, decl: *rangeVar, type);
2056 };
2057
2058 // Check for the optional list of operands.
2059 SmallVector<ast::Expr *> operands;
2060 if (!consumeIf(kind: Token::l_paren)) {
2061 // If the operand list isn't specified and we are in a match context, define
2062 // an inplace unconstrained operand range corresponding to all of the
2063 // operands of the operation. This avoids treating zero operands the same
2064 // way as "unconstrained operands".
2065 if (parserContext != ParserContext::Rewrite) {
2066 operands.push_back(Elt: createImplicitRangeVar(
2067 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2068 }
2069 } else if (!consumeIf(kind: Token::r_paren)) {
2070 // If the operand list was specified and non-empty, parse the operands.
2071 do {
2072 // Check for operand signature code completion.
2073 if (curToken.is(k: Token::code_complete)) {
2074 codeCompleteOperationOperandsSignature(opName, currentNumOperands: operands.size());
2075 return failure();
2076 }
2077
2078 FailureOr<ast::Expr *> operand = parseExpr();
2079 if (failed(result: operand))
2080 return failure();
2081 operands.push_back(Elt: *operand);
2082 } while (consumeIf(kind: Token::comma));
2083
2084 if (failed(result: parseToken(kind: Token::r_paren,
2085 msg: "expected `)` after operation operand list")))
2086 return failure();
2087 }
2088
2089 // Check for the optional list of attributes.
2090 SmallVector<ast::NamedAttributeDecl *> attributes;
2091 if (consumeIf(kind: Token::l_brace)) {
2092 do {
2093 FailureOr<ast::NamedAttributeDecl *> decl =
2094 parseNamedAttributeDecl(parentOpName: opName);
2095 if (failed(result: decl))
2096 return failure();
2097 attributes.emplace_back(Args&: *decl);
2098 } while (consumeIf(kind: Token::comma));
2099
2100 if (failed(result: parseToken(kind: Token::r_brace,
2101 msg: "expected `}` after operation attribute list")))
2102 return failure();
2103 }
2104
2105 // Handle the result types of the operation.
2106 SmallVector<ast::Expr *> resultTypes;
2107 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2108
2109 // Check for an explicit list of result types.
2110 if (consumeIf(kind: Token::arrow)) {
2111 if (failed(result: parseToken(kind: Token::l_paren,
2112 msg: "expected `(` before operation result type list")))
2113 return failure();
2114
2115 // If result types are provided, initially assume that the operation does
2116 // not rely on type inferrence. We don't assert that it isn't, because we
2117 // may be inferring the value of some type/type range variables, but given
2118 // that these variables may be defined in calls we can't always discern when
2119 // this is the case.
2120 resultTypeContext = OpResultTypeContext::Explicit;
2121
2122 // Handle the case of an empty result list.
2123 if (!consumeIf(kind: Token::r_paren)) {
2124 do {
2125 // Check for result signature code completion.
2126 if (curToken.is(k: Token::code_complete)) {
2127 codeCompleteOperationResultsSignature(opName, currentNumResults: resultTypes.size());
2128 return failure();
2129 }
2130
2131 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2132 if (failed(result: resultTypeExpr))
2133 return failure();
2134 resultTypes.push_back(Elt: *resultTypeExpr);
2135 } while (consumeIf(kind: Token::comma));
2136
2137 if (failed(result: parseToken(kind: Token::r_paren,
2138 msg: "expected `)` after operation result type list")))
2139 return failure();
2140 }
2141 } else if (parserContext != ParserContext::Rewrite) {
2142 // If the result list isn't specified and we are in a match context, define
2143 // an inplace unconstrained result range corresponding to all of the results
2144 // of the operation. This avoids treating zero results the same way as
2145 // "unconstrained results".
2146 resultTypes.push_back(Elt: createImplicitRangeVar(
2147 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2148 } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2149 // If the result list isn't specified and we are in a rewrite, try to infer
2150 // them at runtime instead.
2151 resultTypeContext = OpResultTypeContext::Interface;
2152 }
2153
2154 return createOperationExpr(loc, name: *opNameDecl, resultTypeContext, operands,
2155 attributes, results&: resultTypes);
2156}
2157
2158FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2159 SMRange loc = curToken.getLoc();
2160 consumeToken(kind: Token::l_paren);
2161
2162 DenseMap<StringRef, SMRange> usedNames;
2163 SmallVector<StringRef> elementNames;
2164 SmallVector<ast::Expr *> elements;
2165 if (curToken.isNot(k: Token::r_paren)) {
2166 do {
2167 // Check for the optional element name assignment before the value.
2168 StringRef elementName;
2169 if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) {
2170 Token elementNameTok = curToken;
2171 consumeToken();
2172
2173 // The element name is only present if followed by an `=`.
2174 if (consumeIf(kind: Token::equal)) {
2175 elementName = elementNameTok.getSpelling();
2176
2177 // Check to see if this name is already used.
2178 auto elementNameIt =
2179 usedNames.try_emplace(Key: elementName, Args: elementNameTok.getLoc());
2180 if (!elementNameIt.second) {
2181 return emitErrorAndNote(
2182 loc: elementNameTok.getLoc(),
2183 msg: llvm::formatv(Fmt: "duplicate tuple element label `{0}`",
2184 Vals&: elementName),
2185 noteLoc: elementNameIt.first->getSecond(),
2186 note: "see previous label use here");
2187 }
2188 } else {
2189 // Otherwise, we treat this as part of an expression so reset the
2190 // lexer.
2191 resetToken(tokLoc: elementNameTok.getLoc());
2192 }
2193 }
2194 elementNames.push_back(Elt: elementName);
2195
2196 // Parse the tuple element value.
2197 FailureOr<ast::Expr *> element = parseExpr();
2198 if (failed(result: element))
2199 return failure();
2200 elements.push_back(Elt: *element);
2201 } while (consumeIf(kind: Token::comma));
2202 }
2203 loc.End = curToken.getEndLoc();
2204 if (failed(
2205 result: parseToken(kind: Token::r_paren, msg: "expected `)` after tuple element list")))
2206 return failure();
2207 return createTupleExpr(loc, elements, elementNames);
2208}
2209
2210FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2211 SMRange loc = curToken.getLoc();
2212 consumeToken(kind: Token::kw_type);
2213
2214 // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2215 // identifier.
2216 if (!consumeIf(kind: Token::less)) {
2217 resetToken(tokLoc: loc);
2218 return parseIdentifierExpr();
2219 }
2220
2221 if (!curToken.isString())
2222 return emitError(msg: "expected string literal containing MLIR type");
2223 std::string attrExpr = curToken.getStringValue();
2224 consumeToken();
2225
2226 loc.End = curToken.getEndLoc();
2227 if (failed(result: parseToken(kind: Token::greater, msg: "expected `>` after type literal")))
2228 return failure();
2229 return ast::TypeExpr::create(ctx, loc, value: attrExpr);
2230}
2231
2232FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2233 StringRef name = curToken.getSpelling();
2234 SMRange nameLoc = curToken.getLoc();
2235 consumeToken(kind: Token::underscore);
2236
2237 // Underscore expressions require a constraint list.
2238 if (failed(result: parseToken(kind: Token::colon, msg: "expected `:` after `_` variable")))
2239 return failure();
2240
2241 // Parse the constraints for the expression.
2242 SmallVector<ast::ConstraintRef> constraints;
2243 if (failed(result: parseVariableDeclConstraintList(constraints)))
2244 return failure();
2245
2246 ast::Type type;
2247 if (failed(result: validateVariableConstraints(constraints, inferredType&: type)))
2248 return failure();
2249 return createInlineVariableExpr(type, name, loc: nameLoc, constraints);
2250}
2251
2252//===----------------------------------------------------------------------===//
2253// Stmts
2254
2255FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2256 FailureOr<ast::Stmt *> stmt;
2257 switch (curToken.getKind()) {
2258 case Token::kw_erase:
2259 stmt = parseEraseStmt();
2260 break;
2261 case Token::kw_let:
2262 stmt = parseLetStmt();
2263 break;
2264 case Token::kw_replace:
2265 stmt = parseReplaceStmt();
2266 break;
2267 case Token::kw_return:
2268 stmt = parseReturnStmt();
2269 break;
2270 case Token::kw_rewrite:
2271 stmt = parseRewriteStmt();
2272 break;
2273 default:
2274 stmt = parseExpr();
2275 break;
2276 }
2277 if (failed(result: stmt) ||
2278 (expectTerminalSemicolon &&
2279 failed(result: parseToken(kind: Token::semicolon, msg: "expected `;` after statement"))))
2280 return failure();
2281 return stmt;
2282}
2283
2284FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2285 SMLoc startLoc = curToken.getStartLoc();
2286 consumeToken(kind: Token::l_brace);
2287
2288 // Push a new block scope and parse any nested statements.
2289 pushDeclScope();
2290 SmallVector<ast::Stmt *> statements;
2291 while (curToken.isNot(k: Token::r_brace)) {
2292 FailureOr<ast::Stmt *> statement = parseStmt();
2293 if (failed(result: statement))
2294 return popDeclScope(), failure();
2295 statements.push_back(Elt: *statement);
2296 }
2297 popDeclScope();
2298
2299 // Consume the end brace.
2300 SMRange location(startLoc, curToken.getEndLoc());
2301 consumeToken(kind: Token::r_brace);
2302
2303 return ast::CompoundStmt::create(ctx, location, children: statements);
2304}
2305
2306FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2307 if (parserContext == ParserContext::Constraint)
2308 return emitError(msg: "`erase` cannot be used within a Constraint");
2309 SMRange loc = curToken.getLoc();
2310 consumeToken(kind: Token::kw_erase);
2311
2312 // Parse the root operation expression.
2313 FailureOr<ast::Expr *> rootOp = parseExpr();
2314 if (failed(result: rootOp))
2315 return failure();
2316
2317 return createEraseStmt(loc, rootOp: *rootOp);
2318}
2319
2320FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2321 SMRange loc = curToken.getLoc();
2322 consumeToken(kind: Token::kw_let);
2323
2324 // Parse the name of the new variable.
2325 SMRange varLoc = curToken.getLoc();
2326 if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword()) {
2327 // `_` is a reserved variable name.
2328 if (curToken.is(k: Token::underscore)) {
2329 return emitError(loc: varLoc,
2330 msg: "`_` may only be used to define \"inline\" variables");
2331 }
2332 return emitError(loc: varLoc,
2333 msg: "expected identifier after `let` to name a new variable");
2334 }
2335 StringRef varName = curToken.getSpelling();
2336 consumeToken();
2337
2338 // Parse the optional set of constraints.
2339 SmallVector<ast::ConstraintRef> constraints;
2340 if (consumeIf(kind: Token::colon) &&
2341 failed(result: parseVariableDeclConstraintList(constraints)))
2342 return failure();
2343
2344 // Parse the optional initializer expression.
2345 ast::Expr *initializer = nullptr;
2346 if (consumeIf(kind: Token::equal)) {
2347 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2348 if (failed(result: initOrFailure))
2349 return failure();
2350 initializer = *initOrFailure;
2351
2352 // Check that the constraints are compatible with having an initializer,
2353 // e.g. type constraints cannot be used with initializers.
2354 for (ast::ConstraintRef constraint : constraints) {
2355 LogicalResult result =
2356 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2357 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2358 ast::ValueRangeConstraintDecl>(caseFn: [&](const auto *cst) {
2359 if (cst->getTypeExpr()) {
2360 return this->emitError(
2361 loc: constraint.referenceLoc,
2362 msg: "type constraints are not permitted on variables with "
2363 "initializers");
2364 }
2365 return success();
2366 })
2367 .Default(defaultResult: success());
2368 if (failed(result))
2369 return failure();
2370 }
2371 }
2372
2373 FailureOr<ast::VariableDecl *> varDecl =
2374 createVariableDecl(name: varName, loc: varLoc, initializer, constraints);
2375 if (failed(result: varDecl))
2376 return failure();
2377 return ast::LetStmt::create(ctx, loc, varDecl: *varDecl);
2378}
2379
2380FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2381 if (parserContext == ParserContext::Constraint)
2382 return emitError(msg: "`replace` cannot be used within a Constraint");
2383 SMRange loc = curToken.getLoc();
2384 consumeToken(kind: Token::kw_replace);
2385
2386 // Parse the root operation expression.
2387 FailureOr<ast::Expr *> rootOp = parseExpr();
2388 if (failed(result: rootOp))
2389 return failure();
2390
2391 if (failed(
2392 result: parseToken(kind: Token::kw_with, msg: "expected `with` after root operation")))
2393 return failure();
2394
2395 // The replacement portion of this statement is within a rewrite context.
2396 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2397
2398 // Parse the replacement values.
2399 SmallVector<ast::Expr *> replValues;
2400 if (consumeIf(kind: Token::l_paren)) {
2401 if (consumeIf(kind: Token::r_paren)) {
2402 return emitError(
2403 loc, msg: "expected at least one replacement value, consider using "
2404 "`erase` if no replacement values are desired");
2405 }
2406
2407 do {
2408 FailureOr<ast::Expr *> replExpr = parseExpr();
2409 if (failed(result: replExpr))
2410 return failure();
2411 replValues.emplace_back(Args&: *replExpr);
2412 } while (consumeIf(kind: Token::comma));
2413
2414 if (failed(result: parseToken(kind: Token::r_paren,
2415 msg: "expected `)` after replacement values")))
2416 return failure();
2417 } else {
2418 // Handle replacement with an operation uniquely, as the replacement
2419 // operation supports type inferrence from the root operation.
2420 FailureOr<ast::Expr *> replExpr;
2421 if (curToken.is(k: Token::kw_op))
2422 replExpr = parseOperationExpr(inputResultTypeContext: OpResultTypeContext::Replacement);
2423 else
2424 replExpr = parseExpr();
2425 if (failed(result: replExpr))
2426 return failure();
2427 replValues.emplace_back(Args&: *replExpr);
2428 }
2429
2430 return createReplaceStmt(loc, rootOp: *rootOp, replValues);
2431}
2432
2433FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2434 SMRange loc = curToken.getLoc();
2435 consumeToken(kind: Token::kw_return);
2436
2437 // Parse the result value.
2438 FailureOr<ast::Expr *> resultExpr = parseExpr();
2439 if (failed(result: resultExpr))
2440 return failure();
2441
2442 return ast::ReturnStmt::create(ctx, loc, resultExpr: *resultExpr);
2443}
2444
2445FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2446 if (parserContext == ParserContext::Constraint)
2447 return emitError(msg: "`rewrite` cannot be used within a Constraint");
2448 SMRange loc = curToken.getLoc();
2449 consumeToken(kind: Token::kw_rewrite);
2450
2451 // Parse the root operation.
2452 FailureOr<ast::Expr *> rootOp = parseExpr();
2453 if (failed(result: rootOp))
2454 return failure();
2455
2456 if (failed(result: parseToken(kind: Token::kw_with, msg: "expected `with` before rewrite body")))
2457 return failure();
2458
2459 if (curToken.isNot(k: Token::l_brace))
2460 return emitError(msg: "expected `{` to start rewrite body");
2461
2462 // The rewrite body of this statement is within a rewrite context.
2463 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2464
2465 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2466 if (failed(result: rewriteBody))
2467 return failure();
2468
2469 // Verify the rewrite body.
2470 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2471 if (isa<ast::ReturnStmt>(Val: stmt)) {
2472 return emitError(loc: stmt->getLoc(),
2473 msg: "`return` statements are only permitted within a "
2474 "`Constraint` or `Rewrite` body");
2475 }
2476 }
2477
2478 return createRewriteStmt(loc, rootOp: *rootOp, rewriteBody: *rewriteBody);
2479}
2480
2481//===----------------------------------------------------------------------===//
2482// Creation+Analysis
2483//===----------------------------------------------------------------------===//
2484
2485//===----------------------------------------------------------------------===//
2486// Decls
2487
2488ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2489 // Unwrap reference expressions.
2490 if (auto *init = dyn_cast<ast::DeclRefExpr>(Val: node))
2491 node = init->getDecl();
2492 return dyn_cast<ast::CallableDecl>(Val: node);
2493}
2494
2495FailureOr<ast::PatternDecl *>
2496Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2497 const ParsedPatternMetadata &metadata,
2498 ast::CompoundStmt *body) {
2499 return ast::PatternDecl::create(ctx, location: loc, name, benefit: metadata.benefit,
2500 hasBoundedRecursion: metadata.hasBoundedRecursion, body);
2501}
2502
2503ast::Type Parser::createUserConstraintRewriteResultType(
2504 ArrayRef<ast::VariableDecl *> results) {
2505 // Single result decls use the type of the single result.
2506 if (results.size() == 1)
2507 return results[0]->getType();
2508
2509 // Multiple results use a tuple type, with the types and names grabbed from
2510 // the result variable decls.
2511 auto resultTypes = llvm::map_range(
2512 C&: results, F: [&](const auto *result) { return result->getType(); });
2513 auto resultNames = llvm::map_range(
2514 C&: results, F: [&](const auto *result) { return result->getName().getName(); });
2515 return ast::TupleType::get(context&: ctx, elementTypes: llvm::to_vector(Range&: resultTypes),
2516 elementNames: llvm::to_vector(Range&: resultNames));
2517}
2518
2519template <typename T>
2520FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2521 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2522 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2523 ast::CompoundStmt *body) {
2524 if (!body->getChildren().empty()) {
2525 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(Val: body->getChildren().back())) {
2526 ast::Expr *resultExpr = retStmt->getResultExpr();
2527
2528 // Process the result of the decl. If no explicit signature results
2529 // were provided, check for return type inference. Otherwise, check that
2530 // the return expression can be converted to the expected type.
2531 if (results.empty())
2532 resultType = resultExpr->getType();
2533 else if (failed(result: convertExpressionTo(expr&: resultExpr, type: resultType)))
2534 return failure();
2535 else
2536 retStmt->setResultExpr(resultExpr);
2537 }
2538 }
2539 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2540}
2541
2542FailureOr<ast::VariableDecl *>
2543Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2544 ArrayRef<ast::ConstraintRef> constraints) {
2545 // The type of the variable, which is expected to be inferred by either a
2546 // constraint or an initializer expression.
2547 ast::Type type;
2548 if (failed(result: validateVariableConstraints(constraints, inferredType&: type)))
2549 return failure();
2550
2551 if (initializer) {
2552 // Update the variable type based on the initializer, or try to convert the
2553 // initializer to the existing type.
2554 if (!type)
2555 type = initializer->getType();
2556 else if (ast::Type mergedType = type.refineWith(other: initializer->getType()))
2557 type = mergedType;
2558 else if (failed(result: convertExpressionTo(expr&: initializer, type)))
2559 return failure();
2560
2561 // Otherwise, if there is no initializer check that the type has already
2562 // been resolved from the constraint list.
2563 } else if (!type) {
2564 return emitErrorAndNote(
2565 loc, msg: "unable to infer type for variable `" + name + "`", noteLoc: loc,
2566 note: "the type of a variable must be inferable from the constraint "
2567 "list or the initializer");
2568 }
2569
2570 // Constraint types cannot be used when defining variables.
2571 if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2572 return emitError(
2573 loc, msg: llvm::formatv(Fmt: "unable to define variable of `{0}` type", Vals&: type));
2574 }
2575
2576 // Try to define a variable with the given name.
2577 FailureOr<ast::VariableDecl *> varDecl =
2578 defineVariableDecl(name, nameLoc: loc, type, initExpr: initializer, constraints);
2579 if (failed(result: varDecl))
2580 return failure();
2581
2582 return *varDecl;
2583}
2584
2585FailureOr<ast::VariableDecl *>
2586Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2587 const ast::ConstraintRef &constraint) {
2588 ast::Type argType;
2589 if (failed(result: validateVariableConstraint(ref: constraint, inferredType&: argType)))
2590 return failure();
2591 return defineVariableDecl(name, nameLoc: loc, type: argType, constraints: constraint);
2592}
2593
2594LogicalResult
2595Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2596 ast::Type &inferredType) {
2597 for (const ast::ConstraintRef &ref : constraints)
2598 if (failed(result: validateVariableConstraint(ref, inferredType)))
2599 return failure();
2600 return success();
2601}
2602
2603LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2604 ast::Type &inferredType) {
2605 ast::Type constraintType;
2606 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(Val: ref.constraint)) {
2607 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2608 if (failed(result: validateTypeConstraintExpr(typeExpr)))
2609 return failure();
2610 }
2611 constraintType = ast::AttributeType::get(context&: ctx);
2612 } else if (const auto *cst =
2613 dyn_cast<ast::OpConstraintDecl>(Val: ref.constraint)) {
2614 constraintType = ast::OperationType::get(
2615 context&: ctx, name: cst->getName(), odsOp: lookupODSOperation(opName: cst->getName()));
2616 } else if (isa<ast::TypeConstraintDecl>(Val: ref.constraint)) {
2617 constraintType = typeTy;
2618 } else if (isa<ast::TypeRangeConstraintDecl>(Val: ref.constraint)) {
2619 constraintType = typeRangeTy;
2620 } else if (const auto *cst =
2621 dyn_cast<ast::ValueConstraintDecl>(Val: ref.constraint)) {
2622 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2623 if (failed(result: validateTypeConstraintExpr(typeExpr)))
2624 return failure();
2625 }
2626 constraintType = valueTy;
2627 } else if (const auto *cst =
2628 dyn_cast<ast::ValueRangeConstraintDecl>(Val: ref.constraint)) {
2629 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2630 if (failed(result: validateTypeRangeConstraintExpr(typeExpr)))
2631 return failure();
2632 }
2633 constraintType = valueRangeTy;
2634 } else if (const auto *cst =
2635 dyn_cast<ast::UserConstraintDecl>(Val: ref.constraint)) {
2636 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2637 if (inputs.size() != 1) {
2638 return emitErrorAndNote(loc: ref.referenceLoc,
2639 msg: "`Constraint`s applied via a variable constraint "
2640 "list must take a single input, but got " +
2641 Twine(inputs.size()),
2642 noteLoc: cst->getLoc(),
2643 note: "see definition of constraint here");
2644 }
2645 constraintType = inputs.front()->getType();
2646 } else {
2647 llvm_unreachable("unknown constraint type");
2648 }
2649
2650 // Check that the constraint type is compatible with the current inferred
2651 // type.
2652 if (!inferredType) {
2653 inferredType = constraintType;
2654 } else if (ast::Type mergedTy = inferredType.refineWith(other: constraintType)) {
2655 inferredType = mergedTy;
2656 } else {
2657 return emitError(loc: ref.referenceLoc,
2658 msg: llvm::formatv(Fmt: "constraint type `{0}` is incompatible "
2659 "with the previously inferred type `{1}`",
2660 Vals&: constraintType, Vals&: inferredType));
2661 }
2662 return success();
2663}
2664
2665LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2666 ast::Type typeExprType = typeExpr->getType();
2667 if (typeExprType != typeTy) {
2668 return emitError(loc: typeExpr->getLoc(),
2669 msg: "expected expression of `Type` in type constraint");
2670 }
2671 return success();
2672}
2673
2674LogicalResult
2675Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2676 ast::Type typeExprType = typeExpr->getType();
2677 if (typeExprType != typeRangeTy) {
2678 return emitError(loc: typeExpr->getLoc(),
2679 msg: "expected expression of `TypeRange` in type constraint");
2680 }
2681 return success();
2682}
2683
2684//===----------------------------------------------------------------------===//
2685// Exprs
2686
2687FailureOr<ast::CallExpr *>
2688Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2689 MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2690 ast::Type parentType = parentExpr->getType();
2691
2692 ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parentExpr);
2693 if (!callableDecl) {
2694 return emitError(loc,
2695 msg: llvm::formatv(Fmt: "expected a reference to a callable "
2696 "`Constraint` or `Rewrite`, but got: `{0}`",
2697 Vals&: parentType));
2698 }
2699 if (parserContext == ParserContext::Rewrite) {
2700 if (isa<ast::UserConstraintDecl>(Val: callableDecl))
2701 return emitError(
2702 loc, msg: "unable to invoke `Constraint` within a rewrite section");
2703 if (isNegated)
2704 return emitError(loc, msg: "unable to negate a Rewrite");
2705 } else {
2706 if (isa<ast::UserRewriteDecl>(Val: callableDecl))
2707 return emitError(loc,
2708 msg: "unable to invoke `Rewrite` within a match section");
2709 if (isNegated && cast<ast::UserConstraintDecl>(Val: callableDecl)->getBody())
2710 return emitError(loc, msg: "unable to negate non native constraints");
2711 }
2712
2713 // Verify the arguments of the call.
2714 /// Handle size mismatch.
2715 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2716 if (callArgs.size() != arguments.size()) {
2717 return emitErrorAndNote(
2718 loc,
2719 msg: llvm::formatv(Fmt: "invalid number of arguments for {0} call; expected "
2720 "{1}, but got {2}",
2721 Vals: callableDecl->getCallableType(), Vals: callArgs.size(),
2722 Vals: arguments.size()),
2723 noteLoc: callableDecl->getLoc(),
2724 note: llvm::formatv(Fmt: "see the definition of {0} here",
2725 Vals: callableDecl->getName()->getName()));
2726 }
2727
2728 /// Handle argument type mismatch.
2729 auto attachDiagFn = [&](ast::Diagnostic &diag) {
2730 diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here",
2731 Vals: callableDecl->getName()->getName()),
2732 noteLoc: callableDecl->getLoc());
2733 };
2734 for (auto it : llvm::zip(t&: callArgs, u&: arguments)) {
2735 if (failed(result: convertExpressionTo(expr&: std::get<1>(t&: it), type: std::get<0>(t&: it)->getType(),
2736 noteAttachFn: attachDiagFn)))
2737 return failure();
2738 }
2739
2740 return ast::CallExpr::create(ctx, loc, callable: parentExpr, arguments,
2741 resultType: callableDecl->getResultType(), isNegated);
2742}
2743
2744FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2745 ast::Decl *decl) {
2746 // Check the type of decl being referenced.
2747 ast::Type declType;
2748 if (isa<ast::ConstraintDecl>(Val: decl))
2749 declType = ast::ConstraintType::get(context&: ctx);
2750 else if (isa<ast::UserRewriteDecl>(Val: decl))
2751 declType = ast::RewriteType::get(context&: ctx);
2752 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl))
2753 declType = varDecl->getType();
2754 else
2755 return emitError(loc, msg: "invalid reference to `" +
2756 decl->getName()->getName() + "`");
2757
2758 return ast::DeclRefExpr::create(ctx, loc, decl, type: declType);
2759}
2760
2761FailureOr<ast::DeclRefExpr *>
2762Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2763 ArrayRef<ast::ConstraintRef> constraints) {
2764 FailureOr<ast::VariableDecl *> decl =
2765 defineVariableDecl(name, nameLoc: loc, type, constraints);
2766 if (failed(result: decl))
2767 return failure();
2768 return ast::DeclRefExpr::create(ctx, loc, decl: *decl, type);
2769}
2770
2771FailureOr<ast::MemberAccessExpr *>
2772Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2773 SMRange loc) {
2774 // Validate the member name for the given parent expression.
2775 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2776 if (failed(result: memberType))
2777 return failure();
2778
2779 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, memberName: name, type: *memberType);
2780}
2781
2782FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2783 StringRef name, SMRange loc) {
2784 ast::Type parentType = parentExpr->getType();
2785 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2786 if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2787 return valueRangeTy;
2788
2789 // Verify member access based on the operation type.
2790 if (const ods::Operation *odsOp = opType.getODSOperation()) {
2791 auto results = odsOp->getResults();
2792
2793 // Handle indexed results.
2794 unsigned index = 0;
2795 if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) &&
2796 index < results.size()) {
2797 return results[index].isVariadic() ? valueRangeTy : valueTy;
2798 }
2799
2800 // Handle named results.
2801 const auto *it = llvm::find_if(Range&: results, P: [&](const auto &result) {
2802 return result.getName() == name;
2803 });
2804 if (it != results.end())
2805 return it->isVariadic() ? valueRangeTy : valueTy;
2806 } else if (llvm::isDigit(C: name[0])) {
2807 // Allow unchecked numeric indexing of the results of unregistered
2808 // operations. It returns a single value.
2809 return valueTy;
2810 }
2811 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2812 // Handle indexed results.
2813 unsigned index = 0;
2814 if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) &&
2815 index < tupleType.size()) {
2816 return tupleType.getElementTypes()[index];
2817 }
2818
2819 // Handle named results.
2820 auto elementNames = tupleType.getElementNames();
2821 const auto *it = llvm::find(Range&: elementNames, Val: name);
2822 if (it != elementNames.end())
2823 return tupleType.getElementTypes()[it - elementNames.begin()];
2824 }
2825 return emitError(
2826 loc,
2827 msg: llvm::formatv(Fmt: "invalid member access `{0}` on expression of type `{1}`",
2828 Vals&: name, Vals&: parentType));
2829}
2830
2831FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2832 SMRange loc, const ast::OpNameDecl *name,
2833 OpResultTypeContext resultTypeContext,
2834 SmallVectorImpl<ast::Expr *> &operands,
2835 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2836 SmallVectorImpl<ast::Expr *> &results) {
2837 std::optional<StringRef> opNameRef = name->getName();
2838 const ods::Operation *odsOp = lookupODSOperation(opName: opNameRef);
2839
2840 // Verify the inputs operands.
2841 if (failed(result: validateOperationOperands(loc, name: opNameRef, odsOp, operands)))
2842 return failure();
2843
2844 // Verify the attribute list.
2845 for (ast::NamedAttributeDecl *attr : attributes) {
2846 // Check for an attribute type, or a type awaiting resolution.
2847 ast::Type attrType = attr->getValue()->getType();
2848 if (!attrType.isa<ast::AttributeType>()) {
2849 return emitError(
2850 loc: attr->getValue()->getLoc(),
2851 msg: llvm::formatv(Fmt: "expected `Attr` expression, but got `{0}`", Vals&: attrType));
2852 }
2853 }
2854
2855 assert(
2856 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2857 "unexpected inferrence when results were explicitly specified");
2858
2859 // If we aren't relying on type inferrence, or explicit results were provided,
2860 // validate them.
2861 if (resultTypeContext == OpResultTypeContext::Explicit) {
2862 if (failed(result: validateOperationResults(loc, name: opNameRef, odsOp, results)))
2863 return failure();
2864
2865 // Validate the use of interface based type inferrence for this operation.
2866 } else if (resultTypeContext == OpResultTypeContext::Interface) {
2867 assert(opNameRef &&
2868 "expected valid operation name when inferring operation results");
2869 checkOperationResultTypeInferrence(loc, name: *opNameRef, odsOp);
2870 }
2871
2872 return ast::OperationExpr::create(ctx, loc, odsOp, nameDecl: name, operands, resultTypes: results,
2873 attributes);
2874}
2875
2876LogicalResult
2877Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2878 const ods::Operation *odsOp,
2879 SmallVectorImpl<ast::Expr *> &operands) {
2880 return validateOperationOperandsOrResults(
2881 groupName: "operand", loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2882 values&: operands, odsValues: odsOp ? odsOp->getOperands() : std::nullopt, singleTy: valueTy,
2883 rangeTy: valueRangeTy);
2884}
2885
2886LogicalResult
2887Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2888 const ods::Operation *odsOp,
2889 SmallVectorImpl<ast::Expr *> &results) {
2890 return validateOperationOperandsOrResults(
2891 groupName: "result", loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2892 values&: results, odsValues: odsOp ? odsOp->getResults() : std::nullopt, singleTy: typeTy, rangeTy: typeRangeTy);
2893}
2894
2895void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2896 const ods::Operation *odsOp) {
2897 // If the operation might not have inferrence support, emit a warning to the
2898 // user. We don't emit an error because the interface might be added to the
2899 // operation at runtime. It's rare, but it could still happen. We emit a
2900 // warning here instead.
2901
2902 // Handle inferrence warnings for unknown operations.
2903 if (!odsOp) {
2904 ctx.getDiagEngine().emitWarning(
2905 loc, msg: llvm::formatv(
2906 Fmt: "operation result types are marked to be inferred, but "
2907 "`{0}` is unknown. Ensure that `{0}` supports zero "
2908 "results or implements `InferTypeOpInterface`. Include "
2909 "the ODS definition of this operation to remove this warning.",
2910 Vals&: opName));
2911 return;
2912 }
2913
2914 // Handle inferrence warnings for known operations that expected at least one
2915 // result, but don't have inference support. An elided results list can mean
2916 // "zero-results", and we don't want to warn when that is the expected
2917 // behavior.
2918 bool requiresInferrence =
2919 llvm::any_of(Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) {
2920 return !result.isVariableLength();
2921 });
2922 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2923 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2924 loc,
2925 msg: llvm::formatv(Fmt: "operation result types are marked to be inferred, but "
2926 "`{0}` does not provide an implementation of "
2927 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2928 "`InferTypeOpInterface` at runtime, or add support to "
2929 "the ODS definition to remove this warning.",
2930 Vals&: opName));
2931 diag->attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: opName),
2932 noteLoc: odsOp->getLoc());
2933 return;
2934 }
2935}
2936
2937LogicalResult Parser::validateOperationOperandsOrResults(
2938 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2939 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2940 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2941 ast::RangeType rangeTy) {
2942 // All operation types accept a single range parameter.
2943 if (values.size() == 1) {
2944 if (failed(result: convertExpressionTo(expr&: values[0], type: rangeTy)))
2945 return failure();
2946 return success();
2947 }
2948
2949 /// If the operation has ODS information, we can more accurately verify the
2950 /// values.
2951 if (odsOpLoc) {
2952 auto emitSizeMismatchError = [&] {
2953 return emitErrorAndNote(
2954 loc,
2955 msg: llvm::formatv(Fmt: "invalid number of {0} groups for `{1}`; expected "
2956 "{2}, but got {3}",
2957 Vals&: groupName, Vals&: *name, Vals: odsValues.size(), Vals: values.size()),
2958 noteLoc: *odsOpLoc, note: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: *name));
2959 };
2960
2961 // Handle the case where no values were provided.
2962 if (values.empty()) {
2963 // If we don't expect any on the ODS side, we are done.
2964 if (odsValues.empty())
2965 return success();
2966
2967 // If we do, check if we actually need to provide values (i.e. if any of
2968 // the values are actually required).
2969 unsigned numVariadic = 0;
2970 for (const auto &odsValue : odsValues) {
2971 if (!odsValue.isVariableLength())
2972 return emitSizeMismatchError();
2973 ++numVariadic;
2974 }
2975
2976 // If we are in a non-rewrite context, we don't need to do anything more.
2977 // Zero-values is a valid constraint on the operation.
2978 if (parserContext != ParserContext::Rewrite)
2979 return success();
2980
2981 // Otherwise, when in a rewrite we may need to provide values to match the
2982 // ODS signature of the operation to create.
2983
2984 // If we only have one variadic value, just use an empty list.
2985 if (numVariadic == 1)
2986 return success();
2987
2988 // Otherwise, create dummy values for each of the entries so that we
2989 // adhere to the ODS signature.
2990 for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2991 values.push_back(Elt: ast::RangeExpr::create(
2992 ctx, loc, /*elements=*/std::nullopt, type: rangeTy));
2993 }
2994 return success();
2995 }
2996
2997 // Verify that the number of values provided matches the number of value
2998 // groups ODS expects.
2999 if (odsValues.size() != values.size())
3000 return emitSizeMismatchError();
3001
3002 auto diagFn = [&](ast::Diagnostic &diag) {
3003 diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: *name),
3004 noteLoc: *odsOpLoc);
3005 };
3006 for (unsigned i = 0, e = values.size(); i < e; ++i) {
3007 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3008 if (failed(result: convertExpressionTo(expr&: values[i], type: expectedType, noteAttachFn: diagFn)))
3009 return failure();
3010 }
3011 return success();
3012 }
3013
3014 // Otherwise, accept the value groups as they have been defined and just
3015 // ensure they are one of the expected types.
3016 for (ast::Expr *&valueExpr : values) {
3017 ast::Type valueExprType = valueExpr->getType();
3018
3019 // Check if this is one of the expected types.
3020 if (valueExprType == rangeTy || valueExprType == singleTy)
3021 continue;
3022
3023 // If the operand is an Operation, allow converting to a Value or
3024 // ValueRange. This situations arises quite often with nested operation
3025 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3026 if (singleTy == valueTy) {
3027 if (valueExprType.isa<ast::OperationType>()) {
3028 valueExpr = convertOpToValue(opExpr: valueExpr);
3029 continue;
3030 }
3031 }
3032
3033 // Otherwise, try to convert the expression to a range.
3034 if (succeeded(result: convertExpressionTo(expr&: valueExpr, type: rangeTy)))
3035 continue;
3036
3037 return emitError(
3038 loc: valueExpr->getLoc(),
3039 msg: llvm::formatv(
3040 Fmt: "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3041 Vals&: singleTy, Vals&: rangeTy, Vals&: valueExprType));
3042 }
3043 return success();
3044}
3045
3046FailureOr<ast::TupleExpr *>
3047Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3048 ArrayRef<StringRef> elementNames) {
3049 for (const ast::Expr *element : elements) {
3050 ast::Type eleTy = element->getType();
3051 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
3052 return emitError(
3053 loc: element->getLoc(),
3054 msg: llvm::formatv(Fmt: "unable to build a tuple with `{0}` element", Vals&: eleTy));
3055 }
3056 }
3057 return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3058}
3059
3060//===----------------------------------------------------------------------===//
3061// Stmts
3062
3063FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3064 ast::Expr *rootOp) {
3065 // Check that root is an Operation.
3066 ast::Type rootType = rootOp->getType();
3067 if (!rootType.isa<ast::OperationType>())
3068 return emitError(loc: rootOp->getLoc(), msg: "expected `Op` expression");
3069
3070 return ast::EraseStmt::create(ctx, loc, rootOp);
3071}
3072
3073FailureOr<ast::ReplaceStmt *>
3074Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3075 MutableArrayRef<ast::Expr *> replValues) {
3076 // Check that root is an Operation.
3077 ast::Type rootType = rootOp->getType();
3078 if (!rootType.isa<ast::OperationType>()) {
3079 return emitError(
3080 loc: rootOp->getLoc(),
3081 msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`", Vals&: rootType));
3082 }
3083
3084 // If there are multiple replacement values, we implicitly convert any Op
3085 // expressions to the value form.
3086 bool shouldConvertOpToValues = replValues.size() > 1;
3087 for (ast::Expr *&replExpr : replValues) {
3088 ast::Type replType = replExpr->getType();
3089
3090 // Check that replExpr is an Operation, Value, or ValueRange.
3091 if (replType.isa<ast::OperationType>()) {
3092 if (shouldConvertOpToValues)
3093 replExpr = convertOpToValue(opExpr: replExpr);
3094 continue;
3095 }
3096
3097 if (replType != valueTy && replType != valueRangeTy) {
3098 return emitError(loc: replExpr->getLoc(),
3099 msg: llvm::formatv(Fmt: "expected `Op`, `Value` or `ValueRange` "
3100 "expression, but got `{0}`",
3101 Vals&: replType));
3102 }
3103 }
3104
3105 return ast::ReplaceStmt::create(ctx, loc, rootOp, replExprs: replValues);
3106}
3107
3108FailureOr<ast::RewriteStmt *>
3109Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3110 ast::CompoundStmt *rewriteBody) {
3111 // Check that root is an Operation.
3112 ast::Type rootType = rootOp->getType();
3113 if (!rootType.isa<ast::OperationType>()) {
3114 return emitError(
3115 loc: rootOp->getLoc(),
3116 msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`", Vals&: rootType));
3117 }
3118
3119 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3120}
3121
3122//===----------------------------------------------------------------------===//
3123// Code Completion
3124//===----------------------------------------------------------------------===//
3125
3126LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3127 ast::Type parentType = parentExpr->getType();
3128 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3129 codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3130 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3131 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3132 return failure();
3133}
3134
3135LogicalResult
3136Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3137 if (opName)
3138 codeCompleteContext->codeCompleteOperationAttributeName(opName: *opName);
3139 return failure();
3140}
3141
3142LogicalResult
3143Parser::codeCompleteConstraintName(ast::Type inferredType,
3144 bool allowInlineTypeConstraints) {
3145 codeCompleteContext->codeCompleteConstraintName(
3146 currentType: inferredType, allowInlineTypeConstraints, scope: curDeclScope);
3147 return failure();
3148}
3149
3150LogicalResult Parser::codeCompleteDialectName() {
3151 codeCompleteContext->codeCompleteDialectName();
3152 return failure();
3153}
3154
3155LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3156 codeCompleteContext->codeCompleteOperationName(dialectName);
3157 return failure();
3158}
3159
3160LogicalResult Parser::codeCompletePatternMetadata() {
3161 codeCompleteContext->codeCompletePatternMetadata();
3162 return failure();
3163}
3164
3165LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3166 codeCompleteContext->codeCompleteIncludeFilename(curPath);
3167 return failure();
3168}
3169
3170void Parser::codeCompleteCallSignature(ast::Node *parent,
3171 unsigned currentNumArgs) {
3172 ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parent);
3173 if (!callableDecl)
3174 return;
3175
3176 codeCompleteContext->codeCompleteCallSignature(callable: callableDecl, currentNumArgs);
3177}
3178
3179void Parser::codeCompleteOperationOperandsSignature(
3180 std::optional<StringRef> opName, unsigned currentNumOperands) {
3181 codeCompleteContext->codeCompleteOperationOperandsSignature(
3182 opName, currentNumOperands);
3183}
3184
3185void Parser::codeCompleteOperationResultsSignature(
3186 std::optional<StringRef> opName, unsigned currentNumResults) {
3187 codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3188 currentNumResults);
3189}
3190
3191//===----------------------------------------------------------------------===//
3192// Parser
3193//===----------------------------------------------------------------------===//
3194
3195FailureOr<ast::Module *>
3196mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3197 bool enableDocumentation,
3198 CodeCompleteContext *codeCompleteContext) {
3199 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3200 return parser.parseModule();
3201}
3202

source code of mlir/lib/Tools/PDLL/Parser/Parser.cpp