1 | //===- Parser.cpp ---------------------------------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/PDLL/Parser/Parser.h" |
10 | #include "Lexer.h" |
11 | #include "mlir/Support/IndentedOstream.h" |
12 | #include "mlir/TableGen/Argument.h" |
13 | #include "mlir/TableGen/Attribute.h" |
14 | #include "mlir/TableGen/Constraint.h" |
15 | #include "mlir/TableGen/Format.h" |
16 | #include "mlir/TableGen/Operator.h" |
17 | #include "mlir/Tools/PDLL/AST/Context.h" |
18 | #include "mlir/Tools/PDLL/AST/Diagnostic.h" |
19 | #include "mlir/Tools/PDLL/AST/Nodes.h" |
20 | #include "mlir/Tools/PDLL/AST/Types.h" |
21 | #include "mlir/Tools/PDLL/ODS/Constraint.h" |
22 | #include "mlir/Tools/PDLL/ODS/Context.h" |
23 | #include "mlir/Tools/PDLL/ODS/Operation.h" |
24 | #include "mlir/Tools/PDLL/Parser/CodeComplete.h" |
25 | #include "llvm/ADT/StringExtras.h" |
26 | #include "llvm/ADT/TypeSwitch.h" |
27 | #include "llvm/Support/FormatVariadic.h" |
28 | #include "llvm/Support/ManagedStatic.h" |
29 | #include "llvm/Support/SaveAndRestore.h" |
30 | #include "llvm/Support/ScopedPrinter.h" |
31 | #include "llvm/TableGen/Error.h" |
32 | #include "llvm/TableGen/Parser.h" |
33 | #include <optional> |
34 | #include <string> |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::pdll; |
38 | |
39 | //===----------------------------------------------------------------------===// |
40 | // Parser |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | namespace { |
44 | class Parser { |
45 | public: |
46 | Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, |
47 | bool enableDocumentation, CodeCompleteContext *codeCompleteContext) |
48 | : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), |
49 | curToken(lexer.lexToken()), enableDocumentation(enableDocumentation), |
50 | typeTy(ast::TypeType::get(context&: ctx)), valueTy(ast::ValueType::get(context&: ctx)), |
51 | typeRangeTy(ast::TypeRangeType::get(context&: ctx)), |
52 | valueRangeTy(ast::ValueRangeType::get(context&: ctx)), |
53 | attrTy(ast::AttributeType::get(context&: ctx)), |
54 | codeCompleteContext(codeCompleteContext) {} |
55 | |
56 | /// Try to parse a new module. Returns nullptr in the case of failure. |
57 | FailureOr<ast::Module *> parseModule(); |
58 | |
59 | private: |
60 | /// The current context of the parser. It allows for the parser to know a bit |
61 | /// about the construct it is nested within during parsing. This is used |
62 | /// specifically to provide additional verification during parsing, e.g. to |
63 | /// prevent using rewrites within a match context, matcher constraints within |
64 | /// a rewrite section, etc. |
65 | enum class ParserContext { |
66 | /// The parser is in the global context. |
67 | Global, |
68 | /// The parser is currently within a Constraint, which disallows all types |
69 | /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). |
70 | Constraint, |
71 | /// The parser is currently within the matcher portion of a Pattern, which |
72 | /// is allows a terminal operation rewrite statement but no other rewrite |
73 | /// transformations. |
74 | PatternMatch, |
75 | /// The parser is currently within a Rewrite, which disallows calls to |
76 | /// constraints, requires operation expressions to have names, etc. |
77 | Rewrite, |
78 | }; |
79 | |
80 | /// The current specification context of an operations result type. This |
81 | /// indicates how the result types of an operation may be inferred. |
82 | enum class OpResultTypeContext { |
83 | /// The result types of the operation are not known to be inferred. |
84 | Explicit, |
85 | /// The result types of the operation are inferred from the root input of a |
86 | /// `replace` statement. |
87 | Replacement, |
88 | /// The result types of the operation are inferred by using the |
89 | /// `InferTypeOpInterface` interface provided by the operation. |
90 | Interface, |
91 | }; |
92 | |
93 | //===--------------------------------------------------------------------===// |
94 | // Parsing |
95 | //===--------------------------------------------------------------------===// |
96 | |
97 | /// Push a new decl scope onto the lexer. |
98 | ast::DeclScope *pushDeclScope() { |
99 | ast::DeclScope *newScope = |
100 | new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); |
101 | return (curDeclScope = newScope); |
102 | } |
103 | void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } |
104 | |
105 | /// Pop the last decl scope from the lexer. |
106 | void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } |
107 | |
108 | /// Parse the body of an AST module. |
109 | LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); |
110 | |
111 | /// Try to convert the given expression to `type`. Returns failure and emits |
112 | /// an error if a conversion is not viable. On failure, `noteAttachFn` is |
113 | /// invoked to attach notes to the emitted error diagnostic. On success, |
114 | /// `expr` is updated to the expression used to convert to `type`. |
115 | LogicalResult convertExpressionTo( |
116 | ast::Expr *&expr, ast::Type type, |
117 | function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); |
118 | LogicalResult |
119 | convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType, |
120 | ast::Type type, |
121 | function_ref<ast::InFlightDiagnostic()> emitErrorFn); |
122 | LogicalResult convertTupleExpressionTo( |
123 | ast::Expr *&expr, ast::TupleType exprType, ast::Type type, |
124 | function_ref<ast::InFlightDiagnostic()> emitErrorFn, |
125 | function_ref<void(ast::Diagnostic &diag)> noteAttachFn); |
126 | |
127 | /// Given an operation expression, convert it to a Value or ValueRange |
128 | /// typed expression. |
129 | ast::Expr *convertOpToValue(const ast::Expr *opExpr); |
130 | |
131 | /// Lookup ODS information for the given operation, returns nullptr if no |
132 | /// information is found. |
133 | const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) { |
134 | return opName ? ctx.getODSContext().lookupOperation(name: *opName) : nullptr; |
135 | } |
136 | |
137 | /// Process the given documentation string, or return an empty string if |
138 | /// documentation isn't enabled. |
139 | StringRef processDoc(StringRef doc) { |
140 | return enableDocumentation ? doc : StringRef(); |
141 | } |
142 | |
143 | /// Process the given documentation string and format it, or return an empty |
144 | /// string if documentation isn't enabled. |
145 | std::string processAndFormatDoc(const Twine &doc) { |
146 | if (!enableDocumentation) |
147 | return ""; |
148 | std::string docStr; |
149 | { |
150 | llvm::raw_string_ostream docOS(docStr); |
151 | raw_indented_ostream(docOS).printReindented( |
152 | str: StringRef(docStr).rtrim(Chars: " \t")); |
153 | } |
154 | return docStr; |
155 | } |
156 | |
157 | //===--------------------------------------------------------------------===// |
158 | // Directives |
159 | |
160 | LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); |
161 | LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); |
162 | LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, |
163 | SmallVectorImpl<ast::Decl *> &decls); |
164 | |
165 | /// Process the records of a parsed tablegen include file. |
166 | void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords, |
167 | SmallVectorImpl<ast::Decl *> &decls); |
168 | |
169 | /// Create a user defined native constraint for a constraint imported from |
170 | /// ODS. |
171 | template <typename ConstraintT> |
172 | ast::Decl * |
173 | createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, |
174 | SMRange loc, ast::Type type, |
175 | StringRef nativeType, StringRef docString); |
176 | template <typename ConstraintT> |
177 | ast::Decl * |
178 | createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, |
179 | SMRange loc, ast::Type type, |
180 | StringRef nativeType); |
181 | |
182 | //===--------------------------------------------------------------------===// |
183 | // Decls |
184 | |
185 | /// This structure contains the set of pattern metadata that may be parsed. |
186 | struct ParsedPatternMetadata { |
187 | std::optional<uint16_t> benefit; |
188 | bool hasBoundedRecursion = false; |
189 | }; |
190 | |
191 | FailureOr<ast::Decl *> parseTopLevelDecl(); |
192 | FailureOr<ast::NamedAttributeDecl *> |
193 | parseNamedAttributeDecl(std::optional<StringRef> parentOpName); |
194 | |
195 | /// Parse an argument variable as part of the signature of a |
196 | /// UserConstraintDecl or UserRewriteDecl. |
197 | FailureOr<ast::VariableDecl *> parseArgumentDecl(); |
198 | |
199 | /// Parse a result variable as part of the signature of a UserConstraintDecl |
200 | /// or UserRewriteDecl. |
201 | FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); |
202 | |
203 | /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being |
204 | /// defined in a non-global context. |
205 | FailureOr<ast::UserConstraintDecl *> |
206 | parseUserConstraintDecl(bool isInline = false); |
207 | |
208 | /// Parse an inline UserConstraintDecl. An inline decl is one defined in a |
209 | /// non-global context, such as within a Pattern/Constraint/etc. |
210 | FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); |
211 | |
212 | /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using |
213 | /// PDLL constructs. |
214 | FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( |
215 | const ast::Name &name, bool isInline, |
216 | ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, |
217 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType); |
218 | |
219 | /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being |
220 | /// defined in a non-global context. |
221 | FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); |
222 | |
223 | /// Parse an inline UserRewriteDecl. An inline decl is one defined in a |
224 | /// non-global context, such as within a Pattern/Rewrite/etc. |
225 | FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); |
226 | |
227 | /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using |
228 | /// PDLL constructs. |
229 | FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( |
230 | const ast::Name &name, bool isInline, |
231 | ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, |
232 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType); |
233 | |
234 | /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have |
235 | /// effectively the same syntax, and only differ on slight semantics (given |
236 | /// the different parsing contexts). |
237 | template <typename T, typename ParseUserPDLLDeclFnT> |
238 | FailureOr<T *> parseUserConstraintOrRewriteDecl( |
239 | ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, |
240 | StringRef anonymousNamePrefix, bool isInline); |
241 | |
242 | /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. |
243 | /// These decls have effectively the same syntax. |
244 | template <typename T> |
245 | FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( |
246 | const ast::Name &name, bool isInline, |
247 | ArrayRef<ast::VariableDecl *> arguments, |
248 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType); |
249 | |
250 | /// Parse the functional signature (i.e. the arguments and results) of a |
251 | /// UserConstraintDecl or UserRewriteDecl. |
252 | LogicalResult parseUserConstraintOrRewriteSignature( |
253 | SmallVectorImpl<ast::VariableDecl *> &arguments, |
254 | SmallVectorImpl<ast::VariableDecl *> &results, |
255 | ast::DeclScope *&argumentScope, ast::Type &resultType); |
256 | |
257 | /// Validate the return (which if present is specified by bodyIt) of a |
258 | /// UserConstraintDecl or UserRewriteDecl. |
259 | LogicalResult validateUserConstraintOrRewriteReturn( |
260 | StringRef declType, ast::CompoundStmt *body, |
261 | ArrayRef<ast::Stmt *>::iterator bodyIt, |
262 | ArrayRef<ast::Stmt *>::iterator bodyE, |
263 | ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); |
264 | |
265 | FailureOr<ast::CompoundStmt *> |
266 | parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, |
267 | bool expectTerminalSemicolon = true); |
268 | FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); |
269 | FailureOr<ast::Decl *> parsePatternDecl(); |
270 | LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); |
271 | |
272 | /// Check to see if a decl has already been defined with the given name, if |
273 | /// one has emit and error and return failure. Returns success otherwise. |
274 | LogicalResult checkDefineNamedDecl(const ast::Name &name); |
275 | |
276 | /// Try to define a variable decl with the given components, returns the |
277 | /// variable on success. |
278 | FailureOr<ast::VariableDecl *> |
279 | defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, |
280 | ast::Expr *initExpr, |
281 | ArrayRef<ast::ConstraintRef> constraints); |
282 | FailureOr<ast::VariableDecl *> |
283 | defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, |
284 | ArrayRef<ast::ConstraintRef> constraints); |
285 | |
286 | /// Parse the constraint reference list for a variable decl. |
287 | LogicalResult parseVariableDeclConstraintList( |
288 | SmallVectorImpl<ast::ConstraintRef> &constraints); |
289 | |
290 | /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. |
291 | FailureOr<ast::Expr *> parseTypeConstraintExpr(); |
292 | |
293 | /// Try to parse a single reference to a constraint. `typeConstraint` is the |
294 | /// location of a previously parsed type constraint for the entity that will |
295 | /// be constrained by the parsed constraint. `existingConstraints` are any |
296 | /// existing constraints that have already been parsed for the same entity |
297 | /// that will be constrained by this constraint. `allowInlineTypeConstraints` |
298 | /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. |
299 | FailureOr<ast::ConstraintRef> |
300 | parseConstraint(std::optional<SMRange> &typeConstraint, |
301 | ArrayRef<ast::ConstraintRef> existingConstraints, |
302 | bool allowInlineTypeConstraints); |
303 | |
304 | /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl |
305 | /// argument or result variable. The constraints for these variables do not |
306 | /// allow inline type constraints, and only permit a single constraint. |
307 | FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); |
308 | |
309 | //===--------------------------------------------------------------------===// |
310 | // Exprs |
311 | |
312 | FailureOr<ast::Expr *> parseExpr(); |
313 | |
314 | /// Identifier expressions. |
315 | FailureOr<ast::Expr *> parseAttributeExpr(); |
316 | FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr, |
317 | bool isNegated = false); |
318 | FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); |
319 | FailureOr<ast::Expr *> parseIdentifierExpr(); |
320 | FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); |
321 | FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); |
322 | FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); |
323 | FailureOr<ast::Expr *> parseNegatedExpr(); |
324 | FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); |
325 | FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); |
326 | FailureOr<ast::Expr *> |
327 | parseOperationExpr(OpResultTypeContext inputResultTypeContext = |
328 | OpResultTypeContext::Explicit); |
329 | FailureOr<ast::Expr *> parseTupleExpr(); |
330 | FailureOr<ast::Expr *> parseTypeExpr(); |
331 | FailureOr<ast::Expr *> parseUnderscoreExpr(); |
332 | |
333 | //===--------------------------------------------------------------------===// |
334 | // Stmts |
335 | |
336 | FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); |
337 | FailureOr<ast::CompoundStmt *> parseCompoundStmt(); |
338 | FailureOr<ast::EraseStmt *> parseEraseStmt(); |
339 | FailureOr<ast::LetStmt *> parseLetStmt(); |
340 | FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); |
341 | FailureOr<ast::ReturnStmt *> parseReturnStmt(); |
342 | FailureOr<ast::RewriteStmt *> parseRewriteStmt(); |
343 | |
344 | //===--------------------------------------------------------------------===// |
345 | // Creation+Analysis |
346 | //===--------------------------------------------------------------------===// |
347 | |
348 | //===--------------------------------------------------------------------===// |
349 | // Decls |
350 | |
351 | /// Try to extract a callable from the given AST node. Returns nullptr on |
352 | /// failure. |
353 | ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); |
354 | |
355 | /// Try to create a pattern decl with the given components, returning the |
356 | /// Pattern on success. |
357 | FailureOr<ast::PatternDecl *> |
358 | createPatternDecl(SMRange loc, const ast::Name *name, |
359 | const ParsedPatternMetadata &metadata, |
360 | ast::CompoundStmt *body); |
361 | |
362 | /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set |
363 | /// of results, defined as part of the signature. |
364 | ast::Type |
365 | createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); |
366 | |
367 | /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. |
368 | template <typename T> |
369 | FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( |
370 | const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, |
371 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType, |
372 | ast::CompoundStmt *body); |
373 | |
374 | /// Try to create a variable decl with the given components, returning the |
375 | /// Variable on success. |
376 | FailureOr<ast::VariableDecl *> |
377 | createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, |
378 | ArrayRef<ast::ConstraintRef> constraints); |
379 | |
380 | /// Create a variable for an argument or result defined as part of the |
381 | /// signature of a UserConstraintDecl/UserRewriteDecl. |
382 | FailureOr<ast::VariableDecl *> |
383 | createArgOrResultVariableDecl(StringRef name, SMRange loc, |
384 | const ast::ConstraintRef &constraint); |
385 | |
386 | /// Validate the constraints used to constraint a variable decl. |
387 | /// `inferredType` is the type of the variable inferred by the constraints |
388 | /// within the list, and is updated to the most refined type as determined by |
389 | /// the constraints. Returns success if the constraint list is valid, failure |
390 | /// otherwise. |
391 | LogicalResult |
392 | validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, |
393 | ast::Type &inferredType); |
394 | /// Validate a single reference to a constraint. `inferredType` contains the |
395 | /// currently inferred variabled type and is refined within the type defined |
396 | /// by the constraint. Returns success if the constraint is valid, failure |
397 | /// otherwise. |
398 | LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, |
399 | ast::Type &inferredType); |
400 | LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); |
401 | LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); |
402 | |
403 | //===--------------------------------------------------------------------===// |
404 | // Exprs |
405 | |
406 | FailureOr<ast::CallExpr *> |
407 | createCallExpr(SMRange loc, ast::Expr *parentExpr, |
408 | MutableArrayRef<ast::Expr *> arguments, |
409 | bool isNegated = false); |
410 | FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); |
411 | FailureOr<ast::DeclRefExpr *> |
412 | createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, |
413 | ArrayRef<ast::ConstraintRef> constraints); |
414 | FailureOr<ast::MemberAccessExpr *> |
415 | createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); |
416 | |
417 | /// Validate the member access `name` into the given parent expression. On |
418 | /// success, this also returns the type of the member accessed. |
419 | FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, |
420 | StringRef name, SMRange loc); |
421 | FailureOr<ast::OperationExpr *> |
422 | createOperationExpr(SMRange loc, const ast::OpNameDecl *name, |
423 | OpResultTypeContext resultTypeContext, |
424 | SmallVectorImpl<ast::Expr *> &operands, |
425 | MutableArrayRef<ast::NamedAttributeDecl *> attributes, |
426 | SmallVectorImpl<ast::Expr *> &results); |
427 | LogicalResult |
428 | validateOperationOperands(SMRange loc, std::optional<StringRef> name, |
429 | const ods::Operation *odsOp, |
430 | SmallVectorImpl<ast::Expr *> &operands); |
431 | LogicalResult validateOperationResults(SMRange loc, |
432 | std::optional<StringRef> name, |
433 | const ods::Operation *odsOp, |
434 | SmallVectorImpl<ast::Expr *> &results); |
435 | void checkOperationResultTypeInferrence(SMRange loc, StringRef name, |
436 | const ods::Operation *odsOp); |
437 | LogicalResult validateOperationOperandsOrResults( |
438 | StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc, |
439 | std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values, |
440 | ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, |
441 | ast::RangeType rangeTy); |
442 | FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, |
443 | ArrayRef<ast::Expr *> elements, |
444 | ArrayRef<StringRef> elementNames); |
445 | |
446 | //===--------------------------------------------------------------------===// |
447 | // Stmts |
448 | |
449 | FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); |
450 | FailureOr<ast::ReplaceStmt *> |
451 | createReplaceStmt(SMRange loc, ast::Expr *rootOp, |
452 | MutableArrayRef<ast::Expr *> replValues); |
453 | FailureOr<ast::RewriteStmt *> |
454 | createRewriteStmt(SMRange loc, ast::Expr *rootOp, |
455 | ast::CompoundStmt *rewriteBody); |
456 | |
457 | //===--------------------------------------------------------------------===// |
458 | // Code Completion |
459 | //===--------------------------------------------------------------------===// |
460 | |
461 | /// The set of various code completion methods. Every completion method |
462 | /// returns `failure` to stop the parsing process after providing completion |
463 | /// results. |
464 | |
465 | LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); |
466 | LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName); |
467 | LogicalResult codeCompleteConstraintName(ast::Type inferredType, |
468 | bool allowInlineTypeConstraints); |
469 | LogicalResult codeCompleteDialectName(); |
470 | LogicalResult codeCompleteOperationName(StringRef dialectName); |
471 | LogicalResult codeCompletePatternMetadata(); |
472 | LogicalResult codeCompleteIncludeFilename(StringRef curPath); |
473 | |
474 | void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); |
475 | void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName, |
476 | unsigned currentNumOperands); |
477 | void codeCompleteOperationResultsSignature(std::optional<StringRef> opName, |
478 | unsigned currentNumResults); |
479 | |
480 | //===--------------------------------------------------------------------===// |
481 | // Lexer Utilities |
482 | //===--------------------------------------------------------------------===// |
483 | |
484 | /// If the current token has the specified kind, consume it and return true. |
485 | /// If not, return false. |
486 | bool consumeIf(Token::Kind kind) { |
487 | if (curToken.isNot(k: kind)) |
488 | return false; |
489 | consumeToken(kind); |
490 | return true; |
491 | } |
492 | |
493 | /// Advance the current lexer onto the next token. |
494 | void consumeToken() { |
495 | assert(curToken.isNot(Token::eof, Token::error) && |
496 | "shouldn't advance past EOF or errors"); |
497 | curToken = lexer.lexToken(); |
498 | } |
499 | |
500 | /// Advance the current lexer onto the next token, asserting what the expected |
501 | /// current token is. This is preferred to the above method because it leads |
502 | /// to more self-documenting code with better checking. |
503 | void consumeToken(Token::Kind kind) { |
504 | assert(curToken.is(kind) && "consumed an unexpected token"); |
505 | consumeToken(); |
506 | } |
507 | |
508 | /// Reset the lexer to the location at the given position. |
509 | void resetToken(SMRange tokLoc) { |
510 | lexer.resetPointer(newPointer: tokLoc.Start.getPointer()); |
511 | curToken = lexer.lexToken(); |
512 | } |
513 | |
514 | /// Consume the specified token if present and return success. On failure, |
515 | /// output a diagnostic and return failure. |
516 | LogicalResult parseToken(Token::Kind kind, const Twine &msg) { |
517 | if (curToken.getKind() != kind) |
518 | return emitError(loc: curToken.getLoc(), msg); |
519 | consumeToken(); |
520 | return success(); |
521 | } |
522 | LogicalResult emitError(SMRange loc, const Twine &msg) { |
523 | lexer.emitError(loc, msg); |
524 | return failure(); |
525 | } |
526 | LogicalResult emitError(const Twine &msg) { |
527 | return emitError(loc: curToken.getLoc(), msg); |
528 | } |
529 | LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, |
530 | const Twine ¬e) { |
531 | lexer.emitErrorAndNote(loc, msg, noteLoc, note); |
532 | return failure(); |
533 | } |
534 | |
535 | //===--------------------------------------------------------------------===// |
536 | // Fields |
537 | //===--------------------------------------------------------------------===// |
538 | |
539 | /// The owning AST context. |
540 | ast::Context &ctx; |
541 | |
542 | /// The lexer of this parser. |
543 | Lexer lexer; |
544 | |
545 | /// The current token within the lexer. |
546 | Token curToken; |
547 | |
548 | /// A flag indicating if the parser should add documentation to AST nodes when |
549 | /// viable. |
550 | bool enableDocumentation; |
551 | |
552 | /// The most recently defined decl scope. |
553 | ast::DeclScope *curDeclScope = nullptr; |
554 | llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; |
555 | |
556 | /// The current context of the parser. |
557 | ParserContext parserContext = ParserContext::Global; |
558 | |
559 | /// Cached types to simplify verification and expression creation. |
560 | ast::Type typeTy, valueTy; |
561 | ast::RangeType typeRangeTy, valueRangeTy; |
562 | ast::Type attrTy; |
563 | |
564 | /// A counter used when naming anonymous constraints and rewrites. |
565 | unsigned anonymousDeclNameCounter = 0; |
566 | |
567 | /// The optional code completion context. |
568 | CodeCompleteContext *codeCompleteContext; |
569 | }; |
570 | } // namespace |
571 | |
572 | FailureOr<ast::Module *> Parser::parseModule() { |
573 | SMLoc moduleLoc = curToken.getStartLoc(); |
574 | pushDeclScope(); |
575 | |
576 | // Parse the top-level decls of the module. |
577 | SmallVector<ast::Decl *> decls; |
578 | if (failed(Result: parseModuleBody(decls))) |
579 | return popDeclScope(), failure(); |
580 | |
581 | popDeclScope(); |
582 | return ast::Module::create(ctx, loc: moduleLoc, children: decls); |
583 | } |
584 | |
585 | LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { |
586 | while (curToken.isNot(k: Token::eof)) { |
587 | if (curToken.is(k: Token::directive)) { |
588 | if (failed(Result: parseDirective(decls))) |
589 | return failure(); |
590 | continue; |
591 | } |
592 | |
593 | FailureOr<ast::Decl *> decl = parseTopLevelDecl(); |
594 | if (failed(Result: decl)) |
595 | return failure(); |
596 | decls.push_back(Elt: *decl); |
597 | } |
598 | return success(); |
599 | } |
600 | |
601 | ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { |
602 | return ast::AllResultsMemberAccessExpr::create(ctx, loc: opExpr->getLoc(), parentExpr: opExpr, |
603 | type: valueRangeTy); |
604 | } |
605 | |
606 | LogicalResult Parser::convertExpressionTo( |
607 | ast::Expr *&expr, ast::Type type, |
608 | function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { |
609 | ast::Type exprType = expr->getType(); |
610 | if (exprType == type) |
611 | return success(); |
612 | |
613 | auto emitConvertError = [&]() -> ast::InFlightDiagnostic { |
614 | ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( |
615 | loc: expr->getLoc(), msg: llvm::formatv(Fmt: "unable to convert expression of type " |
616 | "`{0}` to the expected type of " |
617 | "`{1}`", |
618 | Vals&: exprType, Vals&: type)); |
619 | if (noteAttachFn) |
620 | noteAttachFn(*diag); |
621 | return diag; |
622 | }; |
623 | |
624 | if (auto exprOpType = dyn_cast<ast::OperationType>(Val&: exprType)) |
625 | return convertOpExpressionTo(expr, exprType: exprOpType, type, emitErrorFn: emitConvertError); |
626 | |
627 | // FIXME: Decide how to allow/support converting a single result to multiple, |
628 | // and multiple to a single result. For now, we just allow Single->Range, |
629 | // but this isn't something really supported in the PDL dialect. We should |
630 | // figure out some way to support both. |
631 | if ((exprType == valueTy || exprType == valueRangeTy) && |
632 | (type == valueTy || type == valueRangeTy)) |
633 | return success(); |
634 | if ((exprType == typeTy || exprType == typeRangeTy) && |
635 | (type == typeTy || type == typeRangeTy)) |
636 | return success(); |
637 | |
638 | // Handle tuple types. |
639 | if (auto exprTupleType = dyn_cast<ast::TupleType>(Val&: exprType)) |
640 | return convertTupleExpressionTo(expr, exprType: exprTupleType, type, emitErrorFn: emitConvertError, |
641 | noteAttachFn); |
642 | |
643 | return emitConvertError(); |
644 | } |
645 | |
646 | LogicalResult Parser::convertOpExpressionTo( |
647 | ast::Expr *&expr, ast::OperationType exprType, ast::Type type, |
648 | function_ref<ast::InFlightDiagnostic()> emitErrorFn) { |
649 | // Two operation types are compatible if they have the same name, or if the |
650 | // expected type is more general. |
651 | if (auto opType = dyn_cast<ast::OperationType>(Val&: type)) { |
652 | if (opType.getName()) |
653 | return emitErrorFn(); |
654 | return success(); |
655 | } |
656 | |
657 | // An operation can always convert to a ValueRange. |
658 | if (type == valueRangeTy) { |
659 | expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr, |
660 | type: valueRangeTy); |
661 | return success(); |
662 | } |
663 | |
664 | // Allow conversion to a single value by constraining the result range. |
665 | if (type == valueTy) { |
666 | // If the operation is registered, we can verify if it can ever have a |
667 | // single result. |
668 | if (const ods::Operation *odsOp = exprType.getODSOperation()) { |
669 | if (odsOp->getResults().empty()) { |
670 | return emitErrorFn()->attachNote( |
671 | msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined " |
672 | "with zero results", |
673 | Vals: odsOp->getName()), |
674 | noteLoc: odsOp->getLoc()); |
675 | } |
676 | |
677 | unsigned numSingleResults = llvm::count_if( |
678 | Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) { |
679 | return result.getVariableLengthKind() == |
680 | ods::VariableLengthKind::Single; |
681 | }); |
682 | if (numSingleResults > 1) { |
683 | return emitErrorFn()->attachNote( |
684 | msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined " |
685 | "with at least {1} results", |
686 | Vals: odsOp->getName(), Vals&: numSingleResults), |
687 | noteLoc: odsOp->getLoc()); |
688 | } |
689 | } |
690 | |
691 | expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr, |
692 | type: valueTy); |
693 | return success(); |
694 | } |
695 | return emitErrorFn(); |
696 | } |
697 | |
698 | LogicalResult Parser::convertTupleExpressionTo( |
699 | ast::Expr *&expr, ast::TupleType exprType, ast::Type type, |
700 | function_ref<ast::InFlightDiagnostic()> emitErrorFn, |
701 | function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { |
702 | // Handle conversions between tuples. |
703 | if (auto tupleType = dyn_cast<ast::TupleType>(Val&: type)) { |
704 | if (tupleType.size() != exprType.size()) |
705 | return emitErrorFn(); |
706 | |
707 | // Build a new tuple expression using each of the elements of the current |
708 | // tuple. |
709 | SmallVector<ast::Expr *> newExprs; |
710 | for (unsigned i = 0, e = exprType.size(); i < e; ++i) { |
711 | newExprs.push_back(Elt: ast::MemberAccessExpr::create( |
712 | ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i), |
713 | type: exprType.getElementTypes()[i])); |
714 | |
715 | auto diagFn = [&](ast::Diagnostic &diag) { |
716 | diag.attachNote(msg: llvm::formatv(Fmt: "when converting element #{0} of `{1}`", |
717 | Vals&: i, Vals&: exprType)); |
718 | if (noteAttachFn) |
719 | noteAttachFn(diag); |
720 | }; |
721 | if (failed(Result: convertExpressionTo(expr&: newExprs.back(), |
722 | type: tupleType.getElementTypes()[i], noteAttachFn: diagFn))) |
723 | return failure(); |
724 | } |
725 | expr = ast::TupleExpr::create(ctx, loc: expr->getLoc(), elements: newExprs, |
726 | elementNames: tupleType.getElementNames()); |
727 | return success(); |
728 | } |
729 | |
730 | // Handle conversion to a range. |
731 | auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes, |
732 | ast::RangeType resultTy) -> LogicalResult { |
733 | // TODO: We currently only allow range conversion within a rewrite context. |
734 | if (parserContext != ParserContext::Rewrite) { |
735 | return emitErrorFn()->attachNote(msg: "Tuple to Range conversion is currently " |
736 | "only allowed within a rewrite context"); |
737 | } |
738 | |
739 | // All of the tuple elements must be allowed types. |
740 | for (ast::Type elementType : exprType.getElementTypes()) |
741 | if (!llvm::is_contained(Range&: allowedElementTypes, Element: elementType)) |
742 | return emitErrorFn(); |
743 | |
744 | // Build a new tuple expression using each of the elements of the current |
745 | // tuple. |
746 | SmallVector<ast::Expr *> newExprs; |
747 | for (unsigned i = 0, e = exprType.size(); i < e; ++i) { |
748 | newExprs.push_back(Elt: ast::MemberAccessExpr::create( |
749 | ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i), |
750 | type: exprType.getElementTypes()[i])); |
751 | } |
752 | expr = ast::RangeExpr::create(ctx, loc: expr->getLoc(), elements: newExprs, type: resultTy); |
753 | return success(); |
754 | }; |
755 | if (type == valueRangeTy) |
756 | return convertToRange({valueTy, valueRangeTy}, valueRangeTy); |
757 | if (type == typeRangeTy) |
758 | return convertToRange({typeTy, typeRangeTy}, typeRangeTy); |
759 | |
760 | return emitErrorFn(); |
761 | } |
762 | |
763 | //===----------------------------------------------------------------------===// |
764 | // Directives |
765 | //===----------------------------------------------------------------------===// |
766 | |
767 | LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { |
768 | StringRef directive = curToken.getSpelling(); |
769 | if (directive == "#include") |
770 | return parseInclude(decls); |
771 | |
772 | return emitError(msg: "unknown directive `"+ directive + "`"); |
773 | } |
774 | |
775 | LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { |
776 | SMRange loc = curToken.getLoc(); |
777 | consumeToken(kind: Token::directive); |
778 | |
779 | // Handle code completion of the include file path. |
780 | if (curToken.is(k: Token::code_complete_string)) |
781 | return codeCompleteIncludeFilename(curPath: curToken.getStringValue()); |
782 | |
783 | // Parse the file being included. |
784 | if (!curToken.isString()) |
785 | return emitError(loc, |
786 | msg: "expected string file name after `include` directive"); |
787 | SMRange fileLoc = curToken.getLoc(); |
788 | std::string filenameStr = curToken.getStringValue(); |
789 | StringRef filename = filenameStr; |
790 | consumeToken(); |
791 | |
792 | // Check the type of include. If ending with `.pdll`, this is another pdl file |
793 | // to be parsed along with the current module. |
794 | if (filename.ends_with(Suffix: ".pdll")) { |
795 | if (failed(Result: lexer.pushInclude(filename, includeLoc: fileLoc))) |
796 | return emitError(loc: fileLoc, |
797 | msg: "unable to open include file `"+ filename + "`"); |
798 | |
799 | // If we added the include successfully, parse it into the current module. |
800 | // Make sure to update to the next token after we finish parsing the nested |
801 | // file. |
802 | curToken = lexer.lexToken(); |
803 | LogicalResult result = parseModuleBody(decls); |
804 | curToken = lexer.lexToken(); |
805 | return result; |
806 | } |
807 | |
808 | // Otherwise, this must be a `.td` include. |
809 | if (filename.ends_with(Suffix: ".td")) |
810 | return parseTdInclude(filename, fileLoc, decls); |
811 | |
812 | return emitError(loc: fileLoc, |
813 | msg: "expected include filename to end with `.pdll` or `.td`"); |
814 | } |
815 | |
816 | LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, |
817 | SmallVectorImpl<ast::Decl *> &decls) { |
818 | llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); |
819 | |
820 | // Use the source manager to open the file, but don't yet add it. |
821 | std::string includedFile; |
822 | llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = |
823 | parserSrcMgr.OpenIncludeFile(Filename: filename.str(), IncludedFile&: includedFile); |
824 | if (!includeBuffer) |
825 | return emitError(loc: fileLoc, msg: "unable to open include file `"+ filename + "`"); |
826 | |
827 | // Setup the source manager for parsing the tablegen file. |
828 | llvm::SourceMgr tdSrcMgr; |
829 | tdSrcMgr.AddNewSourceBuffer(F: std::move(*includeBuffer), IncludeLoc: SMLoc()); |
830 | tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); |
831 | |
832 | // This class provides a context argument for the llvm::SourceMgr diagnostic |
833 | // handler. |
834 | struct DiagHandlerContext { |
835 | Parser &parser; |
836 | StringRef filename; |
837 | llvm::SMRange loc; |
838 | } handlerContext{.parser: *this, .filename: filename, .loc: fileLoc}; |
839 | |
840 | // Set the diagnostic handler for the tablegen source manager. |
841 | tdSrcMgr.setDiagHandler( |
842 | DH: [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { |
843 | auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); |
844 | (void)ctx->parser.emitError( |
845 | loc: ctx->loc, |
846 | msg: llvm::formatv(Fmt: "error while processing include file `{0}`: {1}", |
847 | Vals&: ctx->filename, Vals: diag.getMessage())); |
848 | }, |
849 | Ctx: &handlerContext); |
850 | |
851 | // Parse the tablegen file. |
852 | llvm::RecordKeeper tdRecords; |
853 | if (llvm::TableGenParseFile(InputSrcMgr&: tdSrcMgr, Records&: tdRecords)) |
854 | return failure(); |
855 | |
856 | // Process the parsed records. |
857 | processTdIncludeRecords(tdRecords, decls); |
858 | |
859 | // After we are done processing, move all of the tablegen source buffers to |
860 | // the main parser source mgr. This allows for directly using source locations |
861 | // from the .td files without needing to remap them. |
862 | parserSrcMgr.takeSourceBuffersFrom(SrcMgr&: tdSrcMgr, MainBufferIncludeLoc: fileLoc.End); |
863 | return success(); |
864 | } |
865 | |
866 | void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords, |
867 | SmallVectorImpl<ast::Decl *> &decls) { |
868 | // Return the length kind of the given value. |
869 | auto getLengthKind = [](const auto &value) { |
870 | if (value.isOptional()) |
871 | return ods::VariableLengthKind::Optional; |
872 | return value.isVariadic() ? ods::VariableLengthKind::Variadic |
873 | : ods::VariableLengthKind::Single; |
874 | }; |
875 | |
876 | // Insert a type constraint into the ODS context. |
877 | ods::Context &odsContext = ctx.getODSContext(); |
878 | auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) |
879 | -> const ods::TypeConstraint & { |
880 | return odsContext.insertTypeConstraint( |
881 | name: cst.constraint.getUniqueDefName(), |
882 | summary: processDoc(doc: cst.constraint.getSummary()), cppClass: cst.constraint.getCppType()); |
883 | }; |
884 | auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { |
885 | return {loc, llvm::SMLoc::getFromPointer(Ptr: loc.getPointer() + 1)}; |
886 | }; |
887 | |
888 | // Process the parsed tablegen records to build ODS information. |
889 | /// Operations. |
890 | for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Op")) { |
891 | tblgen::Operator op(def); |
892 | |
893 | // Check to see if this operation is known to support type inferrence. |
894 | bool supportsResultTypeInferrence = |
895 | op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait"); |
896 | |
897 | auto [odsOp, inserted] = odsContext.insertOperation( |
898 | name: op.getOperationName(), summary: processDoc(doc: op.getSummary()), |
899 | desc: processAndFormatDoc(doc: op.getDescription()), nativeClassName: op.getQualCppClassName(), |
900 | supportsResultTypeInferrence, loc: op.getLoc().front()); |
901 | |
902 | // Ignore operations that have already been added. |
903 | if (!inserted) |
904 | continue; |
905 | |
906 | for (const tblgen::NamedAttribute &attr : op.getAttributes()) { |
907 | odsOp->appendAttribute(name: attr.name, optional: attr.attr.isOptional(), |
908 | constraint: odsContext.insertAttributeConstraint( |
909 | name: attr.attr.getUniqueDefName(), |
910 | summary: processDoc(doc: attr.attr.getSummary()), |
911 | cppClass: attr.attr.getStorageType())); |
912 | } |
913 | for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { |
914 | odsOp->appendOperand(name: operand.name, variableLengthKind: getLengthKind(operand), |
915 | constraint: addTypeConstraint(operand)); |
916 | } |
917 | for (const tblgen::NamedTypeConstraint &result : op.getResults()) { |
918 | odsOp->appendResult(name: result.name, variableLengthKind: getLengthKind(result), |
919 | constraint: addTypeConstraint(result)); |
920 | } |
921 | } |
922 | |
923 | auto shouldBeSkipped = [this](const llvm::Record *def) { |
924 | return def->isAnonymous() || curDeclScope->lookup(name: def->getName()) || |
925 | def->isSubClassOf(Name: "DeclareInterfaceMethods"); |
926 | }; |
927 | |
928 | /// Attr constraints. |
929 | for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Attr")) { |
930 | if (shouldBeSkipped(def)) |
931 | continue; |
932 | |
933 | tblgen::Attribute constraint(def); |
934 | decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( |
935 | constraint, loc: convertLocToRange(def->getLoc().front()), type: attrTy, |
936 | nativeType: constraint.getStorageType())); |
937 | } |
938 | /// Type constraints. |
939 | for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Type")) { |
940 | if (shouldBeSkipped(def)) |
941 | continue; |
942 | |
943 | tblgen::TypeConstraint constraint(def); |
944 | decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( |
945 | constraint, loc: convertLocToRange(def->getLoc().front()), type: typeTy, |
946 | nativeType: constraint.getCppType())); |
947 | } |
948 | /// OpInterfaces. |
949 | ast::Type opTy = ast::OperationType::get(context&: ctx); |
950 | for (const llvm::Record *def : |
951 | tdRecords.getAllDerivedDefinitions(ClassName: "OpInterface")) { |
952 | if (shouldBeSkipped(def)) |
953 | continue; |
954 | |
955 | SMRange loc = convertLocToRange(def->getLoc().front()); |
956 | |
957 | std::string cppClassName = |
958 | llvm::formatv(Fmt: "{0}::{1}", Vals: def->getValueAsString(FieldName: "cppNamespace"), |
959 | Vals: def->getValueAsString(FieldName: "cppInterfaceName")) |
960 | .str(); |
961 | std::string codeBlock = |
962 | llvm::formatv(Fmt: "return ::mlir::success(llvm::isa<{0}>(self));", |
963 | Vals&: cppClassName) |
964 | .str(); |
965 | |
966 | std::string desc = |
967 | processAndFormatDoc(doc: def->getValueAsString(FieldName: "description")); |
968 | decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( |
969 | name: def->getName(), codeBlock, loc, type: opTy, nativeType: cppClassName, docString: desc)); |
970 | } |
971 | } |
972 | |
973 | template <typename ConstraintT> |
974 | ast::Decl *Parser::createODSNativePDLLConstraintDecl( |
975 | StringRef name, StringRef codeBlock, SMRange loc, ast::Type type, |
976 | StringRef nativeType, StringRef docString) { |
977 | // Build the single input parameter. |
978 | ast::DeclScope *argScope = pushDeclScope(); |
979 | auto *paramVar = ast::VariableDecl::create( |
980 | ctx, name: ast::Name::create(ctx, name: "self", location: loc), type, |
981 | /*initExpr=*/nullptr, constraints: ast::ConstraintRef(ConstraintT::create(ctx, loc))); |
982 | argScope->add(decl: paramVar); |
983 | popDeclScope(); |
984 | |
985 | // Build the native constraint. |
986 | auto *constraintDecl = ast::UserConstraintDecl::createNative( |
987 | ctx, name: ast::Name::create(ctx, name, location: loc), inputs: paramVar, |
988 | /*results=*/std::nullopt, codeBlock, resultType: ast::TupleType::get(context&: ctx), |
989 | nativeInputTypes: nativeType); |
990 | constraintDecl->setDocComment(ctx, comment: docString); |
991 | curDeclScope->add(decl: constraintDecl); |
992 | return constraintDecl; |
993 | } |
994 | |
995 | template <typename ConstraintT> |
996 | ast::Decl * |
997 | Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, |
998 | SMRange loc, ast::Type type, |
999 | StringRef nativeType) { |
1000 | // Format the condition template. |
1001 | tblgen::FmtContext fmtContext; |
1002 | fmtContext.withSelf(subst: "self"); |
1003 | std::string codeBlock = tblgen::tgfmt( |
1004 | fmt: "return ::mlir::success("+ constraint.getConditionTemplate() + ");", |
1005 | ctx: &fmtContext); |
1006 | |
1007 | // If documentation was enabled, build the doc string for the generated |
1008 | // constraint. It would be nice to do this lazily, but TableGen information is |
1009 | // destroyed after we finish parsing the file. |
1010 | std::string docString; |
1011 | if (enableDocumentation) { |
1012 | StringRef desc = constraint.getDescription(); |
1013 | docString = processAndFormatDoc( |
1014 | doc: constraint.getSummary() + |
1015 | (desc.empty() ? "": ( "\n\n"+ constraint.getDescription()))); |
1016 | } |
1017 | |
1018 | return createODSNativePDLLConstraintDecl<ConstraintT>( |
1019 | constraint.getUniqueDefName(), codeBlock, loc, type, nativeType, |
1020 | docString); |
1021 | } |
1022 | |
1023 | //===----------------------------------------------------------------------===// |
1024 | // Decls |
1025 | //===----------------------------------------------------------------------===// |
1026 | |
1027 | FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { |
1028 | FailureOr<ast::Decl *> decl; |
1029 | switch (curToken.getKind()) { |
1030 | case Token::kw_Constraint: |
1031 | decl = parseUserConstraintDecl(); |
1032 | break; |
1033 | case Token::kw_Pattern: |
1034 | decl = parsePatternDecl(); |
1035 | break; |
1036 | case Token::kw_Rewrite: |
1037 | decl = parseUserRewriteDecl(); |
1038 | break; |
1039 | default: |
1040 | return emitError(msg: "expected top-level declaration, such as a `Pattern`"); |
1041 | } |
1042 | if (failed(Result: decl)) |
1043 | return failure(); |
1044 | |
1045 | // If the decl has a name, add it to the current scope. |
1046 | if (const ast::Name *name = (*decl)->getName()) { |
1047 | if (failed(Result: checkDefineNamedDecl(name: *name))) |
1048 | return failure(); |
1049 | curDeclScope->add(decl: *decl); |
1050 | } |
1051 | return decl; |
1052 | } |
1053 | |
1054 | FailureOr<ast::NamedAttributeDecl *> |
1055 | Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) { |
1056 | // Check for name code completion. |
1057 | if (curToken.is(k: Token::code_complete)) |
1058 | return codeCompleteAttributeName(opName: parentOpName); |
1059 | |
1060 | std::string attrNameStr; |
1061 | if (curToken.isString()) |
1062 | attrNameStr = curToken.getStringValue(); |
1063 | else if (curToken.is(k: Token::identifier) || curToken.isKeyword()) |
1064 | attrNameStr = curToken.getSpelling().str(); |
1065 | else |
1066 | return emitError(msg: "expected identifier or string attribute name"); |
1067 | const auto &name = ast::Name::create(ctx, name: attrNameStr, location: curToken.getLoc()); |
1068 | consumeToken(); |
1069 | |
1070 | // Check for a value of the attribute. |
1071 | ast::Expr *attrValue = nullptr; |
1072 | if (consumeIf(kind: Token::equal)) { |
1073 | FailureOr<ast::Expr *> attrExpr = parseExpr(); |
1074 | if (failed(Result: attrExpr)) |
1075 | return failure(); |
1076 | attrValue = *attrExpr; |
1077 | } else { |
1078 | // If there isn't a concrete value, create an expression representing a |
1079 | // UnitAttr. |
1080 | attrValue = ast::AttributeExpr::create(ctx, loc: name.getLoc(), value: "unit"); |
1081 | } |
1082 | |
1083 | return ast::NamedAttributeDecl::create(ctx, name, value: attrValue); |
1084 | } |
1085 | |
1086 | FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( |
1087 | function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, |
1088 | bool expectTerminalSemicolon) { |
1089 | consumeToken(kind: Token::equal_arrow); |
1090 | |
1091 | // Parse the single statement of the lambda body. |
1092 | SMLoc bodyStartLoc = curToken.getStartLoc(); |
1093 | pushDeclScope(); |
1094 | FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); |
1095 | bool failedToParse = |
1096 | failed(Result: singleStatement) || failed(Result: processStatementFn(*singleStatement)); |
1097 | popDeclScope(); |
1098 | if (failedToParse) |
1099 | return failure(); |
1100 | |
1101 | SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); |
1102 | return ast::CompoundStmt::create(ctx, location: bodyLoc, children: *singleStatement); |
1103 | } |
1104 | |
1105 | FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { |
1106 | // Ensure that the argument is named. |
1107 | if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword()) |
1108 | return emitError(msg: "expected identifier argument name"); |
1109 | |
1110 | // Parse the argument similarly to a normal variable. |
1111 | StringRef name = curToken.getSpelling(); |
1112 | SMRange nameLoc = curToken.getLoc(); |
1113 | consumeToken(); |
1114 | |
1115 | if (failed( |
1116 | Result: parseToken(kind: Token::colon, msg: "expected `:` before argument constraint"))) |
1117 | return failure(); |
1118 | |
1119 | FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); |
1120 | if (failed(Result: cst)) |
1121 | return failure(); |
1122 | |
1123 | return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst); |
1124 | } |
1125 | |
1126 | FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { |
1127 | // Check to see if this result is named. |
1128 | if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) { |
1129 | // Check to see if this name actually refers to a Constraint. |
1130 | if (!curDeclScope->lookup<ast::ConstraintDecl>(name: curToken.getSpelling())) { |
1131 | // If it wasn't a constraint, parse the result similarly to a variable. If |
1132 | // there is already an existing decl, we will emit an error when defining |
1133 | // this variable later. |
1134 | StringRef name = curToken.getSpelling(); |
1135 | SMRange nameLoc = curToken.getLoc(); |
1136 | consumeToken(); |
1137 | |
1138 | if (failed(Result: parseToken(kind: Token::colon, |
1139 | msg: "expected `:` before result constraint"))) |
1140 | return failure(); |
1141 | |
1142 | FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); |
1143 | if (failed(Result: cst)) |
1144 | return failure(); |
1145 | |
1146 | return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst); |
1147 | } |
1148 | } |
1149 | |
1150 | // If it isn't named, we parse the constraint directly and create an unnamed |
1151 | // result variable. |
1152 | FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); |
1153 | if (failed(Result: cst)) |
1154 | return failure(); |
1155 | |
1156 | return createArgOrResultVariableDecl(name: "", loc: cst->referenceLoc, constraint: *cst); |
1157 | } |
1158 | |
1159 | FailureOr<ast::UserConstraintDecl *> |
1160 | Parser::parseUserConstraintDecl(bool isInline) { |
1161 | // Constraints and rewrites have very similar formats, dispatch to a shared |
1162 | // interface for parsing. |
1163 | return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( |
1164 | parseUserPDLLFn: [&](auto &&...args) { |
1165 | return this->parseUserPDLLConstraintDecl(name: args...); |
1166 | }, |
1167 | declContext: ParserContext::Constraint, anonymousNamePrefix: "constraint", isInline); |
1168 | } |
1169 | |
1170 | FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { |
1171 | FailureOr<ast::UserConstraintDecl *> decl = |
1172 | parseUserConstraintDecl(/*isInline=*/true); |
1173 | if (failed(Result: decl) || failed(Result: checkDefineNamedDecl(name: (*decl)->getName()))) |
1174 | return failure(); |
1175 | |
1176 | curDeclScope->add(decl: *decl); |
1177 | return decl; |
1178 | } |
1179 | |
1180 | FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( |
1181 | const ast::Name &name, bool isInline, |
1182 | ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, |
1183 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { |
1184 | // Push the argument scope back onto the list, so that the body can |
1185 | // reference arguments. |
1186 | pushDeclScope(scope: argumentScope); |
1187 | |
1188 | // Parse the body of the constraint. The body is either defined as a compound |
1189 | // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. |
1190 | ast::CompoundStmt *body; |
1191 | if (curToken.is(k: Token::equal_arrow)) { |
1192 | FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( |
1193 | processStatementFn: [&](ast::Stmt *&stmt) -> LogicalResult { |
1194 | ast::Expr *stmtExpr = dyn_cast<ast::Expr>(Val: stmt); |
1195 | if (!stmtExpr) { |
1196 | return emitError(loc: stmt->getLoc(), |
1197 | msg: "expected `Constraint` lambda body to contain a " |
1198 | "single expression"); |
1199 | } |
1200 | stmt = ast::ReturnStmt::create(ctx, loc: stmt->getLoc(), resultExpr: stmtExpr); |
1201 | return success(); |
1202 | }, |
1203 | /*expectTerminalSemicolon=*/!isInline); |
1204 | if (failed(Result: bodyResult)) |
1205 | return failure(); |
1206 | body = *bodyResult; |
1207 | } else { |
1208 | FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); |
1209 | if (failed(Result: bodyResult)) |
1210 | return failure(); |
1211 | body = *bodyResult; |
1212 | |
1213 | // Verify the structure of the body. |
1214 | auto bodyIt = body->begin(), bodyE = body->end(); |
1215 | for (; bodyIt != bodyE; ++bodyIt) |
1216 | if (isa<ast::ReturnStmt>(Val: *bodyIt)) |
1217 | break; |
1218 | if (failed(Result: validateUserConstraintOrRewriteReturn( |
1219 | declType: "Constraint", body, bodyIt, bodyE, results, resultType))) |
1220 | return failure(); |
1221 | } |
1222 | popDeclScope(); |
1223 | |
1224 | return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( |
1225 | name, arguments, results, resultType, body); |
1226 | } |
1227 | |
1228 | FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { |
1229 | // Constraints and rewrites have very similar formats, dispatch to a shared |
1230 | // interface for parsing. |
1231 | return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( |
1232 | parseUserPDLLFn: [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(name: args...); }, |
1233 | declContext: ParserContext::Rewrite, anonymousNamePrefix: "rewrite", isInline); |
1234 | } |
1235 | |
1236 | FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { |
1237 | FailureOr<ast::UserRewriteDecl *> decl = |
1238 | parseUserRewriteDecl(/*isInline=*/true); |
1239 | if (failed(Result: decl) || failed(Result: checkDefineNamedDecl(name: (*decl)->getName()))) |
1240 | return failure(); |
1241 | |
1242 | curDeclScope->add(decl: *decl); |
1243 | return decl; |
1244 | } |
1245 | |
1246 | FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( |
1247 | const ast::Name &name, bool isInline, |
1248 | ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, |
1249 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { |
1250 | // Push the argument scope back onto the list, so that the body can |
1251 | // reference arguments. |
1252 | curDeclScope = argumentScope; |
1253 | ast::CompoundStmt *body; |
1254 | if (curToken.is(k: Token::equal_arrow)) { |
1255 | FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( |
1256 | processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult { |
1257 | if (isa<ast::OpRewriteStmt>(Val: statement)) |
1258 | return success(); |
1259 | |
1260 | ast::Expr *statementExpr = dyn_cast<ast::Expr>(Val: statement); |
1261 | if (!statementExpr) { |
1262 | return emitError( |
1263 | loc: statement->getLoc(), |
1264 | msg: "expected `Rewrite` lambda body to contain a single expression " |
1265 | "or an operation rewrite statement; such as `erase`, " |
1266 | "`replace`, or `rewrite`"); |
1267 | } |
1268 | statement = |
1269 | ast::ReturnStmt::create(ctx, loc: statement->getLoc(), resultExpr: statementExpr); |
1270 | return success(); |
1271 | }, |
1272 | /*expectTerminalSemicolon=*/!isInline); |
1273 | if (failed(Result: bodyResult)) |
1274 | return failure(); |
1275 | body = *bodyResult; |
1276 | } else { |
1277 | FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); |
1278 | if (failed(Result: bodyResult)) |
1279 | return failure(); |
1280 | body = *bodyResult; |
1281 | } |
1282 | popDeclScope(); |
1283 | |
1284 | // Verify the structure of the body. |
1285 | auto bodyIt = body->begin(), bodyE = body->end(); |
1286 | for (; bodyIt != bodyE; ++bodyIt) |
1287 | if (isa<ast::ReturnStmt>(Val: *bodyIt)) |
1288 | break; |
1289 | if (failed(Result: validateUserConstraintOrRewriteReturn(declType: "Rewrite", body, bodyIt, |
1290 | bodyE, results, resultType))) |
1291 | return failure(); |
1292 | return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( |
1293 | name, arguments, results, resultType, body); |
1294 | } |
1295 | |
1296 | template <typename T, typename ParseUserPDLLDeclFnT> |
1297 | FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( |
1298 | ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, |
1299 | StringRef anonymousNamePrefix, bool isInline) { |
1300 | SMRange loc = curToken.getLoc(); |
1301 | consumeToken(); |
1302 | llvm::SaveAndRestore saveCtx(parserContext, declContext); |
1303 | |
1304 | // Parse the name of the decl. |
1305 | const ast::Name *name = nullptr; |
1306 | if (curToken.isNot(k: Token::identifier)) { |
1307 | // Only inline decls can be un-named. Inline decls are similar to "lambdas" |
1308 | // in C++, so being unnamed is fine. |
1309 | if (!isInline) |
1310 | return emitError(msg: "expected identifier name"); |
1311 | |
1312 | // Create a unique anonymous name to use, as the name for this decl is not |
1313 | // important. |
1314 | std::string anonName = |
1315 | llvm::formatv(Fmt: "<anonymous_{0}_{1}>", Vals&: anonymousNamePrefix, |
1316 | Vals: anonymousDeclNameCounter++) |
1317 | .str(); |
1318 | name = &ast::Name::create(ctx, name: anonName, location: loc); |
1319 | } else { |
1320 | // If a name was provided, we can use it directly. |
1321 | name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc()); |
1322 | consumeToken(kind: Token::identifier); |
1323 | } |
1324 | |
1325 | // Parse the functional signature of the decl. |
1326 | SmallVector<ast::VariableDecl *> arguments, results; |
1327 | ast::DeclScope *argumentScope; |
1328 | ast::Type resultType; |
1329 | if (failed(Result: parseUserConstraintOrRewriteSignature(arguments, results, |
1330 | argumentScope, resultType))) |
1331 | return failure(); |
1332 | |
1333 | // Check to see which type of constraint this is. If the constraint contains a |
1334 | // compound body, this is a PDLL decl. |
1335 | if (curToken.isAny(k1: Token::l_brace, k2: Token::equal_arrow)) |
1336 | return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, |
1337 | resultType); |
1338 | |
1339 | // Otherwise, this is a native decl. |
1340 | return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, |
1341 | results, resultType); |
1342 | } |
1343 | |
1344 | template <typename T> |
1345 | FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( |
1346 | const ast::Name &name, bool isInline, |
1347 | ArrayRef<ast::VariableDecl *> arguments, |
1348 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { |
1349 | // If followed by a string, the native code body has also been specified. |
1350 | std::string codeStrStorage; |
1351 | std::optional<StringRef> optCodeStr; |
1352 | if (curToken.isString()) { |
1353 | codeStrStorage = curToken.getStringValue(); |
1354 | optCodeStr = codeStrStorage; |
1355 | consumeToken(); |
1356 | } else if (isInline) { |
1357 | return emitError(loc: name.getLoc(), |
1358 | msg: "external declarations must be declared in global scope"); |
1359 | } else if (curToken.is(k: Token::error)) { |
1360 | return failure(); |
1361 | } |
1362 | if (failed(Result: parseToken(kind: Token::semicolon, |
1363 | msg: "expected `;` after native declaration"))) |
1364 | return failure(); |
1365 | return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); |
1366 | } |
1367 | |
1368 | LogicalResult Parser::parseUserConstraintOrRewriteSignature( |
1369 | SmallVectorImpl<ast::VariableDecl *> &arguments, |
1370 | SmallVectorImpl<ast::VariableDecl *> &results, |
1371 | ast::DeclScope *&argumentScope, ast::Type &resultType) { |
1372 | // Parse the argument list of the decl. |
1373 | if (failed(Result: parseToken(kind: Token::l_paren, msg: "expected `(` to start argument list"))) |
1374 | return failure(); |
1375 | |
1376 | argumentScope = pushDeclScope(); |
1377 | if (curToken.isNot(k: Token::r_paren)) { |
1378 | do { |
1379 | FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); |
1380 | if (failed(Result: argument)) |
1381 | return failure(); |
1382 | arguments.emplace_back(Args&: *argument); |
1383 | } while (consumeIf(kind: Token::comma)); |
1384 | } |
1385 | popDeclScope(); |
1386 | if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` to end argument list"))) |
1387 | return failure(); |
1388 | |
1389 | // Parse the results of the decl. |
1390 | pushDeclScope(); |
1391 | if (consumeIf(kind: Token::arrow)) { |
1392 | auto parseResultFn = [&]() -> LogicalResult { |
1393 | FailureOr<ast::VariableDecl *> result = parseResultDecl(resultNum: results.size()); |
1394 | if (failed(Result: result)) |
1395 | return failure(); |
1396 | results.emplace_back(Args&: *result); |
1397 | return success(); |
1398 | }; |
1399 | |
1400 | // Check for a list of results. |
1401 | if (consumeIf(kind: Token::l_paren)) { |
1402 | do { |
1403 | if (failed(Result: parseResultFn())) |
1404 | return failure(); |
1405 | } while (consumeIf(kind: Token::comma)); |
1406 | if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` to end result list"))) |
1407 | return failure(); |
1408 | |
1409 | // Otherwise, there is only one result. |
1410 | } else if (failed(Result: parseResultFn())) { |
1411 | return failure(); |
1412 | } |
1413 | } |
1414 | popDeclScope(); |
1415 | |
1416 | // Compute the result type of the decl. |
1417 | resultType = createUserConstraintRewriteResultType(results); |
1418 | |
1419 | // Verify that results are only named if there are more than one. |
1420 | if (results.size() == 1 && !results.front()->getName().getName().empty()) { |
1421 | return emitError( |
1422 | loc: results.front()->getLoc(), |
1423 | msg: "cannot create a single-element tuple with an element label"); |
1424 | } |
1425 | return success(); |
1426 | } |
1427 | |
1428 | LogicalResult Parser::validateUserConstraintOrRewriteReturn( |
1429 | StringRef declType, ast::CompoundStmt *body, |
1430 | ArrayRef<ast::Stmt *>::iterator bodyIt, |
1431 | ArrayRef<ast::Stmt *>::iterator bodyE, |
1432 | ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { |
1433 | // Handle if a `return` was provided. |
1434 | if (bodyIt != bodyE) { |
1435 | // Emit an error if we have trailing statements after the return. |
1436 | if (std::next(x: bodyIt) != bodyE) { |
1437 | return emitError( |
1438 | loc: (*std::next(x: bodyIt))->getLoc(), |
1439 | msg: llvm::formatv(Fmt: "`return` terminated the `{0}` body, but found " |
1440 | "trailing statements afterwards", |
1441 | Vals&: declType)); |
1442 | } |
1443 | |
1444 | // Otherwise if a return wasn't provided, check that no results are |
1445 | // expected. |
1446 | } else if (!results.empty()) { |
1447 | return emitError( |
1448 | loc: {body->getLoc().End, body->getLoc().End}, |
1449 | msg: llvm::formatv(Fmt: "missing return in a `{0}` expected to return `{1}`", |
1450 | Vals&: declType, Vals&: resultType)); |
1451 | } |
1452 | return success(); |
1453 | } |
1454 | |
1455 | FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { |
1456 | return parseLambdaBody(processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult { |
1457 | if (isa<ast::OpRewriteStmt>(Val: statement)) |
1458 | return success(); |
1459 | return emitError( |
1460 | loc: statement->getLoc(), |
1461 | msg: "expected Pattern lambda body to contain a single operation " |
1462 | "rewrite statement, such as `erase`, `replace`, or `rewrite`"); |
1463 | }); |
1464 | } |
1465 | |
1466 | FailureOr<ast::Decl *> Parser::parsePatternDecl() { |
1467 | SMRange loc = curToken.getLoc(); |
1468 | consumeToken(kind: Token::kw_Pattern); |
1469 | llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch); |
1470 | |
1471 | // Check for an optional identifier for the pattern name. |
1472 | const ast::Name *name = nullptr; |
1473 | if (curToken.is(k: Token::identifier)) { |
1474 | name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc()); |
1475 | consumeToken(kind: Token::identifier); |
1476 | } |
1477 | |
1478 | // Parse any pattern metadata. |
1479 | ParsedPatternMetadata metadata; |
1480 | if (consumeIf(kind: Token::kw_with) && failed(Result: parsePatternDeclMetadata(metadata))) |
1481 | return failure(); |
1482 | |
1483 | // Parse the pattern body. |
1484 | ast::CompoundStmt *body; |
1485 | |
1486 | // Handle a lambda body. |
1487 | if (curToken.is(k: Token::equal_arrow)) { |
1488 | FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); |
1489 | if (failed(Result: bodyResult)) |
1490 | return failure(); |
1491 | body = *bodyResult; |
1492 | } else { |
1493 | if (curToken.isNot(k: Token::l_brace)) |
1494 | return emitError(msg: "expected `{` or `=>` to start pattern body"); |
1495 | FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); |
1496 | if (failed(Result: bodyResult)) |
1497 | return failure(); |
1498 | body = *bodyResult; |
1499 | |
1500 | // Verify the body of the pattern. |
1501 | auto bodyIt = body->begin(), bodyE = body->end(); |
1502 | for (; bodyIt != bodyE; ++bodyIt) { |
1503 | if (isa<ast::ReturnStmt>(Val: *bodyIt)) { |
1504 | return emitError(loc: (*bodyIt)->getLoc(), |
1505 | msg: "`return` statements are only permitted within a " |
1506 | "`Constraint` or `Rewrite` body"); |
1507 | } |
1508 | // Break when we've found the rewrite statement. |
1509 | if (isa<ast::OpRewriteStmt>(Val: *bodyIt)) |
1510 | break; |
1511 | } |
1512 | if (bodyIt == bodyE) { |
1513 | return emitError(loc, |
1514 | msg: "expected Pattern body to terminate with an operation " |
1515 | "rewrite statement, such as `erase`"); |
1516 | } |
1517 | if (std::next(x: bodyIt) != bodyE) { |
1518 | return emitError(loc: (*std::next(x: bodyIt))->getLoc(), |
1519 | msg: "Pattern body was terminated by an operation " |
1520 | "rewrite statement, but found trailing statements"); |
1521 | } |
1522 | } |
1523 | |
1524 | return createPatternDecl(loc, name, metadata, body); |
1525 | } |
1526 | |
1527 | LogicalResult |
1528 | Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { |
1529 | std::optional<SMRange> benefitLoc; |
1530 | std::optional<SMRange> hasBoundedRecursionLoc; |
1531 | |
1532 | do { |
1533 | // Handle metadata code completion. |
1534 | if (curToken.is(k: Token::code_complete)) |
1535 | return codeCompletePatternMetadata(); |
1536 | |
1537 | if (curToken.isNot(k: Token::identifier)) |
1538 | return emitError(msg: "expected pattern metadata identifier"); |
1539 | StringRef metadataStr = curToken.getSpelling(); |
1540 | SMRange metadataLoc = curToken.getLoc(); |
1541 | consumeToken(kind: Token::identifier); |
1542 | |
1543 | // Parse the benefit metadata: benefit(<integer-value>) |
1544 | if (metadataStr == "benefit") { |
1545 | if (benefitLoc) { |
1546 | return emitErrorAndNote(loc: metadataLoc, |
1547 | msg: "pattern benefit has already been specified", |
1548 | noteLoc: *benefitLoc, note: "see previous definition here"); |
1549 | } |
1550 | if (failed(Result: parseToken(kind: Token::l_paren, |
1551 | msg: "expected `(` before pattern benefit"))) |
1552 | return failure(); |
1553 | |
1554 | uint16_t benefitValue = 0; |
1555 | if (curToken.isNot(k: Token::integer)) |
1556 | return emitError(msg: "expected integral pattern benefit"); |
1557 | if (curToken.getSpelling().getAsInteger(/*Radix=*/10, Result&: benefitValue)) |
1558 | return emitError( |
1559 | msg: "expected pattern benefit to fit within a 16-bit integer"); |
1560 | consumeToken(kind: Token::integer); |
1561 | |
1562 | metadata.benefit = benefitValue; |
1563 | benefitLoc = metadataLoc; |
1564 | |
1565 | if (failed( |
1566 | Result: parseToken(kind: Token::r_paren, msg: "expected `)` after pattern benefit"))) |
1567 | return failure(); |
1568 | continue; |
1569 | } |
1570 | |
1571 | // Parse the bounded recursion metadata: recursion |
1572 | if (metadataStr == "recursion") { |
1573 | if (hasBoundedRecursionLoc) { |
1574 | return emitErrorAndNote( |
1575 | loc: metadataLoc, |
1576 | msg: "pattern recursion metadata has already been specified", |
1577 | noteLoc: *hasBoundedRecursionLoc, note: "see previous definition here"); |
1578 | } |
1579 | metadata.hasBoundedRecursion = true; |
1580 | hasBoundedRecursionLoc = metadataLoc; |
1581 | continue; |
1582 | } |
1583 | |
1584 | return emitError(loc: metadataLoc, msg: "unknown pattern metadata"); |
1585 | } while (consumeIf(kind: Token::comma)); |
1586 | |
1587 | return success(); |
1588 | } |
1589 | |
1590 | FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { |
1591 | consumeToken(kind: Token::less); |
1592 | |
1593 | FailureOr<ast::Expr *> typeExpr = parseExpr(); |
1594 | if (failed(Result: typeExpr) || |
1595 | failed(Result: parseToken(kind: Token::greater, |
1596 | msg: "expected `>` after variable type constraint"))) |
1597 | return failure(); |
1598 | return typeExpr; |
1599 | } |
1600 | |
1601 | LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { |
1602 | assert(curDeclScope && "defining decl outside of a decl scope"); |
1603 | if (ast::Decl *lastDecl = curDeclScope->lookup(name: name.getName())) { |
1604 | return emitErrorAndNote( |
1605 | loc: name.getLoc(), msg: "`"+ name.getName() + "` has already been defined", |
1606 | noteLoc: lastDecl->getName()->getLoc(), note: "see previous definition here"); |
1607 | } |
1608 | return success(); |
1609 | } |
1610 | |
1611 | FailureOr<ast::VariableDecl *> |
1612 | Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, |
1613 | ast::Expr *initExpr, |
1614 | ArrayRef<ast::ConstraintRef> constraints) { |
1615 | assert(curDeclScope && "defining variable outside of decl scope"); |
1616 | const ast::Name &nameDecl = ast::Name::create(ctx, name, location: nameLoc); |
1617 | |
1618 | // If the name of the variable indicates a special variable, we don't add it |
1619 | // to the scope. This variable is local to the definition point. |
1620 | if (name.empty() || name == "_") { |
1621 | return ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr, |
1622 | constraints); |
1623 | } |
1624 | if (failed(Result: checkDefineNamedDecl(name: nameDecl))) |
1625 | return failure(); |
1626 | |
1627 | auto *varDecl = |
1628 | ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr, constraints); |
1629 | curDeclScope->add(decl: varDecl); |
1630 | return varDecl; |
1631 | } |
1632 | |
1633 | FailureOr<ast::VariableDecl *> |
1634 | Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, |
1635 | ArrayRef<ast::ConstraintRef> constraints) { |
1636 | return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, |
1637 | constraints); |
1638 | } |
1639 | |
1640 | LogicalResult Parser::parseVariableDeclConstraintList( |
1641 | SmallVectorImpl<ast::ConstraintRef> &constraints) { |
1642 | std::optional<SMRange> typeConstraint; |
1643 | auto parseSingleConstraint = [&] { |
1644 | FailureOr<ast::ConstraintRef> constraint = parseConstraint( |
1645 | typeConstraint, existingConstraints: constraints, /*allowInlineTypeConstraints=*/true); |
1646 | if (failed(Result: constraint)) |
1647 | return failure(); |
1648 | constraints.push_back(Elt: *constraint); |
1649 | return success(); |
1650 | }; |
1651 | |
1652 | // Check to see if this is a single constraint, or a list. |
1653 | if (!consumeIf(kind: Token::l_square)) |
1654 | return parseSingleConstraint(); |
1655 | |
1656 | do { |
1657 | if (failed(Result: parseSingleConstraint())) |
1658 | return failure(); |
1659 | } while (consumeIf(kind: Token::comma)); |
1660 | return parseToken(kind: Token::r_square, msg: "expected `]` after constraint list"); |
1661 | } |
1662 | |
1663 | FailureOr<ast::ConstraintRef> |
1664 | Parser::parseConstraint(std::optional<SMRange> &typeConstraint, |
1665 | ArrayRef<ast::ConstraintRef> existingConstraints, |
1666 | bool allowInlineTypeConstraints) { |
1667 | auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { |
1668 | if (!allowInlineTypeConstraints) { |
1669 | return emitError( |
1670 | loc: curToken.getLoc(), |
1671 | msg: "inline `Attr`, `Value`, and `ValueRange` type constraints are not " |
1672 | "permitted on arguments or results"); |
1673 | } |
1674 | if (typeConstraint) |
1675 | return emitErrorAndNote( |
1676 | loc: curToken.getLoc(), |
1677 | msg: "the type of this variable has already been constrained", |
1678 | noteLoc: *typeConstraint, note: "see previous constraint location here"); |
1679 | FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); |
1680 | if (failed(Result: constraintExpr)) |
1681 | return failure(); |
1682 | typeExpr = *constraintExpr; |
1683 | typeConstraint = typeExpr->getLoc(); |
1684 | return success(); |
1685 | }; |
1686 | |
1687 | SMRange loc = curToken.getLoc(); |
1688 | switch (curToken.getKind()) { |
1689 | case Token::kw_Attr: { |
1690 | consumeToken(kind: Token::kw_Attr); |
1691 | |
1692 | // Check for a type constraint. |
1693 | ast::Expr *typeExpr = nullptr; |
1694 | if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr))) |
1695 | return failure(); |
1696 | return ast::ConstraintRef( |
1697 | ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); |
1698 | } |
1699 | case Token::kw_Op: { |
1700 | consumeToken(kind: Token::kw_Op); |
1701 | |
1702 | // Parse an optional operation name. If the name isn't provided, this refers |
1703 | // to "any" operation. |
1704 | FailureOr<ast::OpNameDecl *> opName = |
1705 | parseWrappedOperationName(/*allowEmptyName=*/true); |
1706 | if (failed(Result: opName)) |
1707 | return failure(); |
1708 | |
1709 | return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, nameDecl: *opName), |
1710 | loc); |
1711 | } |
1712 | case Token::kw_Type: |
1713 | consumeToken(kind: Token::kw_Type); |
1714 | return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); |
1715 | case Token::kw_TypeRange: |
1716 | consumeToken(kind: Token::kw_TypeRange); |
1717 | return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), |
1718 | loc); |
1719 | case Token::kw_Value: { |
1720 | consumeToken(kind: Token::kw_Value); |
1721 | |
1722 | // Check for a type constraint. |
1723 | ast::Expr *typeExpr = nullptr; |
1724 | if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr))) |
1725 | return failure(); |
1726 | |
1727 | return ast::ConstraintRef( |
1728 | ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); |
1729 | } |
1730 | case Token::kw_ValueRange: { |
1731 | consumeToken(kind: Token::kw_ValueRange); |
1732 | |
1733 | // Check for a type constraint. |
1734 | ast::Expr *typeExpr = nullptr; |
1735 | if (curToken.is(k: Token::less) && failed(Result: parseTypeConstraint(typeExpr))) |
1736 | return failure(); |
1737 | |
1738 | return ast::ConstraintRef( |
1739 | ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); |
1740 | } |
1741 | |
1742 | case Token::kw_Constraint: { |
1743 | // Handle an inline constraint. |
1744 | FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); |
1745 | if (failed(Result: decl)) |
1746 | return failure(); |
1747 | return ast::ConstraintRef(*decl, loc); |
1748 | } |
1749 | case Token::identifier: { |
1750 | StringRef constraintName = curToken.getSpelling(); |
1751 | consumeToken(kind: Token::identifier); |
1752 | |
1753 | // Lookup the referenced constraint. |
1754 | ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(name: constraintName); |
1755 | if (!cstDecl) { |
1756 | return emitError(loc, msg: "unknown reference to constraint `"+ |
1757 | constraintName + "`"); |
1758 | } |
1759 | |
1760 | // Handle a reference to a proper constraint. |
1761 | if (auto *cst = dyn_cast<ast::ConstraintDecl>(Val: cstDecl)) |
1762 | return ast::ConstraintRef(cst, loc); |
1763 | |
1764 | return emitErrorAndNote( |
1765 | loc, msg: "invalid reference to non-constraint", noteLoc: cstDecl->getLoc(), |
1766 | note: "see the definition of `"+ constraintName + "` here"); |
1767 | } |
1768 | // Handle single entity constraint code completion. |
1769 | case Token::code_complete: { |
1770 | // Try to infer the current type for use by code completion. |
1771 | ast::Type inferredType; |
1772 | if (failed(Result: validateVariableConstraints(constraints: existingConstraints, inferredType))) |
1773 | return failure(); |
1774 | |
1775 | return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints); |
1776 | } |
1777 | default: |
1778 | break; |
1779 | } |
1780 | return emitError(loc, msg: "expected identifier constraint"); |
1781 | } |
1782 | |
1783 | FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { |
1784 | std::optional<SMRange> typeConstraint; |
1785 | return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt, |
1786 | /*allowInlineTypeConstraints=*/false); |
1787 | } |
1788 | |
1789 | //===----------------------------------------------------------------------===// |
1790 | // Exprs |
1791 | //===----------------------------------------------------------------------===// |
1792 | |
1793 | FailureOr<ast::Expr *> Parser::parseExpr() { |
1794 | if (curToken.is(k: Token::underscore)) |
1795 | return parseUnderscoreExpr(); |
1796 | |
1797 | // Parse the LHS expression. |
1798 | FailureOr<ast::Expr *> lhsExpr; |
1799 | switch (curToken.getKind()) { |
1800 | case Token::kw_attr: |
1801 | lhsExpr = parseAttributeExpr(); |
1802 | break; |
1803 | case Token::kw_Constraint: |
1804 | lhsExpr = parseInlineConstraintLambdaExpr(); |
1805 | break; |
1806 | case Token::kw_not: |
1807 | lhsExpr = parseNegatedExpr(); |
1808 | break; |
1809 | case Token::identifier: |
1810 | lhsExpr = parseIdentifierExpr(); |
1811 | break; |
1812 | case Token::kw_op: |
1813 | lhsExpr = parseOperationExpr(); |
1814 | break; |
1815 | case Token::kw_Rewrite: |
1816 | lhsExpr = parseInlineRewriteLambdaExpr(); |
1817 | break; |
1818 | case Token::kw_type: |
1819 | lhsExpr = parseTypeExpr(); |
1820 | break; |
1821 | case Token::l_paren: |
1822 | lhsExpr = parseTupleExpr(); |
1823 | break; |
1824 | default: |
1825 | return emitError(msg: "expected expression"); |
1826 | } |
1827 | if (failed(Result: lhsExpr)) |
1828 | return failure(); |
1829 | |
1830 | // Check for an operator expression. |
1831 | while (true) { |
1832 | switch (curToken.getKind()) { |
1833 | case Token::dot: |
1834 | lhsExpr = parseMemberAccessExpr(parentExpr: *lhsExpr); |
1835 | break; |
1836 | case Token::l_paren: |
1837 | lhsExpr = parseCallExpr(parentExpr: *lhsExpr); |
1838 | break; |
1839 | default: |
1840 | return lhsExpr; |
1841 | } |
1842 | if (failed(Result: lhsExpr)) |
1843 | return failure(); |
1844 | } |
1845 | } |
1846 | |
1847 | FailureOr<ast::Expr *> Parser::parseAttributeExpr() { |
1848 | SMRange loc = curToken.getLoc(); |
1849 | consumeToken(kind: Token::kw_attr); |
1850 | |
1851 | // If we aren't followed by a `<`, the `attr` keyword is treated as a normal |
1852 | // identifier. |
1853 | if (!consumeIf(kind: Token::less)) { |
1854 | resetToken(tokLoc: loc); |
1855 | return parseIdentifierExpr(); |
1856 | } |
1857 | |
1858 | if (!curToken.isString()) |
1859 | return emitError(msg: "expected string literal containing MLIR attribute"); |
1860 | std::string attrExpr = curToken.getStringValue(); |
1861 | consumeToken(); |
1862 | |
1863 | loc.End = curToken.getEndLoc(); |
1864 | if (failed( |
1865 | Result: parseToken(kind: Token::greater, msg: "expected `>` after attribute literal"))) |
1866 | return failure(); |
1867 | return ast::AttributeExpr::create(ctx, loc, value: attrExpr); |
1868 | } |
1869 | |
1870 | FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr, |
1871 | bool isNegated) { |
1872 | consumeToken(kind: Token::l_paren); |
1873 | |
1874 | // Parse the arguments of the call. |
1875 | SmallVector<ast::Expr *> arguments; |
1876 | if (curToken.isNot(k: Token::r_paren)) { |
1877 | do { |
1878 | // Handle code completion for the call arguments. |
1879 | if (curToken.is(k: Token::code_complete)) { |
1880 | codeCompleteCallSignature(parent: parentExpr, currentNumArgs: arguments.size()); |
1881 | return failure(); |
1882 | } |
1883 | |
1884 | FailureOr<ast::Expr *> argument = parseExpr(); |
1885 | if (failed(Result: argument)) |
1886 | return failure(); |
1887 | arguments.push_back(Elt: *argument); |
1888 | } while (consumeIf(kind: Token::comma)); |
1889 | } |
1890 | |
1891 | SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); |
1892 | if (failed(Result: parseToken(kind: Token::r_paren, msg: "expected `)` after argument list"))) |
1893 | return failure(); |
1894 | |
1895 | return createCallExpr(loc, parentExpr, arguments, isNegated); |
1896 | } |
1897 | |
1898 | FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { |
1899 | ast::Decl *decl = curDeclScope->lookup(name); |
1900 | if (!decl) |
1901 | return emitError(loc, msg: "undefined reference to `"+ name + "`"); |
1902 | |
1903 | return createDeclRefExpr(loc, decl); |
1904 | } |
1905 | |
1906 | FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { |
1907 | StringRef name = curToken.getSpelling(); |
1908 | SMRange nameLoc = curToken.getLoc(); |
1909 | consumeToken(); |
1910 | |
1911 | // Check to see if this is a decl ref expression that defines a variable |
1912 | // inline. |
1913 | if (consumeIf(kind: Token::colon)) { |
1914 | SmallVector<ast::ConstraintRef> constraints; |
1915 | if (failed(Result: parseVariableDeclConstraintList(constraints))) |
1916 | return failure(); |
1917 | ast::Type type; |
1918 | if (failed(Result: validateVariableConstraints(constraints, inferredType&: type))) |
1919 | return failure(); |
1920 | return createInlineVariableExpr(type, name, loc: nameLoc, constraints); |
1921 | } |
1922 | |
1923 | return parseDeclRefExpr(name, loc: nameLoc); |
1924 | } |
1925 | |
1926 | FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { |
1927 | FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); |
1928 | if (failed(Result: decl)) |
1929 | return failure(); |
1930 | |
1931 | return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl, |
1932 | type: ast::ConstraintType::get(context&: ctx)); |
1933 | } |
1934 | |
1935 | FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { |
1936 | FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); |
1937 | if (failed(Result: decl)) |
1938 | return failure(); |
1939 | |
1940 | return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl, |
1941 | type: ast::RewriteType::get(context&: ctx)); |
1942 | } |
1943 | |
1944 | FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { |
1945 | SMRange dotLoc = curToken.getLoc(); |
1946 | consumeToken(kind: Token::dot); |
1947 | |
1948 | // Check for code completion of the member name. |
1949 | if (curToken.is(k: Token::code_complete)) |
1950 | return codeCompleteMemberAccess(parentExpr); |
1951 | |
1952 | // Parse the member name. |
1953 | Token memberNameTok = curToken; |
1954 | if (memberNameTok.isNot(k1: Token::identifier, k2: Token::integer) && |
1955 | !memberNameTok.isKeyword()) |
1956 | return emitError(loc: dotLoc, msg: "expected identifier or numeric member name"); |
1957 | StringRef memberName = memberNameTok.getSpelling(); |
1958 | SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); |
1959 | consumeToken(); |
1960 | |
1961 | return createMemberAccessExpr(parentExpr, name: memberName, loc); |
1962 | } |
1963 | |
1964 | FailureOr<ast::Expr *> Parser::parseNegatedExpr() { |
1965 | consumeToken(kind: Token::kw_not); |
1966 | // Only native constraints are supported after negation |
1967 | if (!curToken.is(k: Token::identifier)) |
1968 | return emitError(msg: "expected native constraint"); |
1969 | FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr(); |
1970 | if (failed(Result: identifierExpr)) |
1971 | return failure(); |
1972 | if (!curToken.is(k: Token::l_paren)) |
1973 | return emitError(msg: "expected `(` after function name"); |
1974 | return parseCallExpr(parentExpr: *identifierExpr, /*isNegated = */ true); |
1975 | } |
1976 | |
1977 | FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { |
1978 | SMRange loc = curToken.getLoc(); |
1979 | |
1980 | // Check for code completion for the dialect name. |
1981 | if (curToken.is(k: Token::code_complete)) |
1982 | return codeCompleteDialectName(); |
1983 | |
1984 | // Handle the case of an no operation name. |
1985 | if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword()) { |
1986 | if (allowEmptyName) |
1987 | return ast::OpNameDecl::create(ctx, loc: SMRange()); |
1988 | return emitError(msg: "expected dialect namespace"); |
1989 | } |
1990 | StringRef name = curToken.getSpelling(); |
1991 | consumeToken(); |
1992 | |
1993 | // Otherwise, this is a literal operation name. |
1994 | if (failed(Result: parseToken(kind: Token::dot, msg: "expected `.` after dialect namespace"))) |
1995 | return failure(); |
1996 | |
1997 | // Check for code completion for the operation name. |
1998 | if (curToken.is(k: Token::code_complete)) |
1999 | return codeCompleteOperationName(dialectName: name); |
2000 | |
2001 | if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword()) |
2002 | return emitError(msg: "expected operation name after dialect namespace"); |
2003 | |
2004 | name = StringRef(name.data(), name.size() + 1); |
2005 | do { |
2006 | name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); |
2007 | loc.End = curToken.getEndLoc(); |
2008 | consumeToken(); |
2009 | } while (curToken.isAny(k1: Token::identifier, k2: Token::dot) || |
2010 | curToken.isKeyword()); |
2011 | return ast::OpNameDecl::create(ctx, name: ast::Name::create(ctx, name, location: loc)); |
2012 | } |
2013 | |
2014 | FailureOr<ast::OpNameDecl *> |
2015 | Parser::parseWrappedOperationName(bool allowEmptyName) { |
2016 | if (!consumeIf(kind: Token::less)) |
2017 | return ast::OpNameDecl::create(ctx, loc: SMRange()); |
2018 | |
2019 | FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); |
2020 | if (failed(Result: opNameDecl)) |
2021 | return failure(); |
2022 | |
2023 | if (failed(Result: parseToken(kind: Token::greater, msg: "expected `>` after operation name"))) |
2024 | return failure(); |
2025 | return opNameDecl; |
2026 | } |
2027 | |
2028 | FailureOr<ast::Expr *> |
2029 | Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) { |
2030 | SMRange loc = curToken.getLoc(); |
2031 | consumeToken(kind: Token::kw_op); |
2032 | |
2033 | // If it isn't followed by a `<`, the `op` keyword is treated as a normal |
2034 | // identifier. |
2035 | if (curToken.isNot(k: Token::less)) { |
2036 | resetToken(tokLoc: loc); |
2037 | return parseIdentifierExpr(); |
2038 | } |
2039 | |
2040 | // Parse the operation name. The name may be elided, in which case the |
2041 | // operation refers to "any" operation(i.e. a difference between `MyOp` and |
2042 | // `Operation*`). Operation names within a rewrite context must be named. |
2043 | bool allowEmptyName = parserContext != ParserContext::Rewrite; |
2044 | FailureOr<ast::OpNameDecl *> opNameDecl = |
2045 | parseWrappedOperationName(allowEmptyName); |
2046 | if (failed(Result: opNameDecl)) |
2047 | return failure(); |
2048 | std::optional<StringRef> opName = (*opNameDecl)->getName(); |
2049 | |
2050 | // Functor used to create an implicit range variable, used for implicit "all" |
2051 | // operand or results variables. |
2052 | auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { |
2053 | FailureOr<ast::VariableDecl *> rangeVar = |
2054 | defineVariableDecl(name: "_", nameLoc: loc, type, constraints: ast::ConstraintRef(cst, loc)); |
2055 | assert(succeeded(rangeVar) && "expected range variable to be valid"); |
2056 | return ast::DeclRefExpr::create(ctx, loc, decl: *rangeVar, type); |
2057 | }; |
2058 | |
2059 | // Check for the optional list of operands. |
2060 | SmallVector<ast::Expr *> operands; |
2061 | if (!consumeIf(kind: Token::l_paren)) { |
2062 | // If the operand list isn't specified and we are in a match context, define |
2063 | // an inplace unconstrained operand range corresponding to all of the |
2064 | // operands of the operation. This avoids treating zero operands the same |
2065 | // way as "unconstrained operands". |
2066 | if (parserContext != ParserContext::Rewrite) { |
2067 | operands.push_back(Elt: createImplicitRangeVar( |
2068 | ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); |
2069 | } |
2070 | } else if (!consumeIf(kind: Token::r_paren)) { |
2071 | // If the operand list was specified and non-empty, parse the operands. |
2072 | do { |
2073 | // Check for operand signature code completion. |
2074 | if (curToken.is(k: Token::code_complete)) { |
2075 | codeCompleteOperationOperandsSignature(opName, currentNumOperands: operands.size()); |
2076 | return failure(); |
2077 | } |
2078 | |
2079 | FailureOr<ast::Expr *> operand = parseExpr(); |
2080 | if (failed(Result: operand)) |
2081 | return failure(); |
2082 | operands.push_back(Elt: *operand); |
2083 | } while (consumeIf(kind: Token::comma)); |
2084 | |
2085 | if (failed(Result: parseToken(kind: Token::r_paren, |
2086 | msg: "expected `)` after operation operand list"))) |
2087 | return failure(); |
2088 | } |
2089 | |
2090 | // Check for the optional list of attributes. |
2091 | SmallVector<ast::NamedAttributeDecl *> attributes; |
2092 | if (consumeIf(kind: Token::l_brace)) { |
2093 | do { |
2094 | FailureOr<ast::NamedAttributeDecl *> decl = |
2095 | parseNamedAttributeDecl(parentOpName: opName); |
2096 | if (failed(Result: decl)) |
2097 | return failure(); |
2098 | attributes.emplace_back(Args&: *decl); |
2099 | } while (consumeIf(kind: Token::comma)); |
2100 | |
2101 | if (failed(Result: parseToken(kind: Token::r_brace, |
2102 | msg: "expected `}` after operation attribute list"))) |
2103 | return failure(); |
2104 | } |
2105 | |
2106 | // Handle the result types of the operation. |
2107 | SmallVector<ast::Expr *> resultTypes; |
2108 | OpResultTypeContext resultTypeContext = inputResultTypeContext; |
2109 | |
2110 | // Check for an explicit list of result types. |
2111 | if (consumeIf(kind: Token::arrow)) { |
2112 | if (failed(Result: parseToken(kind: Token::l_paren, |
2113 | msg: "expected `(` before operation result type list"))) |
2114 | return failure(); |
2115 | |
2116 | // If result types are provided, initially assume that the operation does |
2117 | // not rely on type inferrence. We don't assert that it isn't, because we |
2118 | // may be inferring the value of some type/type range variables, but given |
2119 | // that these variables may be defined in calls we can't always discern when |
2120 | // this is the case. |
2121 | resultTypeContext = OpResultTypeContext::Explicit; |
2122 | |
2123 | // Handle the case of an empty result list. |
2124 | if (!consumeIf(kind: Token::r_paren)) { |
2125 | do { |
2126 | // Check for result signature code completion. |
2127 | if (curToken.is(k: Token::code_complete)) { |
2128 | codeCompleteOperationResultsSignature(opName, currentNumResults: resultTypes.size()); |
2129 | return failure(); |
2130 | } |
2131 | |
2132 | FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); |
2133 | if (failed(Result: resultTypeExpr)) |
2134 | return failure(); |
2135 | resultTypes.push_back(Elt: *resultTypeExpr); |
2136 | } while (consumeIf(kind: Token::comma)); |
2137 | |
2138 | if (failed(Result: parseToken(kind: Token::r_paren, |
2139 | msg: "expected `)` after operation result type list"))) |
2140 | return failure(); |
2141 | } |
2142 | } else if (parserContext != ParserContext::Rewrite) { |
2143 | // If the result list isn't specified and we are in a match context, define |
2144 | // an inplace unconstrained result range corresponding to all of the results |
2145 | // of the operation. This avoids treating zero results the same way as |
2146 | // "unconstrained results". |
2147 | resultTypes.push_back(Elt: createImplicitRangeVar( |
2148 | ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); |
2149 | } else if (resultTypeContext == OpResultTypeContext::Explicit) { |
2150 | // If the result list isn't specified and we are in a rewrite, try to infer |
2151 | // them at runtime instead. |
2152 | resultTypeContext = OpResultTypeContext::Interface; |
2153 | } |
2154 | |
2155 | return createOperationExpr(loc, name: *opNameDecl, resultTypeContext, operands, |
2156 | attributes, results&: resultTypes); |
2157 | } |
2158 | |
2159 | FailureOr<ast::Expr *> Parser::parseTupleExpr() { |
2160 | SMRange loc = curToken.getLoc(); |
2161 | consumeToken(kind: Token::l_paren); |
2162 | |
2163 | DenseMap<StringRef, SMRange> usedNames; |
2164 | SmallVector<StringRef> elementNames; |
2165 | SmallVector<ast::Expr *> elements; |
2166 | if (curToken.isNot(k: Token::r_paren)) { |
2167 | do { |
2168 | // Check for the optional element name assignment before the value. |
2169 | StringRef elementName; |
2170 | if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) { |
2171 | Token elementNameTok = curToken; |
2172 | consumeToken(); |
2173 | |
2174 | // The element name is only present if followed by an `=`. |
2175 | if (consumeIf(kind: Token::equal)) { |
2176 | elementName = elementNameTok.getSpelling(); |
2177 | |
2178 | // Check to see if this name is already used. |
2179 | auto elementNameIt = |
2180 | usedNames.try_emplace(Key: elementName, Args: elementNameTok.getLoc()); |
2181 | if (!elementNameIt.second) { |
2182 | return emitErrorAndNote( |
2183 | loc: elementNameTok.getLoc(), |
2184 | msg: llvm::formatv(Fmt: "duplicate tuple element label `{0}`", |
2185 | Vals&: elementName), |
2186 | noteLoc: elementNameIt.first->getSecond(), |
2187 | note: "see previous label use here"); |
2188 | } |
2189 | } else { |
2190 | // Otherwise, we treat this as part of an expression so reset the |
2191 | // lexer. |
2192 | resetToken(tokLoc: elementNameTok.getLoc()); |
2193 | } |
2194 | } |
2195 | elementNames.push_back(Elt: elementName); |
2196 | |
2197 | // Parse the tuple element value. |
2198 | FailureOr<ast::Expr *> element = parseExpr(); |
2199 | if (failed(Result: element)) |
2200 | return failure(); |
2201 | elements.push_back(Elt: *element); |
2202 | } while (consumeIf(kind: Token::comma)); |
2203 | } |
2204 | loc.End = curToken.getEndLoc(); |
2205 | if (failed( |
2206 | Result: parseToken(kind: Token::r_paren, msg: "expected `)` after tuple element list"))) |
2207 | return failure(); |
2208 | return createTupleExpr(loc, elements, elementNames); |
2209 | } |
2210 | |
2211 | FailureOr<ast::Expr *> Parser::parseTypeExpr() { |
2212 | SMRange loc = curToken.getLoc(); |
2213 | consumeToken(kind: Token::kw_type); |
2214 | |
2215 | // If we aren't followed by a `<`, the `type` keyword is treated as a normal |
2216 | // identifier. |
2217 | if (!consumeIf(kind: Token::less)) { |
2218 | resetToken(tokLoc: loc); |
2219 | return parseIdentifierExpr(); |
2220 | } |
2221 | |
2222 | if (!curToken.isString()) |
2223 | return emitError(msg: "expected string literal containing MLIR type"); |
2224 | std::string attrExpr = curToken.getStringValue(); |
2225 | consumeToken(); |
2226 | |
2227 | loc.End = curToken.getEndLoc(); |
2228 | if (failed(Result: parseToken(kind: Token::greater, msg: "expected `>` after type literal"))) |
2229 | return failure(); |
2230 | return ast::TypeExpr::create(ctx, loc, value: attrExpr); |
2231 | } |
2232 | |
2233 | FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { |
2234 | StringRef name = curToken.getSpelling(); |
2235 | SMRange nameLoc = curToken.getLoc(); |
2236 | consumeToken(kind: Token::underscore); |
2237 | |
2238 | // Underscore expressions require a constraint list. |
2239 | if (failed(Result: parseToken(kind: Token::colon, msg: "expected `:` after `_` variable"))) |
2240 | return failure(); |
2241 | |
2242 | // Parse the constraints for the expression. |
2243 | SmallVector<ast::ConstraintRef> constraints; |
2244 | if (failed(Result: parseVariableDeclConstraintList(constraints))) |
2245 | return failure(); |
2246 | |
2247 | ast::Type type; |
2248 | if (failed(Result: validateVariableConstraints(constraints, inferredType&: type))) |
2249 | return failure(); |
2250 | return createInlineVariableExpr(type, name, loc: nameLoc, constraints); |
2251 | } |
2252 | |
2253 | //===----------------------------------------------------------------------===// |
2254 | // Stmts |
2255 | //===----------------------------------------------------------------------===// |
2256 | |
2257 | FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { |
2258 | FailureOr<ast::Stmt *> stmt; |
2259 | switch (curToken.getKind()) { |
2260 | case Token::kw_erase: |
2261 | stmt = parseEraseStmt(); |
2262 | break; |
2263 | case Token::kw_let: |
2264 | stmt = parseLetStmt(); |
2265 | break; |
2266 | case Token::kw_replace: |
2267 | stmt = parseReplaceStmt(); |
2268 | break; |
2269 | case Token::kw_return: |
2270 | stmt = parseReturnStmt(); |
2271 | break; |
2272 | case Token::kw_rewrite: |
2273 | stmt = parseRewriteStmt(); |
2274 | break; |
2275 | default: |
2276 | stmt = parseExpr(); |
2277 | break; |
2278 | } |
2279 | if (failed(Result: stmt) || |
2280 | (expectTerminalSemicolon && |
2281 | failed(Result: parseToken(kind: Token::semicolon, msg: "expected `;` after statement")))) |
2282 | return failure(); |
2283 | return stmt; |
2284 | } |
2285 | |
2286 | FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { |
2287 | SMLoc startLoc = curToken.getStartLoc(); |
2288 | consumeToken(kind: Token::l_brace); |
2289 | |
2290 | // Push a new block scope and parse any nested statements. |
2291 | pushDeclScope(); |
2292 | SmallVector<ast::Stmt *> statements; |
2293 | while (curToken.isNot(k: Token::r_brace)) { |
2294 | FailureOr<ast::Stmt *> statement = parseStmt(); |
2295 | if (failed(Result: statement)) |
2296 | return popDeclScope(), failure(); |
2297 | statements.push_back(Elt: *statement); |
2298 | } |
2299 | popDeclScope(); |
2300 | |
2301 | // Consume the end brace. |
2302 | SMRange location(startLoc, curToken.getEndLoc()); |
2303 | consumeToken(kind: Token::r_brace); |
2304 | |
2305 | return ast::CompoundStmt::create(ctx, location, children: statements); |
2306 | } |
2307 | |
2308 | FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { |
2309 | if (parserContext == ParserContext::Constraint) |
2310 | return emitError(msg: "`erase` cannot be used within a Constraint"); |
2311 | SMRange loc = curToken.getLoc(); |
2312 | consumeToken(kind: Token::kw_erase); |
2313 | |
2314 | // Parse the root operation expression. |
2315 | FailureOr<ast::Expr *> rootOp = parseExpr(); |
2316 | if (failed(Result: rootOp)) |
2317 | return failure(); |
2318 | |
2319 | return createEraseStmt(loc, rootOp: *rootOp); |
2320 | } |
2321 | |
2322 | FailureOr<ast::LetStmt *> Parser::parseLetStmt() { |
2323 | SMRange loc = curToken.getLoc(); |
2324 | consumeToken(kind: Token::kw_let); |
2325 | |
2326 | // Parse the name of the new variable. |
2327 | SMRange varLoc = curToken.getLoc(); |
2328 | if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword()) { |
2329 | // `_` is a reserved variable name. |
2330 | if (curToken.is(k: Token::underscore)) { |
2331 | return emitError(loc: varLoc, |
2332 | msg: "`_` may only be used to define \"inline\" variables"); |
2333 | } |
2334 | return emitError(loc: varLoc, |
2335 | msg: "expected identifier after `let` to name a new variable"); |
2336 | } |
2337 | StringRef varName = curToken.getSpelling(); |
2338 | consumeToken(); |
2339 | |
2340 | // Parse the optional set of constraints. |
2341 | SmallVector<ast::ConstraintRef> constraints; |
2342 | if (consumeIf(kind: Token::colon) && |
2343 | failed(Result: parseVariableDeclConstraintList(constraints))) |
2344 | return failure(); |
2345 | |
2346 | // Parse the optional initializer expression. |
2347 | ast::Expr *initializer = nullptr; |
2348 | if (consumeIf(kind: Token::equal)) { |
2349 | FailureOr<ast::Expr *> initOrFailure = parseExpr(); |
2350 | if (failed(Result: initOrFailure)) |
2351 | return failure(); |
2352 | initializer = *initOrFailure; |
2353 | |
2354 | // Check that the constraints are compatible with having an initializer, |
2355 | // e.g. type constraints cannot be used with initializers. |
2356 | for (ast::ConstraintRef constraint : constraints) { |
2357 | LogicalResult result = |
2358 | TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) |
2359 | .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, |
2360 | ast::ValueRangeConstraintDecl>(caseFn: [&](const auto *cst) { |
2361 | if (cst->getTypeExpr()) { |
2362 | return this->emitError( |
2363 | loc: constraint.referenceLoc, |
2364 | msg: "type constraints are not permitted on variables with " |
2365 | "initializers"); |
2366 | } |
2367 | return success(); |
2368 | }) |
2369 | .Default(defaultResult: success()); |
2370 | if (failed(Result: result)) |
2371 | return failure(); |
2372 | } |
2373 | } |
2374 | |
2375 | FailureOr<ast::VariableDecl *> varDecl = |
2376 | createVariableDecl(name: varName, loc: varLoc, initializer, constraints); |
2377 | if (failed(Result: varDecl)) |
2378 | return failure(); |
2379 | return ast::LetStmt::create(ctx, loc, varDecl: *varDecl); |
2380 | } |
2381 | |
2382 | FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { |
2383 | if (parserContext == ParserContext::Constraint) |
2384 | return emitError(msg: "`replace` cannot be used within a Constraint"); |
2385 | SMRange loc = curToken.getLoc(); |
2386 | consumeToken(kind: Token::kw_replace); |
2387 | |
2388 | // Parse the root operation expression. |
2389 | FailureOr<ast::Expr *> rootOp = parseExpr(); |
2390 | if (failed(Result: rootOp)) |
2391 | return failure(); |
2392 | |
2393 | if (failed( |
2394 | Result: parseToken(kind: Token::kw_with, msg: "expected `with` after root operation"))) |
2395 | return failure(); |
2396 | |
2397 | // The replacement portion of this statement is within a rewrite context. |
2398 | llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite); |
2399 | |
2400 | // Parse the replacement values. |
2401 | SmallVector<ast::Expr *> replValues; |
2402 | if (consumeIf(kind: Token::l_paren)) { |
2403 | if (consumeIf(kind: Token::r_paren)) { |
2404 | return emitError( |
2405 | loc, msg: "expected at least one replacement value, consider using " |
2406 | "`erase` if no replacement values are desired"); |
2407 | } |
2408 | |
2409 | do { |
2410 | FailureOr<ast::Expr *> replExpr = parseExpr(); |
2411 | if (failed(Result: replExpr)) |
2412 | return failure(); |
2413 | replValues.emplace_back(Args&: *replExpr); |
2414 | } while (consumeIf(kind: Token::comma)); |
2415 | |
2416 | if (failed(Result: parseToken(kind: Token::r_paren, |
2417 | msg: "expected `)` after replacement values"))) |
2418 | return failure(); |
2419 | } else { |
2420 | // Handle replacement with an operation uniquely, as the replacement |
2421 | // operation supports type inferrence from the root operation. |
2422 | FailureOr<ast::Expr *> replExpr; |
2423 | if (curToken.is(k: Token::kw_op)) |
2424 | replExpr = parseOperationExpr(inputResultTypeContext: OpResultTypeContext::Replacement); |
2425 | else |
2426 | replExpr = parseExpr(); |
2427 | if (failed(Result: replExpr)) |
2428 | return failure(); |
2429 | replValues.emplace_back(Args&: *replExpr); |
2430 | } |
2431 | |
2432 | return createReplaceStmt(loc, rootOp: *rootOp, replValues); |
2433 | } |
2434 | |
2435 | FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { |
2436 | SMRange loc = curToken.getLoc(); |
2437 | consumeToken(kind: Token::kw_return); |
2438 | |
2439 | // Parse the result value. |
2440 | FailureOr<ast::Expr *> resultExpr = parseExpr(); |
2441 | if (failed(Result: resultExpr)) |
2442 | return failure(); |
2443 | |
2444 | return ast::ReturnStmt::create(ctx, loc, resultExpr: *resultExpr); |
2445 | } |
2446 | |
2447 | FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { |
2448 | if (parserContext == ParserContext::Constraint) |
2449 | return emitError(msg: "`rewrite` cannot be used within a Constraint"); |
2450 | SMRange loc = curToken.getLoc(); |
2451 | consumeToken(kind: Token::kw_rewrite); |
2452 | |
2453 | // Parse the root operation. |
2454 | FailureOr<ast::Expr *> rootOp = parseExpr(); |
2455 | if (failed(Result: rootOp)) |
2456 | return failure(); |
2457 | |
2458 | if (failed(Result: parseToken(kind: Token::kw_with, msg: "expected `with` before rewrite body"))) |
2459 | return failure(); |
2460 | |
2461 | if (curToken.isNot(k: Token::l_brace)) |
2462 | return emitError(msg: "expected `{` to start rewrite body"); |
2463 | |
2464 | // The rewrite body of this statement is within a rewrite context. |
2465 | llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite); |
2466 | |
2467 | FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); |
2468 | if (failed(Result: rewriteBody)) |
2469 | return failure(); |
2470 | |
2471 | // Verify the rewrite body. |
2472 | for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { |
2473 | if (isa<ast::ReturnStmt>(Val: stmt)) { |
2474 | return emitError(loc: stmt->getLoc(), |
2475 | msg: "`return` statements are only permitted within a " |
2476 | "`Constraint` or `Rewrite` body"); |
2477 | } |
2478 | } |
2479 | |
2480 | return createRewriteStmt(loc, rootOp: *rootOp, rewriteBody: *rewriteBody); |
2481 | } |
2482 | |
2483 | //===----------------------------------------------------------------------===// |
2484 | // Creation+Analysis |
2485 | //===----------------------------------------------------------------------===// |
2486 | |
2487 | //===----------------------------------------------------------------------===// |
2488 | // Decls |
2489 | //===----------------------------------------------------------------------===// |
2490 | |
2491 | ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { |
2492 | // Unwrap reference expressions. |
2493 | if (auto *init = dyn_cast<ast::DeclRefExpr>(Val: node)) |
2494 | node = init->getDecl(); |
2495 | return dyn_cast<ast::CallableDecl>(Val: node); |
2496 | } |
2497 | |
2498 | FailureOr<ast::PatternDecl *> |
2499 | Parser::createPatternDecl(SMRange loc, const ast::Name *name, |
2500 | const ParsedPatternMetadata &metadata, |
2501 | ast::CompoundStmt *body) { |
2502 | return ast::PatternDecl::create(ctx, location: loc, name, benefit: metadata.benefit, |
2503 | hasBoundedRecursion: metadata.hasBoundedRecursion, body); |
2504 | } |
2505 | |
2506 | ast::Type Parser::createUserConstraintRewriteResultType( |
2507 | ArrayRef<ast::VariableDecl *> results) { |
2508 | // Single result decls use the type of the single result. |
2509 | if (results.size() == 1) |
2510 | return results[0]->getType(); |
2511 | |
2512 | // Multiple results use a tuple type, with the types and names grabbed from |
2513 | // the result variable decls. |
2514 | auto resultTypes = llvm::map_range( |
2515 | C&: results, F: [&](const auto *result) { return result->getType(); }); |
2516 | auto resultNames = llvm::map_range( |
2517 | C&: results, F: [&](const auto *result) { return result->getName().getName(); }); |
2518 | return ast::TupleType::get(context&: ctx, elementTypes: llvm::to_vector(Range&: resultTypes), |
2519 | elementNames: llvm::to_vector(Range&: resultNames)); |
2520 | } |
2521 | |
2522 | template <typename T> |
2523 | FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( |
2524 | const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, |
2525 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType, |
2526 | ast::CompoundStmt *body) { |
2527 | if (!body->getChildren().empty()) { |
2528 | if (auto *retStmt = dyn_cast<ast::ReturnStmt>(Val: body->getChildren().back())) { |
2529 | ast::Expr *resultExpr = retStmt->getResultExpr(); |
2530 | |
2531 | // Process the result of the decl. If no explicit signature results |
2532 | // were provided, check for return type inference. Otherwise, check that |
2533 | // the return expression can be converted to the expected type. |
2534 | if (results.empty()) |
2535 | resultType = resultExpr->getType(); |
2536 | else if (failed(Result: convertExpressionTo(expr&: resultExpr, type: resultType))) |
2537 | return failure(); |
2538 | else |
2539 | retStmt->setResultExpr(resultExpr); |
2540 | } |
2541 | } |
2542 | return T::createPDLL(ctx, name, arguments, results, body, resultType); |
2543 | } |
2544 | |
2545 | FailureOr<ast::VariableDecl *> |
2546 | Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, |
2547 | ArrayRef<ast::ConstraintRef> constraints) { |
2548 | // The type of the variable, which is expected to be inferred by either a |
2549 | // constraint or an initializer expression. |
2550 | ast::Type type; |
2551 | if (failed(Result: validateVariableConstraints(constraints, inferredType&: type))) |
2552 | return failure(); |
2553 | |
2554 | if (initializer) { |
2555 | // Update the variable type based on the initializer, or try to convert the |
2556 | // initializer to the existing type. |
2557 | if (!type) |
2558 | type = initializer->getType(); |
2559 | else if (ast::Type mergedType = type.refineWith(other: initializer->getType())) |
2560 | type = mergedType; |
2561 | else if (failed(Result: convertExpressionTo(expr&: initializer, type))) |
2562 | return failure(); |
2563 | |
2564 | // Otherwise, if there is no initializer check that the type has already |
2565 | // been resolved from the constraint list. |
2566 | } else if (!type) { |
2567 | return emitErrorAndNote( |
2568 | loc, msg: "unable to infer type for variable `"+ name + "`", noteLoc: loc, |
2569 | note: "the type of a variable must be inferable from the constraint " |
2570 | "list or the initializer"); |
2571 | } |
2572 | |
2573 | // Constraint types cannot be used when defining variables. |
2574 | if (isa<ast::ConstraintType, ast::RewriteType>(Val: type)) { |
2575 | return emitError( |
2576 | loc, msg: llvm::formatv(Fmt: "unable to define variable of `{0}` type", Vals&: type)); |
2577 | } |
2578 | |
2579 | // Try to define a variable with the given name. |
2580 | FailureOr<ast::VariableDecl *> varDecl = |
2581 | defineVariableDecl(name, nameLoc: loc, type, initExpr: initializer, constraints); |
2582 | if (failed(Result: varDecl)) |
2583 | return failure(); |
2584 | |
2585 | return *varDecl; |
2586 | } |
2587 | |
2588 | FailureOr<ast::VariableDecl *> |
2589 | Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, |
2590 | const ast::ConstraintRef &constraint) { |
2591 | ast::Type argType; |
2592 | if (failed(Result: validateVariableConstraint(ref: constraint, inferredType&: argType))) |
2593 | return failure(); |
2594 | return defineVariableDecl(name, nameLoc: loc, type: argType, constraints: constraint); |
2595 | } |
2596 | |
2597 | LogicalResult |
2598 | Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, |
2599 | ast::Type &inferredType) { |
2600 | for (const ast::ConstraintRef &ref : constraints) |
2601 | if (failed(Result: validateVariableConstraint(ref, inferredType))) |
2602 | return failure(); |
2603 | return success(); |
2604 | } |
2605 | |
2606 | LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, |
2607 | ast::Type &inferredType) { |
2608 | ast::Type constraintType; |
2609 | if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(Val: ref.constraint)) { |
2610 | if (const ast::Expr *typeExpr = cst->getTypeExpr()) { |
2611 | if (failed(Result: validateTypeConstraintExpr(typeExpr))) |
2612 | return failure(); |
2613 | } |
2614 | constraintType = ast::AttributeType::get(context&: ctx); |
2615 | } else if (const auto *cst = |
2616 | dyn_cast<ast::OpConstraintDecl>(Val: ref.constraint)) { |
2617 | constraintType = ast::OperationType::get( |
2618 | context&: ctx, name: cst->getName(), odsOp: lookupODSOperation(opName: cst->getName())); |
2619 | } else if (isa<ast::TypeConstraintDecl>(Val: ref.constraint)) { |
2620 | constraintType = typeTy; |
2621 | } else if (isa<ast::TypeRangeConstraintDecl>(Val: ref.constraint)) { |
2622 | constraintType = typeRangeTy; |
2623 | } else if (const auto *cst = |
2624 | dyn_cast<ast::ValueConstraintDecl>(Val: ref.constraint)) { |
2625 | if (const ast::Expr *typeExpr = cst->getTypeExpr()) { |
2626 | if (failed(Result: validateTypeConstraintExpr(typeExpr))) |
2627 | return failure(); |
2628 | } |
2629 | constraintType = valueTy; |
2630 | } else if (const auto *cst = |
2631 | dyn_cast<ast::ValueRangeConstraintDecl>(Val: ref.constraint)) { |
2632 | if (const ast::Expr *typeExpr = cst->getTypeExpr()) { |
2633 | if (failed(Result: validateTypeRangeConstraintExpr(typeExpr))) |
2634 | return failure(); |
2635 | } |
2636 | constraintType = valueRangeTy; |
2637 | } else if (const auto *cst = |
2638 | dyn_cast<ast::UserConstraintDecl>(Val: ref.constraint)) { |
2639 | ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); |
2640 | if (inputs.size() != 1) { |
2641 | return emitErrorAndNote(loc: ref.referenceLoc, |
2642 | msg: "`Constraint`s applied via a variable constraint " |
2643 | "list must take a single input, but got "+ |
2644 | Twine(inputs.size()), |
2645 | noteLoc: cst->getLoc(), |
2646 | note: "see definition of constraint here"); |
2647 | } |
2648 | constraintType = inputs.front()->getType(); |
2649 | } else { |
2650 | llvm_unreachable("unknown constraint type"); |
2651 | } |
2652 | |
2653 | // Check that the constraint type is compatible with the current inferred |
2654 | // type. |
2655 | if (!inferredType) { |
2656 | inferredType = constraintType; |
2657 | } else if (ast::Type mergedTy = inferredType.refineWith(other: constraintType)) { |
2658 | inferredType = mergedTy; |
2659 | } else { |
2660 | return emitError(loc: ref.referenceLoc, |
2661 | msg: llvm::formatv(Fmt: "constraint type `{0}` is incompatible " |
2662 | "with the previously inferred type `{1}`", |
2663 | Vals&: constraintType, Vals&: inferredType)); |
2664 | } |
2665 | return success(); |
2666 | } |
2667 | |
2668 | LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { |
2669 | ast::Type typeExprType = typeExpr->getType(); |
2670 | if (typeExprType != typeTy) { |
2671 | return emitError(loc: typeExpr->getLoc(), |
2672 | msg: "expected expression of `Type` in type constraint"); |
2673 | } |
2674 | return success(); |
2675 | } |
2676 | |
2677 | LogicalResult |
2678 | Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { |
2679 | ast::Type typeExprType = typeExpr->getType(); |
2680 | if (typeExprType != typeRangeTy) { |
2681 | return emitError(loc: typeExpr->getLoc(), |
2682 | msg: "expected expression of `TypeRange` in type constraint"); |
2683 | } |
2684 | return success(); |
2685 | } |
2686 | |
2687 | //===----------------------------------------------------------------------===// |
2688 | // Exprs |
2689 | //===----------------------------------------------------------------------===// |
2690 | |
2691 | FailureOr<ast::CallExpr *> |
2692 | Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, |
2693 | MutableArrayRef<ast::Expr *> arguments, bool isNegated) { |
2694 | ast::Type parentType = parentExpr->getType(); |
2695 | |
2696 | ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parentExpr); |
2697 | if (!callableDecl) { |
2698 | return emitError(loc, |
2699 | msg: llvm::formatv(Fmt: "expected a reference to a callable " |
2700 | "`Constraint` or `Rewrite`, but got: `{0}`", |
2701 | Vals&: parentType)); |
2702 | } |
2703 | if (parserContext == ParserContext::Rewrite) { |
2704 | if (isa<ast::UserConstraintDecl>(Val: callableDecl)) |
2705 | return emitError( |
2706 | loc, msg: "unable to invoke `Constraint` within a rewrite section"); |
2707 | if (isNegated) |
2708 | return emitError(loc, msg: "unable to negate a Rewrite"); |
2709 | } else { |
2710 | if (isa<ast::UserRewriteDecl>(Val: callableDecl)) |
2711 | return emitError(loc, |
2712 | msg: "unable to invoke `Rewrite` within a match section"); |
2713 | if (isNegated && cast<ast::UserConstraintDecl>(Val: callableDecl)->getBody()) |
2714 | return emitError(loc, msg: "unable to negate non native constraints"); |
2715 | } |
2716 | |
2717 | // Verify the arguments of the call. |
2718 | /// Handle size mismatch. |
2719 | ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); |
2720 | if (callArgs.size() != arguments.size()) { |
2721 | return emitErrorAndNote( |
2722 | loc, |
2723 | msg: llvm::formatv(Fmt: "invalid number of arguments for {0} call; expected " |
2724 | "{1}, but got {2}", |
2725 | Vals: callableDecl->getCallableType(), Vals: callArgs.size(), |
2726 | Vals: arguments.size()), |
2727 | noteLoc: callableDecl->getLoc(), |
2728 | note: llvm::formatv(Fmt: "see the definition of {0} here", |
2729 | Vals: callableDecl->getName()->getName())); |
2730 | } |
2731 | |
2732 | /// Handle argument type mismatch. |
2733 | auto attachDiagFn = [&](ast::Diagnostic &diag) { |
2734 | diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", |
2735 | Vals: callableDecl->getName()->getName()), |
2736 | noteLoc: callableDecl->getLoc()); |
2737 | }; |
2738 | for (auto it : llvm::zip(t&: callArgs, u&: arguments)) { |
2739 | if (failed(Result: convertExpressionTo(expr&: std::get<1>(t&: it), type: std::get<0>(t&: it)->getType(), |
2740 | noteAttachFn: attachDiagFn))) |
2741 | return failure(); |
2742 | } |
2743 | |
2744 | return ast::CallExpr::create(ctx, loc, callable: parentExpr, arguments, |
2745 | resultType: callableDecl->getResultType(), isNegated); |
2746 | } |
2747 | |
2748 | FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, |
2749 | ast::Decl *decl) { |
2750 | // Check the type of decl being referenced. |
2751 | ast::Type declType; |
2752 | if (isa<ast::ConstraintDecl>(Val: decl)) |
2753 | declType = ast::ConstraintType::get(context&: ctx); |
2754 | else if (isa<ast::UserRewriteDecl>(Val: decl)) |
2755 | declType = ast::RewriteType::get(context&: ctx); |
2756 | else if (auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl)) |
2757 | declType = varDecl->getType(); |
2758 | else |
2759 | return emitError(loc, msg: "invalid reference to `"+ |
2760 | decl->getName()->getName() + "`"); |
2761 | |
2762 | return ast::DeclRefExpr::create(ctx, loc, decl, type: declType); |
2763 | } |
2764 | |
2765 | FailureOr<ast::DeclRefExpr *> |
2766 | Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, |
2767 | ArrayRef<ast::ConstraintRef> constraints) { |
2768 | FailureOr<ast::VariableDecl *> decl = |
2769 | defineVariableDecl(name, nameLoc: loc, type, constraints); |
2770 | if (failed(Result: decl)) |
2771 | return failure(); |
2772 | return ast::DeclRefExpr::create(ctx, loc, decl: *decl, type); |
2773 | } |
2774 | |
2775 | FailureOr<ast::MemberAccessExpr *> |
2776 | Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, |
2777 | SMRange loc) { |
2778 | // Validate the member name for the given parent expression. |
2779 | FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); |
2780 | if (failed(Result: memberType)) |
2781 | return failure(); |
2782 | |
2783 | return ast::MemberAccessExpr::create(ctx, loc, parentExpr, memberName: name, type: *memberType); |
2784 | } |
2785 | |
2786 | FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, |
2787 | StringRef name, SMRange loc) { |
2788 | ast::Type parentType = parentExpr->getType(); |
2789 | if (ast::OperationType opType = dyn_cast<ast::OperationType>(Val&: parentType)) { |
2790 | if (name == ast::AllResultsMemberAccessExpr::getMemberName()) |
2791 | return valueRangeTy; |
2792 | |
2793 | // Verify member access based on the operation type. |
2794 | if (const ods::Operation *odsOp = opType.getODSOperation()) { |
2795 | auto results = odsOp->getResults(); |
2796 | |
2797 | // Handle indexed results. |
2798 | unsigned index = 0; |
2799 | if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) && |
2800 | index < results.size()) { |
2801 | return results[index].isVariadic() ? valueRangeTy : valueTy; |
2802 | } |
2803 | |
2804 | // Handle named results. |
2805 | const auto *it = llvm::find_if(Range&: results, P: [&](const auto &result) { |
2806 | return result.getName() == name; |
2807 | }); |
2808 | if (it != results.end()) |
2809 | return it->isVariadic() ? valueRangeTy : valueTy; |
2810 | } else if (llvm::isDigit(C: name[0])) { |
2811 | // Allow unchecked numeric indexing of the results of unregistered |
2812 | // operations. It returns a single value. |
2813 | return valueTy; |
2814 | } |
2815 | } else if (auto tupleType = dyn_cast<ast::TupleType>(Val&: parentType)) { |
2816 | // Handle indexed results. |
2817 | unsigned index = 0; |
2818 | if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) && |
2819 | index < tupleType.size()) { |
2820 | return tupleType.getElementTypes()[index]; |
2821 | } |
2822 | |
2823 | // Handle named results. |
2824 | auto elementNames = tupleType.getElementNames(); |
2825 | const auto *it = llvm::find(Range&: elementNames, Val: name); |
2826 | if (it != elementNames.end()) |
2827 | return tupleType.getElementTypes()[it - elementNames.begin()]; |
2828 | } |
2829 | return emitError( |
2830 | loc, |
2831 | msg: llvm::formatv(Fmt: "invalid member access `{0}` on expression of type `{1}`", |
2832 | Vals&: name, Vals&: parentType)); |
2833 | } |
2834 | |
2835 | FailureOr<ast::OperationExpr *> Parser::createOperationExpr( |
2836 | SMRange loc, const ast::OpNameDecl *name, |
2837 | OpResultTypeContext resultTypeContext, |
2838 | SmallVectorImpl<ast::Expr *> &operands, |
2839 | MutableArrayRef<ast::NamedAttributeDecl *> attributes, |
2840 | SmallVectorImpl<ast::Expr *> &results) { |
2841 | std::optional<StringRef> opNameRef = name->getName(); |
2842 | const ods::Operation *odsOp = lookupODSOperation(opName: opNameRef); |
2843 | |
2844 | // Verify the inputs operands. |
2845 | if (failed(Result: validateOperationOperands(loc, name: opNameRef, odsOp, operands))) |
2846 | return failure(); |
2847 | |
2848 | // Verify the attribute list. |
2849 | for (ast::NamedAttributeDecl *attr : attributes) { |
2850 | // Check for an attribute type, or a type awaiting resolution. |
2851 | ast::Type attrType = attr->getValue()->getType(); |
2852 | if (!isa<ast::AttributeType>(Val: attrType)) { |
2853 | return emitError( |
2854 | loc: attr->getValue()->getLoc(), |
2855 | msg: llvm::formatv(Fmt: "expected `Attr` expression, but got `{0}`", Vals&: attrType)); |
2856 | } |
2857 | } |
2858 | |
2859 | assert( |
2860 | (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) && |
2861 | "unexpected inferrence when results were explicitly specified"); |
2862 | |
2863 | // If we aren't relying on type inferrence, or explicit results were provided, |
2864 | // validate them. |
2865 | if (resultTypeContext == OpResultTypeContext::Explicit) { |
2866 | if (failed(Result: validateOperationResults(loc, name: opNameRef, odsOp, results))) |
2867 | return failure(); |
2868 | |
2869 | // Validate the use of interface based type inferrence for this operation. |
2870 | } else if (resultTypeContext == OpResultTypeContext::Interface) { |
2871 | assert(opNameRef && |
2872 | "expected valid operation name when inferring operation results"); |
2873 | checkOperationResultTypeInferrence(loc, name: *opNameRef, odsOp); |
2874 | } |
2875 | |
2876 | return ast::OperationExpr::create(ctx, loc, odsOp, nameDecl: name, operands, resultTypes: results, |
2877 | attributes); |
2878 | } |
2879 | |
2880 | LogicalResult |
2881 | Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name, |
2882 | const ods::Operation *odsOp, |
2883 | SmallVectorImpl<ast::Expr *> &operands) { |
2884 | return validateOperationOperandsOrResults( |
2885 | groupName: "operand", loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name, |
2886 | values&: operands, odsValues: odsOp ? odsOp->getOperands() : std::nullopt, singleTy: valueTy, |
2887 | rangeTy: valueRangeTy); |
2888 | } |
2889 | |
2890 | LogicalResult |
2891 | Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name, |
2892 | const ods::Operation *odsOp, |
2893 | SmallVectorImpl<ast::Expr *> &results) { |
2894 | return validateOperationOperandsOrResults( |
2895 | groupName: "result", loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name, |
2896 | values&: results, odsValues: odsOp ? odsOp->getResults() : std::nullopt, singleTy: typeTy, rangeTy: typeRangeTy); |
2897 | } |
2898 | |
2899 | void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName, |
2900 | const ods::Operation *odsOp) { |
2901 | // If the operation might not have inferrence support, emit a warning to the |
2902 | // user. We don't emit an error because the interface might be added to the |
2903 | // operation at runtime. It's rare, but it could still happen. We emit a |
2904 | // warning here instead. |
2905 | |
2906 | // Handle inferrence warnings for unknown operations. |
2907 | if (!odsOp) { |
2908 | ctx.getDiagEngine().emitWarning( |
2909 | loc, msg: llvm::formatv( |
2910 | Fmt: "operation result types are marked to be inferred, but " |
2911 | "`{0}` is unknown. Ensure that `{0}` supports zero " |
2912 | "results or implements `InferTypeOpInterface`. Include " |
2913 | "the ODS definition of this operation to remove this warning.", |
2914 | Vals&: opName)); |
2915 | return; |
2916 | } |
2917 | |
2918 | // Handle inferrence warnings for known operations that expected at least one |
2919 | // result, but don't have inference support. An elided results list can mean |
2920 | // "zero-results", and we don't want to warn when that is the expected |
2921 | // behavior. |
2922 | bool requiresInferrence = |
2923 | llvm::any_of(Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) { |
2924 | return !result.isVariableLength(); |
2925 | }); |
2926 | if (requiresInferrence && !odsOp->hasResultTypeInferrence()) { |
2927 | ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning( |
2928 | loc, |
2929 | msg: llvm::formatv(Fmt: "operation result types are marked to be inferred, but " |
2930 | "`{0}` does not provide an implementation of " |
2931 | "`InferTypeOpInterface`. Ensure that `{0}` attaches " |
2932 | "`InferTypeOpInterface` at runtime, or add support to " |
2933 | "the ODS definition to remove this warning.", |
2934 | Vals&: opName)); |
2935 | diag->attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: opName), |
2936 | noteLoc: odsOp->getLoc()); |
2937 | return; |
2938 | } |
2939 | } |
2940 | |
2941 | LogicalResult Parser::validateOperationOperandsOrResults( |
2942 | StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc, |
2943 | std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values, |
2944 | ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, |
2945 | ast::RangeType rangeTy) { |
2946 | // All operation types accept a single range parameter. |
2947 | if (values.size() == 1) { |
2948 | if (failed(Result: convertExpressionTo(expr&: values[0], type: rangeTy))) |
2949 | return failure(); |
2950 | return success(); |
2951 | } |
2952 | |
2953 | /// If the operation has ODS information, we can more accurately verify the |
2954 | /// values. |
2955 | if (odsOpLoc) { |
2956 | auto emitSizeMismatchError = [&] { |
2957 | return emitErrorAndNote( |
2958 | loc, |
2959 | msg: llvm::formatv(Fmt: "invalid number of {0} groups for `{1}`; expected " |
2960 | "{2}, but got {3}", |
2961 | Vals&: groupName, Vals&: *name, Vals: odsValues.size(), Vals: values.size()), |
2962 | noteLoc: *odsOpLoc, note: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: *name)); |
2963 | }; |
2964 | |
2965 | // Handle the case where no values were provided. |
2966 | if (values.empty()) { |
2967 | // If we don't expect any on the ODS side, we are done. |
2968 | if (odsValues.empty()) |
2969 | return success(); |
2970 | |
2971 | // If we do, check if we actually need to provide values (i.e. if any of |
2972 | // the values are actually required). |
2973 | unsigned numVariadic = 0; |
2974 | for (const auto &odsValue : odsValues) { |
2975 | if (!odsValue.isVariableLength()) |
2976 | return emitSizeMismatchError(); |
2977 | ++numVariadic; |
2978 | } |
2979 | |
2980 | // If we are in a non-rewrite context, we don't need to do anything more. |
2981 | // Zero-values is a valid constraint on the operation. |
2982 | if (parserContext != ParserContext::Rewrite) |
2983 | return success(); |
2984 | |
2985 | // Otherwise, when in a rewrite we may need to provide values to match the |
2986 | // ODS signature of the operation to create. |
2987 | |
2988 | // If we only have one variadic value, just use an empty list. |
2989 | if (numVariadic == 1) |
2990 | return success(); |
2991 | |
2992 | // Otherwise, create dummy values for each of the entries so that we |
2993 | // adhere to the ODS signature. |
2994 | for (unsigned i = 0, e = odsValues.size(); i < e; ++i) { |
2995 | values.push_back(Elt: ast::RangeExpr::create( |
2996 | ctx, loc, /*elements=*/std::nullopt, type: rangeTy)); |
2997 | } |
2998 | return success(); |
2999 | } |
3000 | |
3001 | // Verify that the number of values provided matches the number of value |
3002 | // groups ODS expects. |
3003 | if (odsValues.size() != values.size()) |
3004 | return emitSizeMismatchError(); |
3005 | |
3006 | auto diagFn = [&](ast::Diagnostic &diag) { |
3007 | diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here", Vals&: *name), |
3008 | noteLoc: *odsOpLoc); |
3009 | }; |
3010 | for (unsigned i = 0, e = values.size(); i < e; ++i) { |
3011 | ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; |
3012 | if (failed(Result: convertExpressionTo(expr&: values[i], type: expectedType, noteAttachFn: diagFn))) |
3013 | return failure(); |
3014 | } |
3015 | return success(); |
3016 | } |
3017 | |
3018 | // Otherwise, accept the value groups as they have been defined and just |
3019 | // ensure they are one of the expected types. |
3020 | for (ast::Expr *&valueExpr : values) { |
3021 | ast::Type valueExprType = valueExpr->getType(); |
3022 | |
3023 | // Check if this is one of the expected types. |
3024 | if (valueExprType == rangeTy || valueExprType == singleTy) |
3025 | continue; |
3026 | |
3027 | // If the operand is an Operation, allow converting to a Value or |
3028 | // ValueRange. This situations arises quite often with nested operation |
3029 | // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` |
3030 | if (singleTy == valueTy) { |
3031 | if (isa<ast::OperationType>(Val: valueExprType)) { |
3032 | valueExpr = convertOpToValue(opExpr: valueExpr); |
3033 | continue; |
3034 | } |
3035 | } |
3036 | |
3037 | // Otherwise, try to convert the expression to a range. |
3038 | if (succeeded(Result: convertExpressionTo(expr&: valueExpr, type: rangeTy))) |
3039 | continue; |
3040 | |
3041 | return emitError( |
3042 | loc: valueExpr->getLoc(), |
3043 | msg: llvm::formatv( |
3044 | Fmt: "expected `{0}` or `{1}` convertible expression, but got `{2}`", |
3045 | Vals&: singleTy, Vals&: rangeTy, Vals&: valueExprType)); |
3046 | } |
3047 | return success(); |
3048 | } |
3049 | |
3050 | FailureOr<ast::TupleExpr *> |
3051 | Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, |
3052 | ArrayRef<StringRef> elementNames) { |
3053 | for (const ast::Expr *element : elements) { |
3054 | ast::Type eleTy = element->getType(); |
3055 | if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(Val: eleTy)) { |
3056 | return emitError( |
3057 | loc: element->getLoc(), |
3058 | msg: llvm::formatv(Fmt: "unable to build a tuple with `{0}` element", Vals&: eleTy)); |
3059 | } |
3060 | } |
3061 | return ast::TupleExpr::create(ctx, loc, elements, elementNames); |
3062 | } |
3063 | |
3064 | //===----------------------------------------------------------------------===// |
3065 | // Stmts |
3066 | //===----------------------------------------------------------------------===// |
3067 | |
3068 | FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, |
3069 | ast::Expr *rootOp) { |
3070 | // Check that root is an Operation. |
3071 | ast::Type rootType = rootOp->getType(); |
3072 | if (!isa<ast::OperationType>(Val: rootType)) |
3073 | return emitError(loc: rootOp->getLoc(), msg: "expected `Op` expression"); |
3074 | |
3075 | return ast::EraseStmt::create(ctx, loc, rootOp); |
3076 | } |
3077 | |
3078 | FailureOr<ast::ReplaceStmt *> |
3079 | Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, |
3080 | MutableArrayRef<ast::Expr *> replValues) { |
3081 | // Check that root is an Operation. |
3082 | ast::Type rootType = rootOp->getType(); |
3083 | if (!isa<ast::OperationType>(Val: rootType)) { |
3084 | return emitError( |
3085 | loc: rootOp->getLoc(), |
3086 | msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`", Vals&: rootType)); |
3087 | } |
3088 | |
3089 | // If there are multiple replacement values, we implicitly convert any Op |
3090 | // expressions to the value form. |
3091 | bool shouldConvertOpToValues = replValues.size() > 1; |
3092 | for (ast::Expr *&replExpr : replValues) { |
3093 | ast::Type replType = replExpr->getType(); |
3094 | |
3095 | // Check that replExpr is an Operation, Value, or ValueRange. |
3096 | if (isa<ast::OperationType>(Val: replType)) { |
3097 | if (shouldConvertOpToValues) |
3098 | replExpr = convertOpToValue(opExpr: replExpr); |
3099 | continue; |
3100 | } |
3101 | |
3102 | if (replType != valueTy && replType != valueRangeTy) { |
3103 | return emitError(loc: replExpr->getLoc(), |
3104 | msg: llvm::formatv(Fmt: "expected `Op`, `Value` or `ValueRange` " |
3105 | "expression, but got `{0}`", |
3106 | Vals&: replType)); |
3107 | } |
3108 | } |
3109 | |
3110 | return ast::ReplaceStmt::create(ctx, loc, rootOp, replExprs: replValues); |
3111 | } |
3112 | |
3113 | FailureOr<ast::RewriteStmt *> |
3114 | Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, |
3115 | ast::CompoundStmt *rewriteBody) { |
3116 | // Check that root is an Operation. |
3117 | ast::Type rootType = rootOp->getType(); |
3118 | if (!isa<ast::OperationType>(Val: rootType)) { |
3119 | return emitError( |
3120 | loc: rootOp->getLoc(), |
3121 | msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`", Vals&: rootType)); |
3122 | } |
3123 | |
3124 | return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); |
3125 | } |
3126 | |
3127 | //===----------------------------------------------------------------------===// |
3128 | // Code Completion |
3129 | //===----------------------------------------------------------------------===// |
3130 | |
3131 | LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { |
3132 | ast::Type parentType = parentExpr->getType(); |
3133 | if (ast::OperationType opType = dyn_cast<ast::OperationType>(Val&: parentType)) |
3134 | codeCompleteContext->codeCompleteOperationMemberAccess(opType); |
3135 | else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(Val&: parentType)) |
3136 | codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); |
3137 | return failure(); |
3138 | } |
3139 | |
3140 | LogicalResult |
3141 | Parser::codeCompleteAttributeName(std::optional<StringRef> opName) { |
3142 | if (opName) |
3143 | codeCompleteContext->codeCompleteOperationAttributeName(opName: *opName); |
3144 | return failure(); |
3145 | } |
3146 | |
3147 | LogicalResult |
3148 | Parser::codeCompleteConstraintName(ast::Type inferredType, |
3149 | bool allowInlineTypeConstraints) { |
3150 | codeCompleteContext->codeCompleteConstraintName( |
3151 | currentType: inferredType, allowInlineTypeConstraints, scope: curDeclScope); |
3152 | return failure(); |
3153 | } |
3154 | |
3155 | LogicalResult Parser::codeCompleteDialectName() { |
3156 | codeCompleteContext->codeCompleteDialectName(); |
3157 | return failure(); |
3158 | } |
3159 | |
3160 | LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { |
3161 | codeCompleteContext->codeCompleteOperationName(dialectName); |
3162 | return failure(); |
3163 | } |
3164 | |
3165 | LogicalResult Parser::codeCompletePatternMetadata() { |
3166 | codeCompleteContext->codeCompletePatternMetadata(); |
3167 | return failure(); |
3168 | } |
3169 | |
3170 | LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { |
3171 | codeCompleteContext->codeCompleteIncludeFilename(curPath); |
3172 | return failure(); |
3173 | } |
3174 | |
3175 | void Parser::codeCompleteCallSignature(ast::Node *parent, |
3176 | unsigned currentNumArgs) { |
3177 | ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parent); |
3178 | if (!callableDecl) |
3179 | return; |
3180 | |
3181 | codeCompleteContext->codeCompleteCallSignature(callable: callableDecl, currentNumArgs); |
3182 | } |
3183 | |
3184 | void Parser::codeCompleteOperationOperandsSignature( |
3185 | std::optional<StringRef> opName, unsigned currentNumOperands) { |
3186 | codeCompleteContext->codeCompleteOperationOperandsSignature( |
3187 | opName, currentNumOperands); |
3188 | } |
3189 | |
3190 | void Parser::codeCompleteOperationResultsSignature( |
3191 | std::optional<StringRef> opName, unsigned currentNumResults) { |
3192 | codeCompleteContext->codeCompleteOperationResultsSignature(opName, |
3193 | currentNumResults); |
3194 | } |
3195 | |
3196 | //===----------------------------------------------------------------------===// |
3197 | // Parser |
3198 | //===----------------------------------------------------------------------===// |
3199 | |
3200 | FailureOr<ast::Module *> |
3201 | mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, |
3202 | bool enableDocumentation, |
3203 | CodeCompleteContext *codeCompleteContext) { |
3204 | Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext); |
3205 | return parser.parseModule(); |
3206 | } |
3207 |
Definitions
- Parser
- Parser
- ParserContext
- OpResultTypeContext
- pushDeclScope
- pushDeclScope
- popDeclScope
- lookupODSOperation
- processDoc
- processAndFormatDoc
- ParsedPatternMetadata
- consumeIf
- consumeToken
- consumeToken
- resetToken
- parseToken
- emitError
- emitError
- emitErrorAndNote
- parseModule
- parseModuleBody
- convertOpToValue
- convertExpressionTo
- convertOpExpressionTo
- convertTupleExpressionTo
- parseDirective
- parseInclude
- parseTdInclude
- processTdIncludeRecords
- createODSNativePDLLConstraintDecl
- createODSNativePDLLConstraintDecl
- parseTopLevelDecl
- parseNamedAttributeDecl
- parseLambdaBody
- parseArgumentDecl
- parseResultDecl
- parseUserConstraintDecl
- parseInlineUserConstraintDecl
- parseUserPDLLConstraintDecl
- parseUserRewriteDecl
- parseInlineUserRewriteDecl
- parseUserPDLLRewriteDecl
- parseUserConstraintOrRewriteDecl
- parseUserNativeConstraintOrRewriteDecl
- parseUserConstraintOrRewriteSignature
- validateUserConstraintOrRewriteReturn
- parsePatternLambdaBody
- parsePatternDecl
- parsePatternDeclMetadata
- parseTypeConstraintExpr
- checkDefineNamedDecl
- defineVariableDecl
- defineVariableDecl
- parseVariableDeclConstraintList
- parseConstraint
- parseArgOrResultConstraint
- parseExpr
- parseAttributeExpr
- parseCallExpr
- parseDeclRefExpr
- parseIdentifierExpr
- parseInlineConstraintLambdaExpr
- parseInlineRewriteLambdaExpr
- parseMemberAccessExpr
- parseNegatedExpr
- parseOperationName
- parseWrappedOperationName
- parseOperationExpr
- parseTupleExpr
- parseTypeExpr
- parseUnderscoreExpr
- parseStmt
- parseCompoundStmt
- parseEraseStmt
- parseLetStmt
- parseReplaceStmt
- parseReturnStmt
- parseRewriteStmt
- tryExtractCallableDecl
- createPatternDecl
- createUserConstraintRewriteResultType
- createUserPDLLConstraintOrRewriteDecl
- createVariableDecl
- createArgOrResultVariableDecl
- validateVariableConstraints
- validateVariableConstraint
- validateTypeConstraintExpr
- validateTypeRangeConstraintExpr
- createCallExpr
- createDeclRefExpr
- createInlineVariableExpr
- createMemberAccessExpr
- validateMemberAccess
- createOperationExpr
- validateOperationOperands
- validateOperationResults
- checkOperationResultTypeInferrence
- validateOperationOperandsOrResults
- createTupleExpr
- createEraseStmt
- createReplaceStmt
- createRewriteStmt
- codeCompleteMemberAccess
- codeCompleteAttributeName
- codeCompleteConstraintName
- codeCompleteDialectName
- codeCompleteOperationName
- codeCompletePatternMetadata
- codeCompleteIncludeFilename
- codeCompleteCallSignature
- codeCompleteOperationOperandsSignature
- codeCompleteOperationResultsSignature
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more