1 | //===- Parser.cpp ---------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/PDLL/Parser/Parser.h" |
10 | #include "Lexer.h" |
11 | #include "mlir/Support/IndentedOstream.h" |
12 | #include "mlir/Support/LogicalResult.h" |
13 | #include "mlir/TableGen/Argument.h" |
14 | #include "mlir/TableGen/Attribute.h" |
15 | #include "mlir/TableGen/Constraint.h" |
16 | #include "mlir/TableGen/Format.h" |
17 | #include "mlir/TableGen/Operator.h" |
18 | #include "mlir/Tools/PDLL/AST/Context.h" |
19 | #include "mlir/Tools/PDLL/AST/Diagnostic.h" |
20 | #include "mlir/Tools/PDLL/AST/Nodes.h" |
21 | #include "mlir/Tools/PDLL/AST/Types.h" |
22 | #include "mlir/Tools/PDLL/ODS/Constraint.h" |
23 | #include "mlir/Tools/PDLL/ODS/Context.h" |
24 | #include "mlir/Tools/PDLL/ODS/Operation.h" |
25 | #include "mlir/Tools/PDLL/Parser/CodeComplete.h" |
26 | #include "llvm/ADT/StringExtras.h" |
27 | #include "llvm/ADT/TypeSwitch.h" |
28 | #include "llvm/Support/FormatVariadic.h" |
29 | #include "llvm/Support/ManagedStatic.h" |
30 | #include "llvm/Support/SaveAndRestore.h" |
31 | #include "llvm/Support/ScopedPrinter.h" |
32 | #include "llvm/TableGen/Error.h" |
33 | #include "llvm/TableGen/Parser.h" |
34 | #include <optional> |
35 | #include <string> |
36 | |
37 | using namespace mlir; |
38 | using namespace mlir::pdll; |
39 | |
40 | //===----------------------------------------------------------------------===// |
41 | // Parser |
42 | //===----------------------------------------------------------------------===// |
43 | |
44 | namespace { |
45 | class Parser { |
46 | public: |
47 | Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr, |
48 | bool enableDocumentation, CodeCompleteContext *codeCompleteContext) |
49 | : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), |
50 | curToken(lexer.lexToken()), enableDocumentation(enableDocumentation), |
51 | typeTy(ast::TypeType::get(context&: ctx)), valueTy(ast::ValueType::get(context&: ctx)), |
52 | typeRangeTy(ast::TypeRangeType::get(context&: ctx)), |
53 | valueRangeTy(ast::ValueRangeType::get(context&: ctx)), |
54 | attrTy(ast::AttributeType::get(context&: ctx)), |
55 | codeCompleteContext(codeCompleteContext) {} |
56 | |
57 | /// Try to parse a new module. Returns nullptr in the case of failure. |
58 | FailureOr<ast::Module *> parseModule(); |
59 | |
60 | private: |
61 | /// The current context of the parser. It allows for the parser to know a bit |
62 | /// about the construct it is nested within during parsing. This is used |
63 | /// specifically to provide additional verification during parsing, e.g. to |
64 | /// prevent using rewrites within a match context, matcher constraints within |
65 | /// a rewrite section, etc. |
66 | enum class ParserContext { |
67 | /// The parser is in the global context. |
68 | Global, |
69 | /// The parser is currently within a Constraint, which disallows all types |
70 | /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). |
71 | Constraint, |
72 | /// The parser is currently within the matcher portion of a Pattern, which |
73 | /// is allows a terminal operation rewrite statement but no other rewrite |
74 | /// transformations. |
75 | PatternMatch, |
76 | /// The parser is currently within a Rewrite, which disallows calls to |
77 | /// constraints, requires operation expressions to have names, etc. |
78 | Rewrite, |
79 | }; |
80 | |
81 | /// The current specification context of an operations result type. This |
82 | /// indicates how the result types of an operation may be inferred. |
83 | enum class OpResultTypeContext { |
84 | /// The result types of the operation are not known to be inferred. |
85 | Explicit, |
86 | /// The result types of the operation are inferred from the root input of a |
87 | /// `replace` statement. |
88 | Replacement, |
89 | /// The result types of the operation are inferred by using the |
90 | /// `InferTypeOpInterface` interface provided by the operation. |
91 | Interface, |
92 | }; |
93 | |
94 | //===--------------------------------------------------------------------===// |
95 | // Parsing |
96 | //===--------------------------------------------------------------------===// |
97 | |
98 | /// Push a new decl scope onto the lexer. |
99 | ast::DeclScope *pushDeclScope() { |
100 | ast::DeclScope *newScope = |
101 | new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); |
102 | return (curDeclScope = newScope); |
103 | } |
104 | void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } |
105 | |
106 | /// Pop the last decl scope from the lexer. |
107 | void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } |
108 | |
109 | /// Parse the body of an AST module. |
110 | LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls); |
111 | |
112 | /// Try to convert the given expression to `type`. Returns failure and emits |
113 | /// an error if a conversion is not viable. On failure, `noteAttachFn` is |
114 | /// invoked to attach notes to the emitted error diagnostic. On success, |
115 | /// `expr` is updated to the expression used to convert to `type`. |
116 | LogicalResult convertExpressionTo( |
117 | ast::Expr *&expr, ast::Type type, |
118 | function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {}); |
119 | LogicalResult |
120 | convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType, |
121 | ast::Type type, |
122 | function_ref<ast::InFlightDiagnostic()> emitErrorFn); |
123 | LogicalResult convertTupleExpressionTo( |
124 | ast::Expr *&expr, ast::TupleType exprType, ast::Type type, |
125 | function_ref<ast::InFlightDiagnostic()> emitErrorFn, |
126 | function_ref<void(ast::Diagnostic &diag)> noteAttachFn); |
127 | |
128 | /// Given an operation expression, convert it to a Value or ValueRange |
129 | /// typed expression. |
130 | ast::Expr *convertOpToValue(const ast::Expr *opExpr); |
131 | |
132 | /// Lookup ODS information for the given operation, returns nullptr if no |
133 | /// information is found. |
134 | const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) { |
135 | return opName ? ctx.getODSContext().lookupOperation(name: *opName) : nullptr; |
136 | } |
137 | |
138 | /// Process the given documentation string, or return an empty string if |
139 | /// documentation isn't enabled. |
140 | StringRef processDoc(StringRef doc) { |
141 | return enableDocumentation ? doc : StringRef(); |
142 | } |
143 | |
144 | /// Process the given documentation string and format it, or return an empty |
145 | /// string if documentation isn't enabled. |
146 | std::string processAndFormatDoc(const Twine &doc) { |
147 | if (!enableDocumentation) |
148 | return "" ; |
149 | std::string docStr; |
150 | { |
151 | llvm::raw_string_ostream docOS(docStr); |
152 | std::string tmpDocStr = doc.str(); |
153 | raw_indented_ostream(docOS).printReindented( |
154 | str: StringRef(tmpDocStr).rtrim(Chars: " \t" )); |
155 | } |
156 | return docStr; |
157 | } |
158 | |
159 | //===--------------------------------------------------------------------===// |
160 | // Directives |
161 | |
162 | LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls); |
163 | LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls); |
164 | LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, |
165 | SmallVectorImpl<ast::Decl *> &decls); |
166 | |
167 | /// Process the records of a parsed tablegen include file. |
168 | void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, |
169 | SmallVectorImpl<ast::Decl *> &decls); |
170 | |
171 | /// Create a user defined native constraint for a constraint imported from |
172 | /// ODS. |
173 | template <typename ConstraintT> |
174 | ast::Decl * |
175 | createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, |
176 | SMRange loc, ast::Type type, |
177 | StringRef nativeType, StringRef docString); |
178 | template <typename ConstraintT> |
179 | ast::Decl * |
180 | createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, |
181 | SMRange loc, ast::Type type, |
182 | StringRef nativeType); |
183 | |
184 | //===--------------------------------------------------------------------===// |
185 | // Decls |
186 | |
187 | /// This structure contains the set of pattern metadata that may be parsed. |
188 | struct ParsedPatternMetadata { |
189 | std::optional<uint16_t> benefit; |
190 | bool hasBoundedRecursion = false; |
191 | }; |
192 | |
193 | FailureOr<ast::Decl *> parseTopLevelDecl(); |
194 | FailureOr<ast::NamedAttributeDecl *> |
195 | parseNamedAttributeDecl(std::optional<StringRef> parentOpName); |
196 | |
197 | /// Parse an argument variable as part of the signature of a |
198 | /// UserConstraintDecl or UserRewriteDecl. |
199 | FailureOr<ast::VariableDecl *> parseArgumentDecl(); |
200 | |
201 | /// Parse a result variable as part of the signature of a UserConstraintDecl |
202 | /// or UserRewriteDecl. |
203 | FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum); |
204 | |
205 | /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being |
206 | /// defined in a non-global context. |
207 | FailureOr<ast::UserConstraintDecl *> |
208 | parseUserConstraintDecl(bool isInline = false); |
209 | |
210 | /// Parse an inline UserConstraintDecl. An inline decl is one defined in a |
211 | /// non-global context, such as within a Pattern/Constraint/etc. |
212 | FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl(); |
213 | |
214 | /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using |
215 | /// PDLL constructs. |
216 | FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl( |
217 | const ast::Name &name, bool isInline, |
218 | ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, |
219 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType); |
220 | |
221 | /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being |
222 | /// defined in a non-global context. |
223 | FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false); |
224 | |
225 | /// Parse an inline UserRewriteDecl. An inline decl is one defined in a |
226 | /// non-global context, such as within a Pattern/Rewrite/etc. |
227 | FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl(); |
228 | |
229 | /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using |
230 | /// PDLL constructs. |
231 | FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl( |
232 | const ast::Name &name, bool isInline, |
233 | ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, |
234 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType); |
235 | |
236 | /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have |
237 | /// effectively the same syntax, and only differ on slight semantics (given |
238 | /// the different parsing contexts). |
239 | template <typename T, typename ParseUserPDLLDeclFnT> |
240 | FailureOr<T *> parseUserConstraintOrRewriteDecl( |
241 | ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, |
242 | StringRef anonymousNamePrefix, bool isInline); |
243 | |
244 | /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. |
245 | /// These decls have effectively the same syntax. |
246 | template <typename T> |
247 | FailureOr<T *> parseUserNativeConstraintOrRewriteDecl( |
248 | const ast::Name &name, bool isInline, |
249 | ArrayRef<ast::VariableDecl *> arguments, |
250 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType); |
251 | |
252 | /// Parse the functional signature (i.e. the arguments and results) of a |
253 | /// UserConstraintDecl or UserRewriteDecl. |
254 | LogicalResult parseUserConstraintOrRewriteSignature( |
255 | SmallVectorImpl<ast::VariableDecl *> &arguments, |
256 | SmallVectorImpl<ast::VariableDecl *> &results, |
257 | ast::DeclScope *&argumentScope, ast::Type &resultType); |
258 | |
259 | /// Validate the return (which if present is specified by bodyIt) of a |
260 | /// UserConstraintDecl or UserRewriteDecl. |
261 | LogicalResult validateUserConstraintOrRewriteReturn( |
262 | StringRef declType, ast::CompoundStmt *body, |
263 | ArrayRef<ast::Stmt *>::iterator bodyIt, |
264 | ArrayRef<ast::Stmt *>::iterator bodyE, |
265 | ArrayRef<ast::VariableDecl *> results, ast::Type &resultType); |
266 | |
267 | FailureOr<ast::CompoundStmt *> |
268 | parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, |
269 | bool expectTerminalSemicolon = true); |
270 | FailureOr<ast::CompoundStmt *> parsePatternLambdaBody(); |
271 | FailureOr<ast::Decl *> parsePatternDecl(); |
272 | LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); |
273 | |
274 | /// Check to see if a decl has already been defined with the given name, if |
275 | /// one has emit and error and return failure. Returns success otherwise. |
276 | LogicalResult checkDefineNamedDecl(const ast::Name &name); |
277 | |
278 | /// Try to define a variable decl with the given components, returns the |
279 | /// variable on success. |
280 | FailureOr<ast::VariableDecl *> |
281 | defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, |
282 | ast::Expr *initExpr, |
283 | ArrayRef<ast::ConstraintRef> constraints); |
284 | FailureOr<ast::VariableDecl *> |
285 | defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, |
286 | ArrayRef<ast::ConstraintRef> constraints); |
287 | |
288 | /// Parse the constraint reference list for a variable decl. |
289 | LogicalResult parseVariableDeclConstraintList( |
290 | SmallVectorImpl<ast::ConstraintRef> &constraints); |
291 | |
292 | /// Parse the expression used within a type constraint, e.g. Attr<type-expr>. |
293 | FailureOr<ast::Expr *> parseTypeConstraintExpr(); |
294 | |
295 | /// Try to parse a single reference to a constraint. `typeConstraint` is the |
296 | /// location of a previously parsed type constraint for the entity that will |
297 | /// be constrained by the parsed constraint. `existingConstraints` are any |
298 | /// existing constraints that have already been parsed for the same entity |
299 | /// that will be constrained by this constraint. `allowInlineTypeConstraints` |
300 | /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`. |
301 | FailureOr<ast::ConstraintRef> |
302 | parseConstraint(std::optional<SMRange> &typeConstraint, |
303 | ArrayRef<ast::ConstraintRef> existingConstraints, |
304 | bool allowInlineTypeConstraints); |
305 | |
306 | /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl |
307 | /// argument or result variable. The constraints for these variables do not |
308 | /// allow inline type constraints, and only permit a single constraint. |
309 | FailureOr<ast::ConstraintRef> parseArgOrResultConstraint(); |
310 | |
311 | //===--------------------------------------------------------------------===// |
312 | // Exprs |
313 | |
314 | FailureOr<ast::Expr *> parseExpr(); |
315 | |
316 | /// Identifier expressions. |
317 | FailureOr<ast::Expr *> parseAttributeExpr(); |
318 | FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr, |
319 | bool isNegated = false); |
320 | FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc); |
321 | FailureOr<ast::Expr *> parseIdentifierExpr(); |
322 | FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr(); |
323 | FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr(); |
324 | FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr); |
325 | FailureOr<ast::Expr *> parseNegatedExpr(); |
326 | FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false); |
327 | FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName); |
328 | FailureOr<ast::Expr *> |
329 | parseOperationExpr(OpResultTypeContext inputResultTypeContext = |
330 | OpResultTypeContext::Explicit); |
331 | FailureOr<ast::Expr *> parseTupleExpr(); |
332 | FailureOr<ast::Expr *> parseTypeExpr(); |
333 | FailureOr<ast::Expr *> parseUnderscoreExpr(); |
334 | |
335 | //===--------------------------------------------------------------------===// |
336 | // Stmts |
337 | |
338 | FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true); |
339 | FailureOr<ast::CompoundStmt *> parseCompoundStmt(); |
340 | FailureOr<ast::EraseStmt *> parseEraseStmt(); |
341 | FailureOr<ast::LetStmt *> parseLetStmt(); |
342 | FailureOr<ast::ReplaceStmt *> parseReplaceStmt(); |
343 | FailureOr<ast::ReturnStmt *> parseReturnStmt(); |
344 | FailureOr<ast::RewriteStmt *> parseRewriteStmt(); |
345 | |
346 | //===--------------------------------------------------------------------===// |
347 | // Creation+Analysis |
348 | //===--------------------------------------------------------------------===// |
349 | |
350 | //===--------------------------------------------------------------------===// |
351 | // Decls |
352 | |
353 | /// Try to extract a callable from the given AST node. Returns nullptr on |
354 | /// failure. |
355 | ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); |
356 | |
357 | /// Try to create a pattern decl with the given components, returning the |
358 | /// Pattern on success. |
359 | FailureOr<ast::PatternDecl *> |
360 | createPatternDecl(SMRange loc, const ast::Name *name, |
361 | const ParsedPatternMetadata &metadata, |
362 | ast::CompoundStmt *body); |
363 | |
364 | /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set |
365 | /// of results, defined as part of the signature. |
366 | ast::Type |
367 | createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results); |
368 | |
369 | /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. |
370 | template <typename T> |
371 | FailureOr<T *> createUserPDLLConstraintOrRewriteDecl( |
372 | const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, |
373 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType, |
374 | ast::CompoundStmt *body); |
375 | |
376 | /// Try to create a variable decl with the given components, returning the |
377 | /// Variable on success. |
378 | FailureOr<ast::VariableDecl *> |
379 | createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, |
380 | ArrayRef<ast::ConstraintRef> constraints); |
381 | |
382 | /// Create a variable for an argument or result defined as part of the |
383 | /// signature of a UserConstraintDecl/UserRewriteDecl. |
384 | FailureOr<ast::VariableDecl *> |
385 | createArgOrResultVariableDecl(StringRef name, SMRange loc, |
386 | const ast::ConstraintRef &constraint); |
387 | |
388 | /// Validate the constraints used to constraint a variable decl. |
389 | /// `inferredType` is the type of the variable inferred by the constraints |
390 | /// within the list, and is updated to the most refined type as determined by |
391 | /// the constraints. Returns success if the constraint list is valid, failure |
392 | /// otherwise. |
393 | LogicalResult |
394 | validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, |
395 | ast::Type &inferredType); |
396 | /// Validate a single reference to a constraint. `inferredType` contains the |
397 | /// currently inferred variabled type and is refined within the type defined |
398 | /// by the constraint. Returns success if the constraint is valid, failure |
399 | /// otherwise. |
400 | LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, |
401 | ast::Type &inferredType); |
402 | LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); |
403 | LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); |
404 | |
405 | //===--------------------------------------------------------------------===// |
406 | // Exprs |
407 | |
408 | FailureOr<ast::CallExpr *> |
409 | createCallExpr(SMRange loc, ast::Expr *parentExpr, |
410 | MutableArrayRef<ast::Expr *> arguments, |
411 | bool isNegated = false); |
412 | FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl); |
413 | FailureOr<ast::DeclRefExpr *> |
414 | createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, |
415 | ArrayRef<ast::ConstraintRef> constraints); |
416 | FailureOr<ast::MemberAccessExpr *> |
417 | createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); |
418 | |
419 | /// Validate the member access `name` into the given parent expression. On |
420 | /// success, this also returns the type of the member accessed. |
421 | FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr, |
422 | StringRef name, SMRange loc); |
423 | FailureOr<ast::OperationExpr *> |
424 | createOperationExpr(SMRange loc, const ast::OpNameDecl *name, |
425 | OpResultTypeContext resultTypeContext, |
426 | SmallVectorImpl<ast::Expr *> &operands, |
427 | MutableArrayRef<ast::NamedAttributeDecl *> attributes, |
428 | SmallVectorImpl<ast::Expr *> &results); |
429 | LogicalResult |
430 | validateOperationOperands(SMRange loc, std::optional<StringRef> name, |
431 | const ods::Operation *odsOp, |
432 | SmallVectorImpl<ast::Expr *> &operands); |
433 | LogicalResult validateOperationResults(SMRange loc, |
434 | std::optional<StringRef> name, |
435 | const ods::Operation *odsOp, |
436 | SmallVectorImpl<ast::Expr *> &results); |
437 | void checkOperationResultTypeInferrence(SMRange loc, StringRef name, |
438 | const ods::Operation *odsOp); |
439 | LogicalResult validateOperationOperandsOrResults( |
440 | StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc, |
441 | std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values, |
442 | ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, |
443 | ast::RangeType rangeTy); |
444 | FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, |
445 | ArrayRef<ast::Expr *> elements, |
446 | ArrayRef<StringRef> elementNames); |
447 | |
448 | //===--------------------------------------------------------------------===// |
449 | // Stmts |
450 | |
451 | FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp); |
452 | FailureOr<ast::ReplaceStmt *> |
453 | createReplaceStmt(SMRange loc, ast::Expr *rootOp, |
454 | MutableArrayRef<ast::Expr *> replValues); |
455 | FailureOr<ast::RewriteStmt *> |
456 | createRewriteStmt(SMRange loc, ast::Expr *rootOp, |
457 | ast::CompoundStmt *rewriteBody); |
458 | |
459 | //===--------------------------------------------------------------------===// |
460 | // Code Completion |
461 | //===--------------------------------------------------------------------===// |
462 | |
463 | /// The set of various code completion methods. Every completion method |
464 | /// returns `failure` to stop the parsing process after providing completion |
465 | /// results. |
466 | |
467 | LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); |
468 | LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName); |
469 | LogicalResult codeCompleteConstraintName(ast::Type inferredType, |
470 | bool allowInlineTypeConstraints); |
471 | LogicalResult codeCompleteDialectName(); |
472 | LogicalResult codeCompleteOperationName(StringRef dialectName); |
473 | LogicalResult codeCompletePatternMetadata(); |
474 | LogicalResult codeCompleteIncludeFilename(StringRef curPath); |
475 | |
476 | void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs); |
477 | void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName, |
478 | unsigned currentNumOperands); |
479 | void codeCompleteOperationResultsSignature(std::optional<StringRef> opName, |
480 | unsigned currentNumResults); |
481 | |
482 | //===--------------------------------------------------------------------===// |
483 | // Lexer Utilities |
484 | //===--------------------------------------------------------------------===// |
485 | |
486 | /// If the current token has the specified kind, consume it and return true. |
487 | /// If not, return false. |
488 | bool consumeIf(Token::Kind kind) { |
489 | if (curToken.isNot(k: kind)) |
490 | return false; |
491 | consumeToken(kind); |
492 | return true; |
493 | } |
494 | |
495 | /// Advance the current lexer onto the next token. |
496 | void consumeToken() { |
497 | assert(curToken.isNot(Token::eof, Token::error) && |
498 | "shouldn't advance past EOF or errors" ); |
499 | curToken = lexer.lexToken(); |
500 | } |
501 | |
502 | /// Advance the current lexer onto the next token, asserting what the expected |
503 | /// current token is. This is preferred to the above method because it leads |
504 | /// to more self-documenting code with better checking. |
505 | void consumeToken(Token::Kind kind) { |
506 | assert(curToken.is(kind) && "consumed an unexpected token" ); |
507 | consumeToken(); |
508 | } |
509 | |
510 | /// Reset the lexer to the location at the given position. |
511 | void resetToken(SMRange tokLoc) { |
512 | lexer.resetPointer(newPointer: tokLoc.Start.getPointer()); |
513 | curToken = lexer.lexToken(); |
514 | } |
515 | |
516 | /// Consume the specified token if present and return success. On failure, |
517 | /// output a diagnostic and return failure. |
518 | LogicalResult parseToken(Token::Kind kind, const Twine &msg) { |
519 | if (curToken.getKind() != kind) |
520 | return emitError(loc: curToken.getLoc(), msg); |
521 | consumeToken(); |
522 | return success(); |
523 | } |
524 | LogicalResult emitError(SMRange loc, const Twine &msg) { |
525 | lexer.emitError(loc, msg); |
526 | return failure(); |
527 | } |
528 | LogicalResult emitError(const Twine &msg) { |
529 | return emitError(loc: curToken.getLoc(), msg); |
530 | } |
531 | LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, |
532 | const Twine ¬e) { |
533 | lexer.emitErrorAndNote(loc, msg, noteLoc, note); |
534 | return failure(); |
535 | } |
536 | |
537 | //===--------------------------------------------------------------------===// |
538 | // Fields |
539 | //===--------------------------------------------------------------------===// |
540 | |
541 | /// The owning AST context. |
542 | ast::Context &ctx; |
543 | |
544 | /// The lexer of this parser. |
545 | Lexer lexer; |
546 | |
547 | /// The current token within the lexer. |
548 | Token curToken; |
549 | |
550 | /// A flag indicating if the parser should add documentation to AST nodes when |
551 | /// viable. |
552 | bool enableDocumentation; |
553 | |
554 | /// The most recently defined decl scope. |
555 | ast::DeclScope *curDeclScope = nullptr; |
556 | llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator; |
557 | |
558 | /// The current context of the parser. |
559 | ParserContext parserContext = ParserContext::Global; |
560 | |
561 | /// Cached types to simplify verification and expression creation. |
562 | ast::Type typeTy, valueTy; |
563 | ast::RangeType typeRangeTy, valueRangeTy; |
564 | ast::Type attrTy; |
565 | |
566 | /// A counter used when naming anonymous constraints and rewrites. |
567 | unsigned anonymousDeclNameCounter = 0; |
568 | |
569 | /// The optional code completion context. |
570 | CodeCompleteContext *codeCompleteContext; |
571 | }; |
572 | } // namespace |
573 | |
574 | FailureOr<ast::Module *> Parser::parseModule() { |
575 | SMLoc moduleLoc = curToken.getStartLoc(); |
576 | pushDeclScope(); |
577 | |
578 | // Parse the top-level decls of the module. |
579 | SmallVector<ast::Decl *> decls; |
580 | if (failed(result: parseModuleBody(decls))) |
581 | return popDeclScope(), failure(); |
582 | |
583 | popDeclScope(); |
584 | return ast::Module::create(ctx, loc: moduleLoc, children: decls); |
585 | } |
586 | |
587 | LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) { |
588 | while (curToken.isNot(k: Token::eof)) { |
589 | if (curToken.is(k: Token::directive)) { |
590 | if (failed(result: parseDirective(decls))) |
591 | return failure(); |
592 | continue; |
593 | } |
594 | |
595 | FailureOr<ast::Decl *> decl = parseTopLevelDecl(); |
596 | if (failed(result: decl)) |
597 | return failure(); |
598 | decls.push_back(Elt: *decl); |
599 | } |
600 | return success(); |
601 | } |
602 | |
603 | ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) { |
604 | return ast::AllResultsMemberAccessExpr::create(ctx, loc: opExpr->getLoc(), parentExpr: opExpr, |
605 | type: valueRangeTy); |
606 | } |
607 | |
608 | LogicalResult Parser::convertExpressionTo( |
609 | ast::Expr *&expr, ast::Type type, |
610 | function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { |
611 | ast::Type exprType = expr->getType(); |
612 | if (exprType == type) |
613 | return success(); |
614 | |
615 | auto emitConvertError = [&]() -> ast::InFlightDiagnostic { |
616 | ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError( |
617 | loc: expr->getLoc(), msg: llvm::formatv(Fmt: "unable to convert expression of type " |
618 | "`{0}` to the expected type of " |
619 | "`{1}`" , |
620 | Vals&: exprType, Vals&: type)); |
621 | if (noteAttachFn) |
622 | noteAttachFn(*diag); |
623 | return diag; |
624 | }; |
625 | |
626 | if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) |
627 | return convertOpExpressionTo(expr, exprType: exprOpType, type, emitErrorFn: emitConvertError); |
628 | |
629 | // FIXME: Decide how to allow/support converting a single result to multiple, |
630 | // and multiple to a single result. For now, we just allow Single->Range, |
631 | // but this isn't something really supported in the PDL dialect. We should |
632 | // figure out some way to support both. |
633 | if ((exprType == valueTy || exprType == valueRangeTy) && |
634 | (type == valueTy || type == valueRangeTy)) |
635 | return success(); |
636 | if ((exprType == typeTy || exprType == typeRangeTy) && |
637 | (type == typeTy || type == typeRangeTy)) |
638 | return success(); |
639 | |
640 | // Handle tuple types. |
641 | if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) |
642 | return convertTupleExpressionTo(expr, exprType: exprTupleType, type, emitErrorFn: emitConvertError, |
643 | noteAttachFn); |
644 | |
645 | return emitConvertError(); |
646 | } |
647 | |
648 | LogicalResult Parser::convertOpExpressionTo( |
649 | ast::Expr *&expr, ast::OperationType exprType, ast::Type type, |
650 | function_ref<ast::InFlightDiagnostic()> emitErrorFn) { |
651 | // Two operation types are compatible if they have the same name, or if the |
652 | // expected type is more general. |
653 | if (auto opType = type.dyn_cast<ast::OperationType>()) { |
654 | if (opType.getName()) |
655 | return emitErrorFn(); |
656 | return success(); |
657 | } |
658 | |
659 | // An operation can always convert to a ValueRange. |
660 | if (type == valueRangeTy) { |
661 | expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr, |
662 | type: valueRangeTy); |
663 | return success(); |
664 | } |
665 | |
666 | // Allow conversion to a single value by constraining the result range. |
667 | if (type == valueTy) { |
668 | // If the operation is registered, we can verify if it can ever have a |
669 | // single result. |
670 | if (const ods::Operation *odsOp = exprType.getODSOperation()) { |
671 | if (odsOp->getResults().empty()) { |
672 | return emitErrorFn()->attachNote( |
673 | msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined " |
674 | "with zero results" , |
675 | Vals: odsOp->getName()), |
676 | noteLoc: odsOp->getLoc()); |
677 | } |
678 | |
679 | unsigned numSingleResults = llvm::count_if( |
680 | Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) { |
681 | return result.getVariableLengthKind() == |
682 | ods::VariableLengthKind::Single; |
683 | }); |
684 | if (numSingleResults > 1) { |
685 | return emitErrorFn()->attachNote( |
686 | msg: llvm::formatv(Fmt: "see the definition of `{0}`, which was defined " |
687 | "with at least {1} results" , |
688 | Vals: odsOp->getName(), Vals&: numSingleResults), |
689 | noteLoc: odsOp->getLoc()); |
690 | } |
691 | } |
692 | |
693 | expr = ast::AllResultsMemberAccessExpr::create(ctx, loc: expr->getLoc(), parentExpr: expr, |
694 | type: valueTy); |
695 | return success(); |
696 | } |
697 | return emitErrorFn(); |
698 | } |
699 | |
700 | LogicalResult Parser::convertTupleExpressionTo( |
701 | ast::Expr *&expr, ast::TupleType exprType, ast::Type type, |
702 | function_ref<ast::InFlightDiagnostic()> emitErrorFn, |
703 | function_ref<void(ast::Diagnostic &diag)> noteAttachFn) { |
704 | // Handle conversions between tuples. |
705 | if (auto tupleType = type.dyn_cast<ast::TupleType>()) { |
706 | if (tupleType.size() != exprType.size()) |
707 | return emitErrorFn(); |
708 | |
709 | // Build a new tuple expression using each of the elements of the current |
710 | // tuple. |
711 | SmallVector<ast::Expr *> newExprs; |
712 | for (unsigned i = 0, e = exprType.size(); i < e; ++i) { |
713 | newExprs.push_back(Elt: ast::MemberAccessExpr::create( |
714 | ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i), |
715 | type: exprType.getElementTypes()[i])); |
716 | |
717 | auto diagFn = [&](ast::Diagnostic &diag) { |
718 | diag.attachNote(msg: llvm::formatv(Fmt: "when converting element #{0} of `{1}`" , |
719 | Vals&: i, Vals&: exprType)); |
720 | if (noteAttachFn) |
721 | noteAttachFn(diag); |
722 | }; |
723 | if (failed(result: convertExpressionTo(expr&: newExprs.back(), |
724 | type: tupleType.getElementTypes()[i], noteAttachFn: diagFn))) |
725 | return failure(); |
726 | } |
727 | expr = ast::TupleExpr::create(ctx, loc: expr->getLoc(), elements: newExprs, |
728 | elementNames: tupleType.getElementNames()); |
729 | return success(); |
730 | } |
731 | |
732 | // Handle conversion to a range. |
733 | auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes, |
734 | ast::RangeType resultTy) -> LogicalResult { |
735 | // TODO: We currently only allow range conversion within a rewrite context. |
736 | if (parserContext != ParserContext::Rewrite) { |
737 | return emitErrorFn()->attachNote(msg: "Tuple to Range conversion is currently " |
738 | "only allowed within a rewrite context" ); |
739 | } |
740 | |
741 | // All of the tuple elements must be allowed types. |
742 | for (ast::Type elementType : exprType.getElementTypes()) |
743 | if (!llvm::is_contained(Range&: allowedElementTypes, Element: elementType)) |
744 | return emitErrorFn(); |
745 | |
746 | // Build a new tuple expression using each of the elements of the current |
747 | // tuple. |
748 | SmallVector<ast::Expr *> newExprs; |
749 | for (unsigned i = 0, e = exprType.size(); i < e; ++i) { |
750 | newExprs.push_back(Elt: ast::MemberAccessExpr::create( |
751 | ctx, loc: expr->getLoc(), parentExpr: expr, memberName: llvm::to_string(Value: i), |
752 | type: exprType.getElementTypes()[i])); |
753 | } |
754 | expr = ast::RangeExpr::create(ctx, loc: expr->getLoc(), elements: newExprs, type: resultTy); |
755 | return success(); |
756 | }; |
757 | if (type == valueRangeTy) |
758 | return convertToRange({valueTy, valueRangeTy}, valueRangeTy); |
759 | if (type == typeRangeTy) |
760 | return convertToRange({typeTy, typeRangeTy}, typeRangeTy); |
761 | |
762 | return emitErrorFn(); |
763 | } |
764 | |
765 | //===----------------------------------------------------------------------===// |
766 | // Directives |
767 | |
768 | LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) { |
769 | StringRef directive = curToken.getSpelling(); |
770 | if (directive == "#include" ) |
771 | return parseInclude(decls); |
772 | |
773 | return emitError(msg: "unknown directive `" + directive + "`" ); |
774 | } |
775 | |
776 | LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) { |
777 | SMRange loc = curToken.getLoc(); |
778 | consumeToken(kind: Token::directive); |
779 | |
780 | // Handle code completion of the include file path. |
781 | if (curToken.is(k: Token::code_complete_string)) |
782 | return codeCompleteIncludeFilename(curPath: curToken.getStringValue()); |
783 | |
784 | // Parse the file being included. |
785 | if (!curToken.isString()) |
786 | return emitError(loc, |
787 | msg: "expected string file name after `include` directive" ); |
788 | SMRange fileLoc = curToken.getLoc(); |
789 | std::string filenameStr = curToken.getStringValue(); |
790 | StringRef filename = filenameStr; |
791 | consumeToken(); |
792 | |
793 | // Check the type of include. If ending with `.pdll`, this is another pdl file |
794 | // to be parsed along with the current module. |
795 | if (filename.ends_with(Suffix: ".pdll" )) { |
796 | if (failed(result: lexer.pushInclude(filename, includeLoc: fileLoc))) |
797 | return emitError(loc: fileLoc, |
798 | msg: "unable to open include file `" + filename + "`" ); |
799 | |
800 | // If we added the include successfully, parse it into the current module. |
801 | // Make sure to update to the next token after we finish parsing the nested |
802 | // file. |
803 | curToken = lexer.lexToken(); |
804 | LogicalResult result = parseModuleBody(decls); |
805 | curToken = lexer.lexToken(); |
806 | return result; |
807 | } |
808 | |
809 | // Otherwise, this must be a `.td` include. |
810 | if (filename.ends_with(Suffix: ".td" )) |
811 | return parseTdInclude(filename, fileLoc, decls); |
812 | |
813 | return emitError(loc: fileLoc, |
814 | msg: "expected include filename to end with `.pdll` or `.td`" ); |
815 | } |
816 | |
817 | LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, |
818 | SmallVectorImpl<ast::Decl *> &decls) { |
819 | llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); |
820 | |
821 | // Use the source manager to open the file, but don't yet add it. |
822 | std::string includedFile; |
823 | llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer = |
824 | parserSrcMgr.OpenIncludeFile(Filename: filename.str(), IncludedFile&: includedFile); |
825 | if (!includeBuffer) |
826 | return emitError(loc: fileLoc, msg: "unable to open include file `" + filename + "`" ); |
827 | |
828 | // Setup the source manager for parsing the tablegen file. |
829 | llvm::SourceMgr tdSrcMgr; |
830 | tdSrcMgr.AddNewSourceBuffer(F: std::move(*includeBuffer), IncludeLoc: SMLoc()); |
831 | tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); |
832 | |
833 | // This class provides a context argument for the llvm::SourceMgr diagnostic |
834 | // handler. |
835 | struct DiagHandlerContext { |
836 | Parser &parser; |
837 | StringRef filename; |
838 | llvm::SMRange loc; |
839 | } handlerContext{.parser: *this, .filename: filename, .loc: fileLoc}; |
840 | |
841 | // Set the diagnostic handler for the tablegen source manager. |
842 | tdSrcMgr.setDiagHandler( |
843 | DH: [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { |
844 | auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext); |
845 | (void)ctx->parser.emitError( |
846 | loc: ctx->loc, |
847 | msg: llvm::formatv(Fmt: "error while processing include file `{0}`: {1}" , |
848 | Vals&: ctx->filename, Vals: diag.getMessage())); |
849 | }, |
850 | Ctx: &handlerContext); |
851 | |
852 | // Parse the tablegen file. |
853 | llvm::RecordKeeper tdRecords; |
854 | if (llvm::TableGenParseFile(InputSrcMgr&: tdSrcMgr, Records&: tdRecords)) |
855 | return failure(); |
856 | |
857 | // Process the parsed records. |
858 | processTdIncludeRecords(tdRecords, decls); |
859 | |
860 | // After we are done processing, move all of the tablegen source buffers to |
861 | // the main parser source mgr. This allows for directly using source locations |
862 | // from the .td files without needing to remap them. |
863 | parserSrcMgr.takeSourceBuffersFrom(SrcMgr&: tdSrcMgr, MainBufferIncludeLoc: fileLoc.End); |
864 | return success(); |
865 | } |
866 | |
867 | void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, |
868 | SmallVectorImpl<ast::Decl *> &decls) { |
869 | // Return the length kind of the given value. |
870 | auto getLengthKind = [](const auto &value) { |
871 | if (value.isOptional()) |
872 | return ods::VariableLengthKind::Optional; |
873 | return value.isVariadic() ? ods::VariableLengthKind::Variadic |
874 | : ods::VariableLengthKind::Single; |
875 | }; |
876 | |
877 | // Insert a type constraint into the ODS context. |
878 | ods::Context &odsContext = ctx.getODSContext(); |
879 | auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) |
880 | -> const ods::TypeConstraint & { |
881 | return odsContext.insertTypeConstraint( |
882 | name: cst.constraint.getUniqueDefName(), |
883 | summary: processDoc(doc: cst.constraint.getSummary()), |
884 | cppClass: cst.constraint.getCPPClassName()); |
885 | }; |
886 | auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { |
887 | return {loc, llvm::SMLoc::getFromPointer(Ptr: loc.getPointer() + 1)}; |
888 | }; |
889 | |
890 | // Process the parsed tablegen records to build ODS information. |
891 | /// Operations. |
892 | for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Op" )) { |
893 | tblgen::Operator op(def); |
894 | |
895 | // Check to see if this operation is known to support type inferrence. |
896 | bool supportsResultTypeInferrence = |
897 | op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait" ); |
898 | |
899 | auto [odsOp, inserted] = odsContext.insertOperation( |
900 | name: op.getOperationName(), summary: processDoc(doc: op.getSummary()), |
901 | desc: processAndFormatDoc(doc: op.getDescription()), nativeClassName: op.getQualCppClassName(), |
902 | supportsResultTypeInferrence, loc: op.getLoc().front()); |
903 | |
904 | // Ignore operations that have already been added. |
905 | if (!inserted) |
906 | continue; |
907 | |
908 | for (const tblgen::NamedAttribute &attr : op.getAttributes()) { |
909 | odsOp->appendAttribute(name: attr.name, optional: attr.attr.isOptional(), |
910 | constraint: odsContext.insertAttributeConstraint( |
911 | name: attr.attr.getUniqueDefName(), |
912 | summary: processDoc(doc: attr.attr.getSummary()), |
913 | cppClass: attr.attr.getStorageType())); |
914 | } |
915 | for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { |
916 | odsOp->appendOperand(name: operand.name, variableLengthKind: getLengthKind(operand), |
917 | constraint: addTypeConstraint(operand)); |
918 | } |
919 | for (const tblgen::NamedTypeConstraint &result : op.getResults()) { |
920 | odsOp->appendResult(name: result.name, variableLengthKind: getLengthKind(result), |
921 | constraint: addTypeConstraint(result)); |
922 | } |
923 | } |
924 | |
925 | auto shouldBeSkipped = [this](llvm::Record *def) { |
926 | return def->isAnonymous() || curDeclScope->lookup(name: def->getName()) || |
927 | def->isSubClassOf(Name: "DeclareInterfaceMethods" ); |
928 | }; |
929 | |
930 | /// Attr constraints. |
931 | for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Attr" )) { |
932 | if (shouldBeSkipped(def)) |
933 | continue; |
934 | |
935 | tblgen::Attribute constraint(def); |
936 | decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>( |
937 | constraint, loc: convertLocToRange(def->getLoc().front()), type: attrTy, |
938 | nativeType: constraint.getStorageType())); |
939 | } |
940 | /// Type constraints. |
941 | for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "Type" )) { |
942 | if (shouldBeSkipped(def)) |
943 | continue; |
944 | |
945 | tblgen::TypeConstraint constraint(def); |
946 | decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>( |
947 | constraint, loc: convertLocToRange(def->getLoc().front()), type: typeTy, |
948 | nativeType: constraint.getCPPClassName())); |
949 | } |
950 | /// OpInterfaces. |
951 | ast::Type opTy = ast::OperationType::get(context&: ctx); |
952 | for (llvm::Record *def : tdRecords.getAllDerivedDefinitions(ClassName: "OpInterface" )) { |
953 | if (shouldBeSkipped(def)) |
954 | continue; |
955 | |
956 | SMRange loc = convertLocToRange(def->getLoc().front()); |
957 | |
958 | std::string cppClassName = |
959 | llvm::formatv(Fmt: "{0}::{1}" , Vals: def->getValueAsString(FieldName: "cppNamespace" ), |
960 | Vals: def->getValueAsString(FieldName: "cppInterfaceName" )) |
961 | .str(); |
962 | std::string codeBlock = |
963 | llvm::formatv(Fmt: "return ::mlir::success(llvm::isa<{0}>(self));" , |
964 | Vals&: cppClassName) |
965 | .str(); |
966 | |
967 | std::string desc = |
968 | processAndFormatDoc(doc: def->getValueAsString(FieldName: "description" )); |
969 | decls.push_back(Elt: createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>( |
970 | name: def->getName(), codeBlock, loc, type: opTy, nativeType: cppClassName, docString: desc)); |
971 | } |
972 | } |
973 | |
974 | template <typename ConstraintT> |
975 | ast::Decl *Parser::createODSNativePDLLConstraintDecl( |
976 | StringRef name, StringRef codeBlock, SMRange loc, ast::Type type, |
977 | StringRef nativeType, StringRef docString) { |
978 | // Build the single input parameter. |
979 | ast::DeclScope *argScope = pushDeclScope(); |
980 | auto *paramVar = ast::VariableDecl::create( |
981 | ctx, name: ast::Name::create(ctx, name: "self" , location: loc), type, |
982 | /*initExpr=*/nullptr, constraints: ast::ConstraintRef(ConstraintT::create(ctx, loc))); |
983 | argScope->add(decl: paramVar); |
984 | popDeclScope(); |
985 | |
986 | // Build the native constraint. |
987 | auto *constraintDecl = ast::UserConstraintDecl::createNative( |
988 | ctx, name: ast::Name::create(ctx, name, location: loc), inputs: paramVar, |
989 | /*results=*/std::nullopt, codeBlock, resultType: ast::TupleType::get(context&: ctx), |
990 | nativeInputTypes: nativeType); |
991 | constraintDecl->setDocComment(ctx, comment: docString); |
992 | curDeclScope->add(decl: constraintDecl); |
993 | return constraintDecl; |
994 | } |
995 | |
996 | template <typename ConstraintT> |
997 | ast::Decl * |
998 | Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, |
999 | SMRange loc, ast::Type type, |
1000 | StringRef nativeType) { |
1001 | // Format the condition template. |
1002 | tblgen::FmtContext fmtContext; |
1003 | fmtContext.withSelf(subst: "self" ); |
1004 | std::string codeBlock = tblgen::tgfmt( |
1005 | fmt: "return ::mlir::success(" + constraint.getConditionTemplate() + ");" , |
1006 | ctx: &fmtContext); |
1007 | |
1008 | // If documentation was enabled, build the doc string for the generated |
1009 | // constraint. It would be nice to do this lazily, but TableGen information is |
1010 | // destroyed after we finish parsing the file. |
1011 | std::string docString; |
1012 | if (enableDocumentation) { |
1013 | StringRef desc = constraint.getDescription(); |
1014 | docString = processAndFormatDoc( |
1015 | doc: constraint.getSummary() + |
1016 | (desc.empty() ? "" : ("\n\n" + constraint.getDescription()))); |
1017 | } |
1018 | |
1019 | return createODSNativePDLLConstraintDecl<ConstraintT>( |
1020 | constraint.getUniqueDefName(), codeBlock, loc, type, nativeType, |
1021 | docString); |
1022 | } |
1023 | |
1024 | //===----------------------------------------------------------------------===// |
1025 | // Decls |
1026 | |
1027 | FailureOr<ast::Decl *> Parser::parseTopLevelDecl() { |
1028 | FailureOr<ast::Decl *> decl; |
1029 | switch (curToken.getKind()) { |
1030 | case Token::kw_Constraint: |
1031 | decl = parseUserConstraintDecl(); |
1032 | break; |
1033 | case Token::kw_Pattern: |
1034 | decl = parsePatternDecl(); |
1035 | break; |
1036 | case Token::kw_Rewrite: |
1037 | decl = parseUserRewriteDecl(); |
1038 | break; |
1039 | default: |
1040 | return emitError(msg: "expected top-level declaration, such as a `Pattern`" ); |
1041 | } |
1042 | if (failed(result: decl)) |
1043 | return failure(); |
1044 | |
1045 | // If the decl has a name, add it to the current scope. |
1046 | if (const ast::Name *name = (*decl)->getName()) { |
1047 | if (failed(result: checkDefineNamedDecl(name: *name))) |
1048 | return failure(); |
1049 | curDeclScope->add(decl: *decl); |
1050 | } |
1051 | return decl; |
1052 | } |
1053 | |
1054 | FailureOr<ast::NamedAttributeDecl *> |
1055 | Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) { |
1056 | // Check for name code completion. |
1057 | if (curToken.is(k: Token::code_complete)) |
1058 | return codeCompleteAttributeName(opName: parentOpName); |
1059 | |
1060 | std::string attrNameStr; |
1061 | if (curToken.isString()) |
1062 | attrNameStr = curToken.getStringValue(); |
1063 | else if (curToken.is(k: Token::identifier) || curToken.isKeyword()) |
1064 | attrNameStr = curToken.getSpelling().str(); |
1065 | else |
1066 | return emitError(msg: "expected identifier or string attribute name" ); |
1067 | const auto &name = ast::Name::create(ctx, name: attrNameStr, location: curToken.getLoc()); |
1068 | consumeToken(); |
1069 | |
1070 | // Check for a value of the attribute. |
1071 | ast::Expr *attrValue = nullptr; |
1072 | if (consumeIf(kind: Token::equal)) { |
1073 | FailureOr<ast::Expr *> attrExpr = parseExpr(); |
1074 | if (failed(result: attrExpr)) |
1075 | return failure(); |
1076 | attrValue = *attrExpr; |
1077 | } else { |
1078 | // If there isn't a concrete value, create an expression representing a |
1079 | // UnitAttr. |
1080 | attrValue = ast::AttributeExpr::create(ctx, loc: name.getLoc(), value: "unit" ); |
1081 | } |
1082 | |
1083 | return ast::NamedAttributeDecl::create(ctx, name, value: attrValue); |
1084 | } |
1085 | |
1086 | FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody( |
1087 | function_ref<LogicalResult(ast::Stmt *&)> processStatementFn, |
1088 | bool expectTerminalSemicolon) { |
1089 | consumeToken(kind: Token::equal_arrow); |
1090 | |
1091 | // Parse the single statement of the lambda body. |
1092 | SMLoc bodyStartLoc = curToken.getStartLoc(); |
1093 | pushDeclScope(); |
1094 | FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon); |
1095 | bool failedToParse = |
1096 | failed(result: singleStatement) || failed(result: processStatementFn(*singleStatement)); |
1097 | popDeclScope(); |
1098 | if (failedToParse) |
1099 | return failure(); |
1100 | |
1101 | SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc()); |
1102 | return ast::CompoundStmt::create(ctx, location: bodyLoc, children: *singleStatement); |
1103 | } |
1104 | |
1105 | FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() { |
1106 | // Ensure that the argument is named. |
1107 | if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword()) |
1108 | return emitError(msg: "expected identifier argument name" ); |
1109 | |
1110 | // Parse the argument similarly to a normal variable. |
1111 | StringRef name = curToken.getSpelling(); |
1112 | SMRange nameLoc = curToken.getLoc(); |
1113 | consumeToken(); |
1114 | |
1115 | if (failed( |
1116 | result: parseToken(kind: Token::colon, msg: "expected `:` before argument constraint" ))) |
1117 | return failure(); |
1118 | |
1119 | FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); |
1120 | if (failed(result: cst)) |
1121 | return failure(); |
1122 | |
1123 | return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst); |
1124 | } |
1125 | |
1126 | FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) { |
1127 | // Check to see if this result is named. |
1128 | if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) { |
1129 | // Check to see if this name actually refers to a Constraint. |
1130 | if (!curDeclScope->lookup<ast::ConstraintDecl>(name: curToken.getSpelling())) { |
1131 | // If it wasn't a constraint, parse the result similarly to a variable. If |
1132 | // there is already an existing decl, we will emit an error when defining |
1133 | // this variable later. |
1134 | StringRef name = curToken.getSpelling(); |
1135 | SMRange nameLoc = curToken.getLoc(); |
1136 | consumeToken(); |
1137 | |
1138 | if (failed(result: parseToken(kind: Token::colon, |
1139 | msg: "expected `:` before result constraint" ))) |
1140 | return failure(); |
1141 | |
1142 | FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); |
1143 | if (failed(result: cst)) |
1144 | return failure(); |
1145 | |
1146 | return createArgOrResultVariableDecl(name, loc: nameLoc, constraint: *cst); |
1147 | } |
1148 | } |
1149 | |
1150 | // If it isn't named, we parse the constraint directly and create an unnamed |
1151 | // result variable. |
1152 | FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint(); |
1153 | if (failed(result: cst)) |
1154 | return failure(); |
1155 | |
1156 | return createArgOrResultVariableDecl(name: "" , loc: cst->referenceLoc, constraint: *cst); |
1157 | } |
1158 | |
1159 | FailureOr<ast::UserConstraintDecl *> |
1160 | Parser::parseUserConstraintDecl(bool isInline) { |
1161 | // Constraints and rewrites have very similar formats, dispatch to a shared |
1162 | // interface for parsing. |
1163 | return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>( |
1164 | parseUserPDLLFn: [&](auto &&...args) { |
1165 | return this->parseUserPDLLConstraintDecl(name: args...); |
1166 | }, |
1167 | declContext: ParserContext::Constraint, anonymousNamePrefix: "constraint" , isInline); |
1168 | } |
1169 | |
1170 | FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() { |
1171 | FailureOr<ast::UserConstraintDecl *> decl = |
1172 | parseUserConstraintDecl(/*isInline=*/true); |
1173 | if (failed(result: decl) || failed(result: checkDefineNamedDecl(name: (*decl)->getName()))) |
1174 | return failure(); |
1175 | |
1176 | curDeclScope->add(decl: *decl); |
1177 | return decl; |
1178 | } |
1179 | |
1180 | FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl( |
1181 | const ast::Name &name, bool isInline, |
1182 | ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, |
1183 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { |
1184 | // Push the argument scope back onto the list, so that the body can |
1185 | // reference arguments. |
1186 | pushDeclScope(scope: argumentScope); |
1187 | |
1188 | // Parse the body of the constraint. The body is either defined as a compound |
1189 | // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`. |
1190 | ast::CompoundStmt *body; |
1191 | if (curToken.is(k: Token::equal_arrow)) { |
1192 | FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( |
1193 | processStatementFn: [&](ast::Stmt *&stmt) -> LogicalResult { |
1194 | ast::Expr *stmtExpr = dyn_cast<ast::Expr>(Val: stmt); |
1195 | if (!stmtExpr) { |
1196 | return emitError(loc: stmt->getLoc(), |
1197 | msg: "expected `Constraint` lambda body to contain a " |
1198 | "single expression" ); |
1199 | } |
1200 | stmt = ast::ReturnStmt::create(ctx, loc: stmt->getLoc(), resultExpr: stmtExpr); |
1201 | return success(); |
1202 | }, |
1203 | /*expectTerminalSemicolon=*/!isInline); |
1204 | if (failed(result: bodyResult)) |
1205 | return failure(); |
1206 | body = *bodyResult; |
1207 | } else { |
1208 | FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); |
1209 | if (failed(result: bodyResult)) |
1210 | return failure(); |
1211 | body = *bodyResult; |
1212 | |
1213 | // Verify the structure of the body. |
1214 | auto bodyIt = body->begin(), bodyE = body->end(); |
1215 | for (; bodyIt != bodyE; ++bodyIt) |
1216 | if (isa<ast::ReturnStmt>(Val: *bodyIt)) |
1217 | break; |
1218 | if (failed(result: validateUserConstraintOrRewriteReturn( |
1219 | declType: "Constraint" , body, bodyIt, bodyE, results, resultType))) |
1220 | return failure(); |
1221 | } |
1222 | popDeclScope(); |
1223 | |
1224 | return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>( |
1225 | name, arguments, results, resultType, body); |
1226 | } |
1227 | |
1228 | FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) { |
1229 | // Constraints and rewrites have very similar formats, dispatch to a shared |
1230 | // interface for parsing. |
1231 | return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>( |
1232 | parseUserPDLLFn: [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(name: args...); }, |
1233 | declContext: ParserContext::Rewrite, anonymousNamePrefix: "rewrite" , isInline); |
1234 | } |
1235 | |
1236 | FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() { |
1237 | FailureOr<ast::UserRewriteDecl *> decl = |
1238 | parseUserRewriteDecl(/*isInline=*/true); |
1239 | if (failed(result: decl) || failed(result: checkDefineNamedDecl(name: (*decl)->getName()))) |
1240 | return failure(); |
1241 | |
1242 | curDeclScope->add(decl: *decl); |
1243 | return decl; |
1244 | } |
1245 | |
1246 | FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl( |
1247 | const ast::Name &name, bool isInline, |
1248 | ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope, |
1249 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { |
1250 | // Push the argument scope back onto the list, so that the body can |
1251 | // reference arguments. |
1252 | curDeclScope = argumentScope; |
1253 | ast::CompoundStmt *body; |
1254 | if (curToken.is(k: Token::equal_arrow)) { |
1255 | FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody( |
1256 | processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult { |
1257 | if (isa<ast::OpRewriteStmt>(Val: statement)) |
1258 | return success(); |
1259 | |
1260 | ast::Expr *statementExpr = dyn_cast<ast::Expr>(Val: statement); |
1261 | if (!statementExpr) { |
1262 | return emitError( |
1263 | loc: statement->getLoc(), |
1264 | msg: "expected `Rewrite` lambda body to contain a single expression " |
1265 | "or an operation rewrite statement; such as `erase`, " |
1266 | "`replace`, or `rewrite`" ); |
1267 | } |
1268 | statement = |
1269 | ast::ReturnStmt::create(ctx, loc: statement->getLoc(), resultExpr: statementExpr); |
1270 | return success(); |
1271 | }, |
1272 | /*expectTerminalSemicolon=*/!isInline); |
1273 | if (failed(result: bodyResult)) |
1274 | return failure(); |
1275 | body = *bodyResult; |
1276 | } else { |
1277 | FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); |
1278 | if (failed(result: bodyResult)) |
1279 | return failure(); |
1280 | body = *bodyResult; |
1281 | } |
1282 | popDeclScope(); |
1283 | |
1284 | // Verify the structure of the body. |
1285 | auto bodyIt = body->begin(), bodyE = body->end(); |
1286 | for (; bodyIt != bodyE; ++bodyIt) |
1287 | if (isa<ast::ReturnStmt>(Val: *bodyIt)) |
1288 | break; |
1289 | if (failed(result: validateUserConstraintOrRewriteReturn(declType: "Rewrite" , body, bodyIt, |
1290 | bodyE, results, resultType))) |
1291 | return failure(); |
1292 | return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>( |
1293 | name, arguments, results, resultType, body); |
1294 | } |
1295 | |
1296 | template <typename T, typename ParseUserPDLLDeclFnT> |
1297 | FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl( |
1298 | ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, |
1299 | StringRef anonymousNamePrefix, bool isInline) { |
1300 | SMRange loc = curToken.getLoc(); |
1301 | consumeToken(); |
1302 | llvm::SaveAndRestore saveCtx(parserContext, declContext); |
1303 | |
1304 | // Parse the name of the decl. |
1305 | const ast::Name *name = nullptr; |
1306 | if (curToken.isNot(k: Token::identifier)) { |
1307 | // Only inline decls can be un-named. Inline decls are similar to "lambdas" |
1308 | // in C++, so being unnamed is fine. |
1309 | if (!isInline) |
1310 | return emitError(msg: "expected identifier name" ); |
1311 | |
1312 | // Create a unique anonymous name to use, as the name for this decl is not |
1313 | // important. |
1314 | std::string anonName = |
1315 | llvm::formatv(Fmt: "<anonymous_{0}_{1}>" , Vals&: anonymousNamePrefix, |
1316 | Vals: anonymousDeclNameCounter++) |
1317 | .str(); |
1318 | name = &ast::Name::create(ctx, name: anonName, location: loc); |
1319 | } else { |
1320 | // If a name was provided, we can use it directly. |
1321 | name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc()); |
1322 | consumeToken(kind: Token::identifier); |
1323 | } |
1324 | |
1325 | // Parse the functional signature of the decl. |
1326 | SmallVector<ast::VariableDecl *> arguments, results; |
1327 | ast::DeclScope *argumentScope; |
1328 | ast::Type resultType; |
1329 | if (failed(result: parseUserConstraintOrRewriteSignature(arguments, results, |
1330 | argumentScope, resultType))) |
1331 | return failure(); |
1332 | |
1333 | // Check to see which type of constraint this is. If the constraint contains a |
1334 | // compound body, this is a PDLL decl. |
1335 | if (curToken.isAny(k1: Token::l_brace, k2: Token::equal_arrow)) |
1336 | return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, |
1337 | resultType); |
1338 | |
1339 | // Otherwise, this is a native decl. |
1340 | return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments, |
1341 | results, resultType); |
1342 | } |
1343 | |
1344 | template <typename T> |
1345 | FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( |
1346 | const ast::Name &name, bool isInline, |
1347 | ArrayRef<ast::VariableDecl *> arguments, |
1348 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType) { |
1349 | // If followed by a string, the native code body has also been specified. |
1350 | std::string codeStrStorage; |
1351 | std::optional<StringRef> optCodeStr; |
1352 | if (curToken.isString()) { |
1353 | codeStrStorage = curToken.getStringValue(); |
1354 | optCodeStr = codeStrStorage; |
1355 | consumeToken(); |
1356 | } else if (isInline) { |
1357 | return emitError(loc: name.getLoc(), |
1358 | msg: "external declarations must be declared in global scope" ); |
1359 | } else if (curToken.is(k: Token::error)) { |
1360 | return failure(); |
1361 | } |
1362 | if (failed(result: parseToken(kind: Token::semicolon, |
1363 | msg: "expected `;` after native declaration" ))) |
1364 | return failure(); |
1365 | return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); |
1366 | } |
1367 | |
1368 | LogicalResult Parser::parseUserConstraintOrRewriteSignature( |
1369 | SmallVectorImpl<ast::VariableDecl *> &arguments, |
1370 | SmallVectorImpl<ast::VariableDecl *> &results, |
1371 | ast::DeclScope *&argumentScope, ast::Type &resultType) { |
1372 | // Parse the argument list of the decl. |
1373 | if (failed(result: parseToken(kind: Token::l_paren, msg: "expected `(` to start argument list" ))) |
1374 | return failure(); |
1375 | |
1376 | argumentScope = pushDeclScope(); |
1377 | if (curToken.isNot(k: Token::r_paren)) { |
1378 | do { |
1379 | FailureOr<ast::VariableDecl *> argument = parseArgumentDecl(); |
1380 | if (failed(result: argument)) |
1381 | return failure(); |
1382 | arguments.emplace_back(Args&: *argument); |
1383 | } while (consumeIf(kind: Token::comma)); |
1384 | } |
1385 | popDeclScope(); |
1386 | if (failed(result: parseToken(kind: Token::r_paren, msg: "expected `)` to end argument list" ))) |
1387 | return failure(); |
1388 | |
1389 | // Parse the results of the decl. |
1390 | pushDeclScope(); |
1391 | if (consumeIf(kind: Token::arrow)) { |
1392 | auto parseResultFn = [&]() -> LogicalResult { |
1393 | FailureOr<ast::VariableDecl *> result = parseResultDecl(resultNum: results.size()); |
1394 | if (failed(result)) |
1395 | return failure(); |
1396 | results.emplace_back(Args&: *result); |
1397 | return success(); |
1398 | }; |
1399 | |
1400 | // Check for a list of results. |
1401 | if (consumeIf(kind: Token::l_paren)) { |
1402 | do { |
1403 | if (failed(result: parseResultFn())) |
1404 | return failure(); |
1405 | } while (consumeIf(kind: Token::comma)); |
1406 | if (failed(result: parseToken(kind: Token::r_paren, msg: "expected `)` to end result list" ))) |
1407 | return failure(); |
1408 | |
1409 | // Otherwise, there is only one result. |
1410 | } else if (failed(result: parseResultFn())) { |
1411 | return failure(); |
1412 | } |
1413 | } |
1414 | popDeclScope(); |
1415 | |
1416 | // Compute the result type of the decl. |
1417 | resultType = createUserConstraintRewriteResultType(results); |
1418 | |
1419 | // Verify that results are only named if there are more than one. |
1420 | if (results.size() == 1 && !results.front()->getName().getName().empty()) { |
1421 | return emitError( |
1422 | loc: results.front()->getLoc(), |
1423 | msg: "cannot create a single-element tuple with an element label" ); |
1424 | } |
1425 | return success(); |
1426 | } |
1427 | |
1428 | LogicalResult Parser::validateUserConstraintOrRewriteReturn( |
1429 | StringRef declType, ast::CompoundStmt *body, |
1430 | ArrayRef<ast::Stmt *>::iterator bodyIt, |
1431 | ArrayRef<ast::Stmt *>::iterator bodyE, |
1432 | ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) { |
1433 | // Handle if a `return` was provided. |
1434 | if (bodyIt != bodyE) { |
1435 | // Emit an error if we have trailing statements after the return. |
1436 | if (std::next(x: bodyIt) != bodyE) { |
1437 | return emitError( |
1438 | loc: (*std::next(x: bodyIt))->getLoc(), |
1439 | msg: llvm::formatv(Fmt: "`return` terminated the `{0}` body, but found " |
1440 | "trailing statements afterwards" , |
1441 | Vals&: declType)); |
1442 | } |
1443 | |
1444 | // Otherwise if a return wasn't provided, check that no results are |
1445 | // expected. |
1446 | } else if (!results.empty()) { |
1447 | return emitError( |
1448 | loc: {body->getLoc().End, body->getLoc().End}, |
1449 | msg: llvm::formatv(Fmt: "missing return in a `{0}` expected to return `{1}`" , |
1450 | Vals&: declType, Vals&: resultType)); |
1451 | } |
1452 | return success(); |
1453 | } |
1454 | |
1455 | FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() { |
1456 | return parseLambdaBody(processStatementFn: [&](ast::Stmt *&statement) -> LogicalResult { |
1457 | if (isa<ast::OpRewriteStmt>(Val: statement)) |
1458 | return success(); |
1459 | return emitError( |
1460 | loc: statement->getLoc(), |
1461 | msg: "expected Pattern lambda body to contain a single operation " |
1462 | "rewrite statement, such as `erase`, `replace`, or `rewrite`" ); |
1463 | }); |
1464 | } |
1465 | |
1466 | FailureOr<ast::Decl *> Parser::parsePatternDecl() { |
1467 | SMRange loc = curToken.getLoc(); |
1468 | consumeToken(kind: Token::kw_Pattern); |
1469 | llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch); |
1470 | |
1471 | // Check for an optional identifier for the pattern name. |
1472 | const ast::Name *name = nullptr; |
1473 | if (curToken.is(k: Token::identifier)) { |
1474 | name = &ast::Name::create(ctx, name: curToken.getSpelling(), location: curToken.getLoc()); |
1475 | consumeToken(kind: Token::identifier); |
1476 | } |
1477 | |
1478 | // Parse any pattern metadata. |
1479 | ParsedPatternMetadata metadata; |
1480 | if (consumeIf(kind: Token::kw_with) && failed(result: parsePatternDeclMetadata(metadata))) |
1481 | return failure(); |
1482 | |
1483 | // Parse the pattern body. |
1484 | ast::CompoundStmt *body; |
1485 | |
1486 | // Handle a lambda body. |
1487 | if (curToken.is(k: Token::equal_arrow)) { |
1488 | FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody(); |
1489 | if (failed(result: bodyResult)) |
1490 | return failure(); |
1491 | body = *bodyResult; |
1492 | } else { |
1493 | if (curToken.isNot(k: Token::l_brace)) |
1494 | return emitError(msg: "expected `{` or `=>` to start pattern body" ); |
1495 | FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt(); |
1496 | if (failed(result: bodyResult)) |
1497 | return failure(); |
1498 | body = *bodyResult; |
1499 | |
1500 | // Verify the body of the pattern. |
1501 | auto bodyIt = body->begin(), bodyE = body->end(); |
1502 | for (; bodyIt != bodyE; ++bodyIt) { |
1503 | if (isa<ast::ReturnStmt>(Val: *bodyIt)) { |
1504 | return emitError(loc: (*bodyIt)->getLoc(), |
1505 | msg: "`return` statements are only permitted within a " |
1506 | "`Constraint` or `Rewrite` body" ); |
1507 | } |
1508 | // Break when we've found the rewrite statement. |
1509 | if (isa<ast::OpRewriteStmt>(Val: *bodyIt)) |
1510 | break; |
1511 | } |
1512 | if (bodyIt == bodyE) { |
1513 | return emitError(loc, |
1514 | msg: "expected Pattern body to terminate with an operation " |
1515 | "rewrite statement, such as `erase`" ); |
1516 | } |
1517 | if (std::next(x: bodyIt) != bodyE) { |
1518 | return emitError(loc: (*std::next(x: bodyIt))->getLoc(), |
1519 | msg: "Pattern body was terminated by an operation " |
1520 | "rewrite statement, but found trailing statements" ); |
1521 | } |
1522 | } |
1523 | |
1524 | return createPatternDecl(loc, name, metadata, body); |
1525 | } |
1526 | |
1527 | LogicalResult |
1528 | Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { |
1529 | std::optional<SMRange> benefitLoc; |
1530 | std::optional<SMRange> hasBoundedRecursionLoc; |
1531 | |
1532 | do { |
1533 | // Handle metadata code completion. |
1534 | if (curToken.is(k: Token::code_complete)) |
1535 | return codeCompletePatternMetadata(); |
1536 | |
1537 | if (curToken.isNot(k: Token::identifier)) |
1538 | return emitError(msg: "expected pattern metadata identifier" ); |
1539 | StringRef metadataStr = curToken.getSpelling(); |
1540 | SMRange metadataLoc = curToken.getLoc(); |
1541 | consumeToken(kind: Token::identifier); |
1542 | |
1543 | // Parse the benefit metadata: benefit(<integer-value>) |
1544 | if (metadataStr == "benefit" ) { |
1545 | if (benefitLoc) { |
1546 | return emitErrorAndNote(loc: metadataLoc, |
1547 | msg: "pattern benefit has already been specified" , |
1548 | noteLoc: *benefitLoc, note: "see previous definition here" ); |
1549 | } |
1550 | if (failed(result: parseToken(kind: Token::l_paren, |
1551 | msg: "expected `(` before pattern benefit" ))) |
1552 | return failure(); |
1553 | |
1554 | uint16_t benefitValue = 0; |
1555 | if (curToken.isNot(k: Token::integer)) |
1556 | return emitError(msg: "expected integral pattern benefit" ); |
1557 | if (curToken.getSpelling().getAsInteger(/*Radix=*/10, Result&: benefitValue)) |
1558 | return emitError( |
1559 | msg: "expected pattern benefit to fit within a 16-bit integer" ); |
1560 | consumeToken(kind: Token::integer); |
1561 | |
1562 | metadata.benefit = benefitValue; |
1563 | benefitLoc = metadataLoc; |
1564 | |
1565 | if (failed( |
1566 | result: parseToken(kind: Token::r_paren, msg: "expected `)` after pattern benefit" ))) |
1567 | return failure(); |
1568 | continue; |
1569 | } |
1570 | |
1571 | // Parse the bounded recursion metadata: recursion |
1572 | if (metadataStr == "recursion" ) { |
1573 | if (hasBoundedRecursionLoc) { |
1574 | return emitErrorAndNote( |
1575 | loc: metadataLoc, |
1576 | msg: "pattern recursion metadata has already been specified" , |
1577 | noteLoc: *hasBoundedRecursionLoc, note: "see previous definition here" ); |
1578 | } |
1579 | metadata.hasBoundedRecursion = true; |
1580 | hasBoundedRecursionLoc = metadataLoc; |
1581 | continue; |
1582 | } |
1583 | |
1584 | return emitError(loc: metadataLoc, msg: "unknown pattern metadata" ); |
1585 | } while (consumeIf(kind: Token::comma)); |
1586 | |
1587 | return success(); |
1588 | } |
1589 | |
1590 | FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() { |
1591 | consumeToken(kind: Token::less); |
1592 | |
1593 | FailureOr<ast::Expr *> typeExpr = parseExpr(); |
1594 | if (failed(result: typeExpr) || |
1595 | failed(result: parseToken(kind: Token::greater, |
1596 | msg: "expected `>` after variable type constraint" ))) |
1597 | return failure(); |
1598 | return typeExpr; |
1599 | } |
1600 | |
1601 | LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { |
1602 | assert(curDeclScope && "defining decl outside of a decl scope" ); |
1603 | if (ast::Decl *lastDecl = curDeclScope->lookup(name: name.getName())) { |
1604 | return emitErrorAndNote( |
1605 | loc: name.getLoc(), msg: "`" + name.getName() + "` has already been defined" , |
1606 | noteLoc: lastDecl->getName()->getLoc(), note: "see previous definition here" ); |
1607 | } |
1608 | return success(); |
1609 | } |
1610 | |
1611 | FailureOr<ast::VariableDecl *> |
1612 | Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, |
1613 | ast::Expr *initExpr, |
1614 | ArrayRef<ast::ConstraintRef> constraints) { |
1615 | assert(curDeclScope && "defining variable outside of decl scope" ); |
1616 | const ast::Name &nameDecl = ast::Name::create(ctx, name, location: nameLoc); |
1617 | |
1618 | // If the name of the variable indicates a special variable, we don't add it |
1619 | // to the scope. This variable is local to the definition point. |
1620 | if (name.empty() || name == "_" ) { |
1621 | return ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr, |
1622 | constraints); |
1623 | } |
1624 | if (failed(result: checkDefineNamedDecl(name: nameDecl))) |
1625 | return failure(); |
1626 | |
1627 | auto *varDecl = |
1628 | ast::VariableDecl::create(ctx, name: nameDecl, type, initExpr, constraints); |
1629 | curDeclScope->add(decl: varDecl); |
1630 | return varDecl; |
1631 | } |
1632 | |
1633 | FailureOr<ast::VariableDecl *> |
1634 | Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, |
1635 | ArrayRef<ast::ConstraintRef> constraints) { |
1636 | return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, |
1637 | constraints); |
1638 | } |
1639 | |
1640 | LogicalResult Parser::parseVariableDeclConstraintList( |
1641 | SmallVectorImpl<ast::ConstraintRef> &constraints) { |
1642 | std::optional<SMRange> typeConstraint; |
1643 | auto parseSingleConstraint = [&] { |
1644 | FailureOr<ast::ConstraintRef> constraint = parseConstraint( |
1645 | typeConstraint, existingConstraints: constraints, /*allowInlineTypeConstraints=*/true); |
1646 | if (failed(result: constraint)) |
1647 | return failure(); |
1648 | constraints.push_back(Elt: *constraint); |
1649 | return success(); |
1650 | }; |
1651 | |
1652 | // Check to see if this is a single constraint, or a list. |
1653 | if (!consumeIf(kind: Token::l_square)) |
1654 | return parseSingleConstraint(); |
1655 | |
1656 | do { |
1657 | if (failed(result: parseSingleConstraint())) |
1658 | return failure(); |
1659 | } while (consumeIf(kind: Token::comma)); |
1660 | return parseToken(kind: Token::r_square, msg: "expected `]` after constraint list" ); |
1661 | } |
1662 | |
1663 | FailureOr<ast::ConstraintRef> |
1664 | Parser::parseConstraint(std::optional<SMRange> &typeConstraint, |
1665 | ArrayRef<ast::ConstraintRef> existingConstraints, |
1666 | bool allowInlineTypeConstraints) { |
1667 | auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { |
1668 | if (!allowInlineTypeConstraints) { |
1669 | return emitError( |
1670 | loc: curToken.getLoc(), |
1671 | msg: "inline `Attr`, `Value`, and `ValueRange` type constraints are not " |
1672 | "permitted on arguments or results" ); |
1673 | } |
1674 | if (typeConstraint) |
1675 | return emitErrorAndNote( |
1676 | loc: curToken.getLoc(), |
1677 | msg: "the type of this variable has already been constrained" , |
1678 | noteLoc: *typeConstraint, note: "see previous constraint location here" ); |
1679 | FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr(); |
1680 | if (failed(result: constraintExpr)) |
1681 | return failure(); |
1682 | typeExpr = *constraintExpr; |
1683 | typeConstraint = typeExpr->getLoc(); |
1684 | return success(); |
1685 | }; |
1686 | |
1687 | SMRange loc = curToken.getLoc(); |
1688 | switch (curToken.getKind()) { |
1689 | case Token::kw_Attr: { |
1690 | consumeToken(kind: Token::kw_Attr); |
1691 | |
1692 | // Check for a type constraint. |
1693 | ast::Expr *typeExpr = nullptr; |
1694 | if (curToken.is(k: Token::less) && failed(result: parseTypeConstraint(typeExpr))) |
1695 | return failure(); |
1696 | return ast::ConstraintRef( |
1697 | ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc); |
1698 | } |
1699 | case Token::kw_Op: { |
1700 | consumeToken(kind: Token::kw_Op); |
1701 | |
1702 | // Parse an optional operation name. If the name isn't provided, this refers |
1703 | // to "any" operation. |
1704 | FailureOr<ast::OpNameDecl *> opName = |
1705 | parseWrappedOperationName(/*allowEmptyName=*/true); |
1706 | if (failed(result: opName)) |
1707 | return failure(); |
1708 | |
1709 | return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, nameDecl: *opName), |
1710 | loc); |
1711 | } |
1712 | case Token::kw_Type: |
1713 | consumeToken(kind: Token::kw_Type); |
1714 | return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc); |
1715 | case Token::kw_TypeRange: |
1716 | consumeToken(kind: Token::kw_TypeRange); |
1717 | return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc), |
1718 | loc); |
1719 | case Token::kw_Value: { |
1720 | consumeToken(kind: Token::kw_Value); |
1721 | |
1722 | // Check for a type constraint. |
1723 | ast::Expr *typeExpr = nullptr; |
1724 | if (curToken.is(k: Token::less) && failed(result: parseTypeConstraint(typeExpr))) |
1725 | return failure(); |
1726 | |
1727 | return ast::ConstraintRef( |
1728 | ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc); |
1729 | } |
1730 | case Token::kw_ValueRange: { |
1731 | consumeToken(kind: Token::kw_ValueRange); |
1732 | |
1733 | // Check for a type constraint. |
1734 | ast::Expr *typeExpr = nullptr; |
1735 | if (curToken.is(k: Token::less) && failed(result: parseTypeConstraint(typeExpr))) |
1736 | return failure(); |
1737 | |
1738 | return ast::ConstraintRef( |
1739 | ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); |
1740 | } |
1741 | |
1742 | case Token::kw_Constraint: { |
1743 | // Handle an inline constraint. |
1744 | FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); |
1745 | if (failed(result: decl)) |
1746 | return failure(); |
1747 | return ast::ConstraintRef(*decl, loc); |
1748 | } |
1749 | case Token::identifier: { |
1750 | StringRef constraintName = curToken.getSpelling(); |
1751 | consumeToken(kind: Token::identifier); |
1752 | |
1753 | // Lookup the referenced constraint. |
1754 | ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(name: constraintName); |
1755 | if (!cstDecl) { |
1756 | return emitError(loc, msg: "unknown reference to constraint `" + |
1757 | constraintName + "`" ); |
1758 | } |
1759 | |
1760 | // Handle a reference to a proper constraint. |
1761 | if (auto *cst = dyn_cast<ast::ConstraintDecl>(Val: cstDecl)) |
1762 | return ast::ConstraintRef(cst, loc); |
1763 | |
1764 | return emitErrorAndNote( |
1765 | loc, msg: "invalid reference to non-constraint" , noteLoc: cstDecl->getLoc(), |
1766 | note: "see the definition of `" + constraintName + "` here" ); |
1767 | } |
1768 | // Handle single entity constraint code completion. |
1769 | case Token::code_complete: { |
1770 | // Try to infer the current type for use by code completion. |
1771 | ast::Type inferredType; |
1772 | if (failed(result: validateVariableConstraints(constraints: existingConstraints, inferredType))) |
1773 | return failure(); |
1774 | |
1775 | return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints); |
1776 | } |
1777 | default: |
1778 | break; |
1779 | } |
1780 | return emitError(loc, msg: "expected identifier constraint" ); |
1781 | } |
1782 | |
1783 | FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() { |
1784 | std::optional<SMRange> typeConstraint; |
1785 | return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt, |
1786 | /*allowInlineTypeConstraints=*/false); |
1787 | } |
1788 | |
1789 | //===----------------------------------------------------------------------===// |
1790 | // Exprs |
1791 | |
1792 | FailureOr<ast::Expr *> Parser::parseExpr() { |
1793 | if (curToken.is(k: Token::underscore)) |
1794 | return parseUnderscoreExpr(); |
1795 | |
1796 | // Parse the LHS expression. |
1797 | FailureOr<ast::Expr *> lhsExpr; |
1798 | switch (curToken.getKind()) { |
1799 | case Token::kw_attr: |
1800 | lhsExpr = parseAttributeExpr(); |
1801 | break; |
1802 | case Token::kw_Constraint: |
1803 | lhsExpr = parseInlineConstraintLambdaExpr(); |
1804 | break; |
1805 | case Token::kw_not: |
1806 | lhsExpr = parseNegatedExpr(); |
1807 | break; |
1808 | case Token::identifier: |
1809 | lhsExpr = parseIdentifierExpr(); |
1810 | break; |
1811 | case Token::kw_op: |
1812 | lhsExpr = parseOperationExpr(); |
1813 | break; |
1814 | case Token::kw_Rewrite: |
1815 | lhsExpr = parseInlineRewriteLambdaExpr(); |
1816 | break; |
1817 | case Token::kw_type: |
1818 | lhsExpr = parseTypeExpr(); |
1819 | break; |
1820 | case Token::l_paren: |
1821 | lhsExpr = parseTupleExpr(); |
1822 | break; |
1823 | default: |
1824 | return emitError(msg: "expected expression" ); |
1825 | } |
1826 | if (failed(result: lhsExpr)) |
1827 | return failure(); |
1828 | |
1829 | // Check for an operator expression. |
1830 | while (true) { |
1831 | switch (curToken.getKind()) { |
1832 | case Token::dot: |
1833 | lhsExpr = parseMemberAccessExpr(parentExpr: *lhsExpr); |
1834 | break; |
1835 | case Token::l_paren: |
1836 | lhsExpr = parseCallExpr(parentExpr: *lhsExpr); |
1837 | break; |
1838 | default: |
1839 | return lhsExpr; |
1840 | } |
1841 | if (failed(result: lhsExpr)) |
1842 | return failure(); |
1843 | } |
1844 | } |
1845 | |
1846 | FailureOr<ast::Expr *> Parser::parseAttributeExpr() { |
1847 | SMRange loc = curToken.getLoc(); |
1848 | consumeToken(kind: Token::kw_attr); |
1849 | |
1850 | // If we aren't followed by a `<`, the `attr` keyword is treated as a normal |
1851 | // identifier. |
1852 | if (!consumeIf(kind: Token::less)) { |
1853 | resetToken(tokLoc: loc); |
1854 | return parseIdentifierExpr(); |
1855 | } |
1856 | |
1857 | if (!curToken.isString()) |
1858 | return emitError(msg: "expected string literal containing MLIR attribute" ); |
1859 | std::string attrExpr = curToken.getStringValue(); |
1860 | consumeToken(); |
1861 | |
1862 | loc.End = curToken.getEndLoc(); |
1863 | if (failed( |
1864 | result: parseToken(kind: Token::greater, msg: "expected `>` after attribute literal" ))) |
1865 | return failure(); |
1866 | return ast::AttributeExpr::create(ctx, loc, value: attrExpr); |
1867 | } |
1868 | |
1869 | FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr, |
1870 | bool isNegated) { |
1871 | consumeToken(kind: Token::l_paren); |
1872 | |
1873 | // Parse the arguments of the call. |
1874 | SmallVector<ast::Expr *> arguments; |
1875 | if (curToken.isNot(k: Token::r_paren)) { |
1876 | do { |
1877 | // Handle code completion for the call arguments. |
1878 | if (curToken.is(k: Token::code_complete)) { |
1879 | codeCompleteCallSignature(parent: parentExpr, currentNumArgs: arguments.size()); |
1880 | return failure(); |
1881 | } |
1882 | |
1883 | FailureOr<ast::Expr *> argument = parseExpr(); |
1884 | if (failed(result: argument)) |
1885 | return failure(); |
1886 | arguments.push_back(Elt: *argument); |
1887 | } while (consumeIf(kind: Token::comma)); |
1888 | } |
1889 | |
1890 | SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); |
1891 | if (failed(result: parseToken(kind: Token::r_paren, msg: "expected `)` after argument list" ))) |
1892 | return failure(); |
1893 | |
1894 | return createCallExpr(loc, parentExpr, arguments, isNegated); |
1895 | } |
1896 | |
1897 | FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) { |
1898 | ast::Decl *decl = curDeclScope->lookup(name); |
1899 | if (!decl) |
1900 | return emitError(loc, msg: "undefined reference to `" + name + "`" ); |
1901 | |
1902 | return createDeclRefExpr(loc, decl); |
1903 | } |
1904 | |
1905 | FailureOr<ast::Expr *> Parser::parseIdentifierExpr() { |
1906 | StringRef name = curToken.getSpelling(); |
1907 | SMRange nameLoc = curToken.getLoc(); |
1908 | consumeToken(); |
1909 | |
1910 | // Check to see if this is a decl ref expression that defines a variable |
1911 | // inline. |
1912 | if (consumeIf(kind: Token::colon)) { |
1913 | SmallVector<ast::ConstraintRef> constraints; |
1914 | if (failed(result: parseVariableDeclConstraintList(constraints))) |
1915 | return failure(); |
1916 | ast::Type type; |
1917 | if (failed(result: validateVariableConstraints(constraints, inferredType&: type))) |
1918 | return failure(); |
1919 | return createInlineVariableExpr(type, name, loc: nameLoc, constraints); |
1920 | } |
1921 | |
1922 | return parseDeclRefExpr(name, loc: nameLoc); |
1923 | } |
1924 | |
1925 | FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() { |
1926 | FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl(); |
1927 | if (failed(result: decl)) |
1928 | return failure(); |
1929 | |
1930 | return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl, |
1931 | type: ast::ConstraintType::get(context&: ctx)); |
1932 | } |
1933 | |
1934 | FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() { |
1935 | FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl(); |
1936 | if (failed(result: decl)) |
1937 | return failure(); |
1938 | |
1939 | return ast::DeclRefExpr::create(ctx, loc: (*decl)->getLoc(), decl: *decl, |
1940 | type: ast::RewriteType::get(context&: ctx)); |
1941 | } |
1942 | |
1943 | FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { |
1944 | SMRange dotLoc = curToken.getLoc(); |
1945 | consumeToken(kind: Token::dot); |
1946 | |
1947 | // Check for code completion of the member name. |
1948 | if (curToken.is(k: Token::code_complete)) |
1949 | return codeCompleteMemberAccess(parentExpr); |
1950 | |
1951 | // Parse the member name. |
1952 | Token memberNameTok = curToken; |
1953 | if (memberNameTok.isNot(k1: Token::identifier, k2: Token::integer) && |
1954 | !memberNameTok.isKeyword()) |
1955 | return emitError(loc: dotLoc, msg: "expected identifier or numeric member name" ); |
1956 | StringRef memberName = memberNameTok.getSpelling(); |
1957 | SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc()); |
1958 | consumeToken(); |
1959 | |
1960 | return createMemberAccessExpr(parentExpr, name: memberName, loc); |
1961 | } |
1962 | |
1963 | FailureOr<ast::Expr *> Parser::parseNegatedExpr() { |
1964 | consumeToken(kind: Token::kw_not); |
1965 | // Only native constraints are supported after negation |
1966 | if (!curToken.is(k: Token::identifier)) |
1967 | return emitError(msg: "expected native constraint" ); |
1968 | FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr(); |
1969 | if (failed(result: identifierExpr)) |
1970 | return failure(); |
1971 | if (!curToken.is(k: Token::l_paren)) |
1972 | return emitError(msg: "expected `(` after function name" ); |
1973 | return parseCallExpr(parentExpr: *identifierExpr, /*isNegated = */ true); |
1974 | } |
1975 | |
1976 | FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) { |
1977 | SMRange loc = curToken.getLoc(); |
1978 | |
1979 | // Check for code completion for the dialect name. |
1980 | if (curToken.is(k: Token::code_complete)) |
1981 | return codeCompleteDialectName(); |
1982 | |
1983 | // Handle the case of an no operation name. |
1984 | if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword()) { |
1985 | if (allowEmptyName) |
1986 | return ast::OpNameDecl::create(ctx, loc: SMRange()); |
1987 | return emitError(msg: "expected dialect namespace" ); |
1988 | } |
1989 | StringRef name = curToken.getSpelling(); |
1990 | consumeToken(); |
1991 | |
1992 | // Otherwise, this is a literal operation name. |
1993 | if (failed(result: parseToken(kind: Token::dot, msg: "expected `.` after dialect namespace" ))) |
1994 | return failure(); |
1995 | |
1996 | // Check for code completion for the operation name. |
1997 | if (curToken.is(k: Token::code_complete)) |
1998 | return codeCompleteOperationName(dialectName: name); |
1999 | |
2000 | if (curToken.isNot(k: Token::identifier) && !curToken.isKeyword()) |
2001 | return emitError(msg: "expected operation name after dialect namespace" ); |
2002 | |
2003 | name = StringRef(name.data(), name.size() + 1); |
2004 | do { |
2005 | name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); |
2006 | loc.End = curToken.getEndLoc(); |
2007 | consumeToken(); |
2008 | } while (curToken.isAny(k1: Token::identifier, k2: Token::dot) || |
2009 | curToken.isKeyword()); |
2010 | return ast::OpNameDecl::create(ctx, name: ast::Name::create(ctx, name, location: loc)); |
2011 | } |
2012 | |
2013 | FailureOr<ast::OpNameDecl *> |
2014 | Parser::parseWrappedOperationName(bool allowEmptyName) { |
2015 | if (!consumeIf(kind: Token::less)) |
2016 | return ast::OpNameDecl::create(ctx, loc: SMRange()); |
2017 | |
2018 | FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName); |
2019 | if (failed(result: opNameDecl)) |
2020 | return failure(); |
2021 | |
2022 | if (failed(result: parseToken(kind: Token::greater, msg: "expected `>` after operation name" ))) |
2023 | return failure(); |
2024 | return opNameDecl; |
2025 | } |
2026 | |
2027 | FailureOr<ast::Expr *> |
2028 | Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) { |
2029 | SMRange loc = curToken.getLoc(); |
2030 | consumeToken(kind: Token::kw_op); |
2031 | |
2032 | // If it isn't followed by a `<`, the `op` keyword is treated as a normal |
2033 | // identifier. |
2034 | if (curToken.isNot(k: Token::less)) { |
2035 | resetToken(tokLoc: loc); |
2036 | return parseIdentifierExpr(); |
2037 | } |
2038 | |
2039 | // Parse the operation name. The name may be elided, in which case the |
2040 | // operation refers to "any" operation(i.e. a difference between `MyOp` and |
2041 | // `Operation*`). Operation names within a rewrite context must be named. |
2042 | bool allowEmptyName = parserContext != ParserContext::Rewrite; |
2043 | FailureOr<ast::OpNameDecl *> opNameDecl = |
2044 | parseWrappedOperationName(allowEmptyName); |
2045 | if (failed(result: opNameDecl)) |
2046 | return failure(); |
2047 | std::optional<StringRef> opName = (*opNameDecl)->getName(); |
2048 | |
2049 | // Functor used to create an implicit range variable, used for implicit "all" |
2050 | // operand or results variables. |
2051 | auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) { |
2052 | FailureOr<ast::VariableDecl *> rangeVar = |
2053 | defineVariableDecl(name: "_" , nameLoc: loc, type, constraints: ast::ConstraintRef(cst, loc)); |
2054 | assert(succeeded(rangeVar) && "expected range variable to be valid" ); |
2055 | return ast::DeclRefExpr::create(ctx, loc, decl: *rangeVar, type); |
2056 | }; |
2057 | |
2058 | // Check for the optional list of operands. |
2059 | SmallVector<ast::Expr *> operands; |
2060 | if (!consumeIf(kind: Token::l_paren)) { |
2061 | // If the operand list isn't specified and we are in a match context, define |
2062 | // an inplace unconstrained operand range corresponding to all of the |
2063 | // operands of the operation. This avoids treating zero operands the same |
2064 | // way as "unconstrained operands". |
2065 | if (parserContext != ParserContext::Rewrite) { |
2066 | operands.push_back(Elt: createImplicitRangeVar( |
2067 | ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy)); |
2068 | } |
2069 | } else if (!consumeIf(kind: Token::r_paren)) { |
2070 | // If the operand list was specified and non-empty, parse the operands. |
2071 | do { |
2072 | // Check for operand signature code completion. |
2073 | if (curToken.is(k: Token::code_complete)) { |
2074 | codeCompleteOperationOperandsSignature(opName, currentNumOperands: operands.size()); |
2075 | return failure(); |
2076 | } |
2077 | |
2078 | FailureOr<ast::Expr *> operand = parseExpr(); |
2079 | if (failed(result: operand)) |
2080 | return failure(); |
2081 | operands.push_back(Elt: *operand); |
2082 | } while (consumeIf(kind: Token::comma)); |
2083 | |
2084 | if (failed(result: parseToken(kind: Token::r_paren, |
2085 | msg: "expected `)` after operation operand list" ))) |
2086 | return failure(); |
2087 | } |
2088 | |
2089 | // Check for the optional list of attributes. |
2090 | SmallVector<ast::NamedAttributeDecl *> attributes; |
2091 | if (consumeIf(kind: Token::l_brace)) { |
2092 | do { |
2093 | FailureOr<ast::NamedAttributeDecl *> decl = |
2094 | parseNamedAttributeDecl(parentOpName: opName); |
2095 | if (failed(result: decl)) |
2096 | return failure(); |
2097 | attributes.emplace_back(Args&: *decl); |
2098 | } while (consumeIf(kind: Token::comma)); |
2099 | |
2100 | if (failed(result: parseToken(kind: Token::r_brace, |
2101 | msg: "expected `}` after operation attribute list" ))) |
2102 | return failure(); |
2103 | } |
2104 | |
2105 | // Handle the result types of the operation. |
2106 | SmallVector<ast::Expr *> resultTypes; |
2107 | OpResultTypeContext resultTypeContext = inputResultTypeContext; |
2108 | |
2109 | // Check for an explicit list of result types. |
2110 | if (consumeIf(kind: Token::arrow)) { |
2111 | if (failed(result: parseToken(kind: Token::l_paren, |
2112 | msg: "expected `(` before operation result type list" ))) |
2113 | return failure(); |
2114 | |
2115 | // If result types are provided, initially assume that the operation does |
2116 | // not rely on type inferrence. We don't assert that it isn't, because we |
2117 | // may be inferring the value of some type/type range variables, but given |
2118 | // that these variables may be defined in calls we can't always discern when |
2119 | // this is the case. |
2120 | resultTypeContext = OpResultTypeContext::Explicit; |
2121 | |
2122 | // Handle the case of an empty result list. |
2123 | if (!consumeIf(kind: Token::r_paren)) { |
2124 | do { |
2125 | // Check for result signature code completion. |
2126 | if (curToken.is(k: Token::code_complete)) { |
2127 | codeCompleteOperationResultsSignature(opName, currentNumResults: resultTypes.size()); |
2128 | return failure(); |
2129 | } |
2130 | |
2131 | FailureOr<ast::Expr *> resultTypeExpr = parseExpr(); |
2132 | if (failed(result: resultTypeExpr)) |
2133 | return failure(); |
2134 | resultTypes.push_back(Elt: *resultTypeExpr); |
2135 | } while (consumeIf(kind: Token::comma)); |
2136 | |
2137 | if (failed(result: parseToken(kind: Token::r_paren, |
2138 | msg: "expected `)` after operation result type list" ))) |
2139 | return failure(); |
2140 | } |
2141 | } else if (parserContext != ParserContext::Rewrite) { |
2142 | // If the result list isn't specified and we are in a match context, define |
2143 | // an inplace unconstrained result range corresponding to all of the results |
2144 | // of the operation. This avoids treating zero results the same way as |
2145 | // "unconstrained results". |
2146 | resultTypes.push_back(Elt: createImplicitRangeVar( |
2147 | ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); |
2148 | } else if (resultTypeContext == OpResultTypeContext::Explicit) { |
2149 | // If the result list isn't specified and we are in a rewrite, try to infer |
2150 | // them at runtime instead. |
2151 | resultTypeContext = OpResultTypeContext::Interface; |
2152 | } |
2153 | |
2154 | return createOperationExpr(loc, name: *opNameDecl, resultTypeContext, operands, |
2155 | attributes, results&: resultTypes); |
2156 | } |
2157 | |
2158 | FailureOr<ast::Expr *> Parser::parseTupleExpr() { |
2159 | SMRange loc = curToken.getLoc(); |
2160 | consumeToken(kind: Token::l_paren); |
2161 | |
2162 | DenseMap<StringRef, SMRange> usedNames; |
2163 | SmallVector<StringRef> elementNames; |
2164 | SmallVector<ast::Expr *> elements; |
2165 | if (curToken.isNot(k: Token::r_paren)) { |
2166 | do { |
2167 | // Check for the optional element name assignment before the value. |
2168 | StringRef elementName; |
2169 | if (curToken.is(k: Token::identifier) || curToken.isDependentKeyword()) { |
2170 | Token elementNameTok = curToken; |
2171 | consumeToken(); |
2172 | |
2173 | // The element name is only present if followed by an `=`. |
2174 | if (consumeIf(kind: Token::equal)) { |
2175 | elementName = elementNameTok.getSpelling(); |
2176 | |
2177 | // Check to see if this name is already used. |
2178 | auto elementNameIt = |
2179 | usedNames.try_emplace(Key: elementName, Args: elementNameTok.getLoc()); |
2180 | if (!elementNameIt.second) { |
2181 | return emitErrorAndNote( |
2182 | loc: elementNameTok.getLoc(), |
2183 | msg: llvm::formatv(Fmt: "duplicate tuple element label `{0}`" , |
2184 | Vals&: elementName), |
2185 | noteLoc: elementNameIt.first->getSecond(), |
2186 | note: "see previous label use here" ); |
2187 | } |
2188 | } else { |
2189 | // Otherwise, we treat this as part of an expression so reset the |
2190 | // lexer. |
2191 | resetToken(tokLoc: elementNameTok.getLoc()); |
2192 | } |
2193 | } |
2194 | elementNames.push_back(Elt: elementName); |
2195 | |
2196 | // Parse the tuple element value. |
2197 | FailureOr<ast::Expr *> element = parseExpr(); |
2198 | if (failed(result: element)) |
2199 | return failure(); |
2200 | elements.push_back(Elt: *element); |
2201 | } while (consumeIf(kind: Token::comma)); |
2202 | } |
2203 | loc.End = curToken.getEndLoc(); |
2204 | if (failed( |
2205 | result: parseToken(kind: Token::r_paren, msg: "expected `)` after tuple element list" ))) |
2206 | return failure(); |
2207 | return createTupleExpr(loc, elements, elementNames); |
2208 | } |
2209 | |
2210 | FailureOr<ast::Expr *> Parser::parseTypeExpr() { |
2211 | SMRange loc = curToken.getLoc(); |
2212 | consumeToken(kind: Token::kw_type); |
2213 | |
2214 | // If we aren't followed by a `<`, the `type` keyword is treated as a normal |
2215 | // identifier. |
2216 | if (!consumeIf(kind: Token::less)) { |
2217 | resetToken(tokLoc: loc); |
2218 | return parseIdentifierExpr(); |
2219 | } |
2220 | |
2221 | if (!curToken.isString()) |
2222 | return emitError(msg: "expected string literal containing MLIR type" ); |
2223 | std::string attrExpr = curToken.getStringValue(); |
2224 | consumeToken(); |
2225 | |
2226 | loc.End = curToken.getEndLoc(); |
2227 | if (failed(result: parseToken(kind: Token::greater, msg: "expected `>` after type literal" ))) |
2228 | return failure(); |
2229 | return ast::TypeExpr::create(ctx, loc, value: attrExpr); |
2230 | } |
2231 | |
2232 | FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() { |
2233 | StringRef name = curToken.getSpelling(); |
2234 | SMRange nameLoc = curToken.getLoc(); |
2235 | consumeToken(kind: Token::underscore); |
2236 | |
2237 | // Underscore expressions require a constraint list. |
2238 | if (failed(result: parseToken(kind: Token::colon, msg: "expected `:` after `_` variable" ))) |
2239 | return failure(); |
2240 | |
2241 | // Parse the constraints for the expression. |
2242 | SmallVector<ast::ConstraintRef> constraints; |
2243 | if (failed(result: parseVariableDeclConstraintList(constraints))) |
2244 | return failure(); |
2245 | |
2246 | ast::Type type; |
2247 | if (failed(result: validateVariableConstraints(constraints, inferredType&: type))) |
2248 | return failure(); |
2249 | return createInlineVariableExpr(type, name, loc: nameLoc, constraints); |
2250 | } |
2251 | |
2252 | //===----------------------------------------------------------------------===// |
2253 | // Stmts |
2254 | |
2255 | FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) { |
2256 | FailureOr<ast::Stmt *> stmt; |
2257 | switch (curToken.getKind()) { |
2258 | case Token::kw_erase: |
2259 | stmt = parseEraseStmt(); |
2260 | break; |
2261 | case Token::kw_let: |
2262 | stmt = parseLetStmt(); |
2263 | break; |
2264 | case Token::kw_replace: |
2265 | stmt = parseReplaceStmt(); |
2266 | break; |
2267 | case Token::kw_return: |
2268 | stmt = parseReturnStmt(); |
2269 | break; |
2270 | case Token::kw_rewrite: |
2271 | stmt = parseRewriteStmt(); |
2272 | break; |
2273 | default: |
2274 | stmt = parseExpr(); |
2275 | break; |
2276 | } |
2277 | if (failed(result: stmt) || |
2278 | (expectTerminalSemicolon && |
2279 | failed(result: parseToken(kind: Token::semicolon, msg: "expected `;` after statement" )))) |
2280 | return failure(); |
2281 | return stmt; |
2282 | } |
2283 | |
2284 | FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() { |
2285 | SMLoc startLoc = curToken.getStartLoc(); |
2286 | consumeToken(kind: Token::l_brace); |
2287 | |
2288 | // Push a new block scope and parse any nested statements. |
2289 | pushDeclScope(); |
2290 | SmallVector<ast::Stmt *> statements; |
2291 | while (curToken.isNot(k: Token::r_brace)) { |
2292 | FailureOr<ast::Stmt *> statement = parseStmt(); |
2293 | if (failed(result: statement)) |
2294 | return popDeclScope(), failure(); |
2295 | statements.push_back(Elt: *statement); |
2296 | } |
2297 | popDeclScope(); |
2298 | |
2299 | // Consume the end brace. |
2300 | SMRange location(startLoc, curToken.getEndLoc()); |
2301 | consumeToken(kind: Token::r_brace); |
2302 | |
2303 | return ast::CompoundStmt::create(ctx, location, children: statements); |
2304 | } |
2305 | |
2306 | FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() { |
2307 | if (parserContext == ParserContext::Constraint) |
2308 | return emitError(msg: "`erase` cannot be used within a Constraint" ); |
2309 | SMRange loc = curToken.getLoc(); |
2310 | consumeToken(kind: Token::kw_erase); |
2311 | |
2312 | // Parse the root operation expression. |
2313 | FailureOr<ast::Expr *> rootOp = parseExpr(); |
2314 | if (failed(result: rootOp)) |
2315 | return failure(); |
2316 | |
2317 | return createEraseStmt(loc, rootOp: *rootOp); |
2318 | } |
2319 | |
2320 | FailureOr<ast::LetStmt *> Parser::parseLetStmt() { |
2321 | SMRange loc = curToken.getLoc(); |
2322 | consumeToken(kind: Token::kw_let); |
2323 | |
2324 | // Parse the name of the new variable. |
2325 | SMRange varLoc = curToken.getLoc(); |
2326 | if (curToken.isNot(k: Token::identifier) && !curToken.isDependentKeyword()) { |
2327 | // `_` is a reserved variable name. |
2328 | if (curToken.is(k: Token::underscore)) { |
2329 | return emitError(loc: varLoc, |
2330 | msg: "`_` may only be used to define \"inline\" variables" ); |
2331 | } |
2332 | return emitError(loc: varLoc, |
2333 | msg: "expected identifier after `let` to name a new variable" ); |
2334 | } |
2335 | StringRef varName = curToken.getSpelling(); |
2336 | consumeToken(); |
2337 | |
2338 | // Parse the optional set of constraints. |
2339 | SmallVector<ast::ConstraintRef> constraints; |
2340 | if (consumeIf(kind: Token::colon) && |
2341 | failed(result: parseVariableDeclConstraintList(constraints))) |
2342 | return failure(); |
2343 | |
2344 | // Parse the optional initializer expression. |
2345 | ast::Expr *initializer = nullptr; |
2346 | if (consumeIf(kind: Token::equal)) { |
2347 | FailureOr<ast::Expr *> initOrFailure = parseExpr(); |
2348 | if (failed(result: initOrFailure)) |
2349 | return failure(); |
2350 | initializer = *initOrFailure; |
2351 | |
2352 | // Check that the constraints are compatible with having an initializer, |
2353 | // e.g. type constraints cannot be used with initializers. |
2354 | for (ast::ConstraintRef constraint : constraints) { |
2355 | LogicalResult result = |
2356 | TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint) |
2357 | .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, |
2358 | ast::ValueRangeConstraintDecl>(caseFn: [&](const auto *cst) { |
2359 | if (cst->getTypeExpr()) { |
2360 | return this->emitError( |
2361 | loc: constraint.referenceLoc, |
2362 | msg: "type constraints are not permitted on variables with " |
2363 | "initializers" ); |
2364 | } |
2365 | return success(); |
2366 | }) |
2367 | .Default(defaultResult: success()); |
2368 | if (failed(result)) |
2369 | return failure(); |
2370 | } |
2371 | } |
2372 | |
2373 | FailureOr<ast::VariableDecl *> varDecl = |
2374 | createVariableDecl(name: varName, loc: varLoc, initializer, constraints); |
2375 | if (failed(result: varDecl)) |
2376 | return failure(); |
2377 | return ast::LetStmt::create(ctx, loc, varDecl: *varDecl); |
2378 | } |
2379 | |
2380 | FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() { |
2381 | if (parserContext == ParserContext::Constraint) |
2382 | return emitError(msg: "`replace` cannot be used within a Constraint" ); |
2383 | SMRange loc = curToken.getLoc(); |
2384 | consumeToken(kind: Token::kw_replace); |
2385 | |
2386 | // Parse the root operation expression. |
2387 | FailureOr<ast::Expr *> rootOp = parseExpr(); |
2388 | if (failed(result: rootOp)) |
2389 | return failure(); |
2390 | |
2391 | if (failed( |
2392 | result: parseToken(kind: Token::kw_with, msg: "expected `with` after root operation" ))) |
2393 | return failure(); |
2394 | |
2395 | // The replacement portion of this statement is within a rewrite context. |
2396 | llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite); |
2397 | |
2398 | // Parse the replacement values. |
2399 | SmallVector<ast::Expr *> replValues; |
2400 | if (consumeIf(kind: Token::l_paren)) { |
2401 | if (consumeIf(kind: Token::r_paren)) { |
2402 | return emitError( |
2403 | loc, msg: "expected at least one replacement value, consider using " |
2404 | "`erase` if no replacement values are desired" ); |
2405 | } |
2406 | |
2407 | do { |
2408 | FailureOr<ast::Expr *> replExpr = parseExpr(); |
2409 | if (failed(result: replExpr)) |
2410 | return failure(); |
2411 | replValues.emplace_back(Args&: *replExpr); |
2412 | } while (consumeIf(kind: Token::comma)); |
2413 | |
2414 | if (failed(result: parseToken(kind: Token::r_paren, |
2415 | msg: "expected `)` after replacement values" ))) |
2416 | return failure(); |
2417 | } else { |
2418 | // Handle replacement with an operation uniquely, as the replacement |
2419 | // operation supports type inferrence from the root operation. |
2420 | FailureOr<ast::Expr *> replExpr; |
2421 | if (curToken.is(k: Token::kw_op)) |
2422 | replExpr = parseOperationExpr(inputResultTypeContext: OpResultTypeContext::Replacement); |
2423 | else |
2424 | replExpr = parseExpr(); |
2425 | if (failed(result: replExpr)) |
2426 | return failure(); |
2427 | replValues.emplace_back(Args&: *replExpr); |
2428 | } |
2429 | |
2430 | return createReplaceStmt(loc, rootOp: *rootOp, replValues); |
2431 | } |
2432 | |
2433 | FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() { |
2434 | SMRange loc = curToken.getLoc(); |
2435 | consumeToken(kind: Token::kw_return); |
2436 | |
2437 | // Parse the result value. |
2438 | FailureOr<ast::Expr *> resultExpr = parseExpr(); |
2439 | if (failed(result: resultExpr)) |
2440 | return failure(); |
2441 | |
2442 | return ast::ReturnStmt::create(ctx, loc, resultExpr: *resultExpr); |
2443 | } |
2444 | |
2445 | FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() { |
2446 | if (parserContext == ParserContext::Constraint) |
2447 | return emitError(msg: "`rewrite` cannot be used within a Constraint" ); |
2448 | SMRange loc = curToken.getLoc(); |
2449 | consumeToken(kind: Token::kw_rewrite); |
2450 | |
2451 | // Parse the root operation. |
2452 | FailureOr<ast::Expr *> rootOp = parseExpr(); |
2453 | if (failed(result: rootOp)) |
2454 | return failure(); |
2455 | |
2456 | if (failed(result: parseToken(kind: Token::kw_with, msg: "expected `with` before rewrite body" ))) |
2457 | return failure(); |
2458 | |
2459 | if (curToken.isNot(k: Token::l_brace)) |
2460 | return emitError(msg: "expected `{` to start rewrite body" ); |
2461 | |
2462 | // The rewrite body of this statement is within a rewrite context. |
2463 | llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite); |
2464 | |
2465 | FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt(); |
2466 | if (failed(result: rewriteBody)) |
2467 | return failure(); |
2468 | |
2469 | // Verify the rewrite body. |
2470 | for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { |
2471 | if (isa<ast::ReturnStmt>(Val: stmt)) { |
2472 | return emitError(loc: stmt->getLoc(), |
2473 | msg: "`return` statements are only permitted within a " |
2474 | "`Constraint` or `Rewrite` body" ); |
2475 | } |
2476 | } |
2477 | |
2478 | return createRewriteStmt(loc, rootOp: *rootOp, rewriteBody: *rewriteBody); |
2479 | } |
2480 | |
2481 | //===----------------------------------------------------------------------===// |
2482 | // Creation+Analysis |
2483 | //===----------------------------------------------------------------------===// |
2484 | |
2485 | //===----------------------------------------------------------------------===// |
2486 | // Decls |
2487 | |
2488 | ast::CallableDecl *Parser::(ast::Node *node) { |
2489 | // Unwrap reference expressions. |
2490 | if (auto *init = dyn_cast<ast::DeclRefExpr>(Val: node)) |
2491 | node = init->getDecl(); |
2492 | return dyn_cast<ast::CallableDecl>(Val: node); |
2493 | } |
2494 | |
2495 | FailureOr<ast::PatternDecl *> |
2496 | Parser::createPatternDecl(SMRange loc, const ast::Name *name, |
2497 | const ParsedPatternMetadata &metadata, |
2498 | ast::CompoundStmt *body) { |
2499 | return ast::PatternDecl::create(ctx, location: loc, name, benefit: metadata.benefit, |
2500 | hasBoundedRecursion: metadata.hasBoundedRecursion, body); |
2501 | } |
2502 | |
2503 | ast::Type Parser::createUserConstraintRewriteResultType( |
2504 | ArrayRef<ast::VariableDecl *> results) { |
2505 | // Single result decls use the type of the single result. |
2506 | if (results.size() == 1) |
2507 | return results[0]->getType(); |
2508 | |
2509 | // Multiple results use a tuple type, with the types and names grabbed from |
2510 | // the result variable decls. |
2511 | auto resultTypes = llvm::map_range( |
2512 | C&: results, F: [&](const auto *result) { return result->getType(); }); |
2513 | auto resultNames = llvm::map_range( |
2514 | C&: results, F: [&](const auto *result) { return result->getName().getName(); }); |
2515 | return ast::TupleType::get(context&: ctx, elementTypes: llvm::to_vector(Range&: resultTypes), |
2516 | elementNames: llvm::to_vector(Range&: resultNames)); |
2517 | } |
2518 | |
2519 | template <typename T> |
2520 | FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl( |
2521 | const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments, |
2522 | ArrayRef<ast::VariableDecl *> results, ast::Type resultType, |
2523 | ast::CompoundStmt *body) { |
2524 | if (!body->getChildren().empty()) { |
2525 | if (auto *retStmt = dyn_cast<ast::ReturnStmt>(Val: body->getChildren().back())) { |
2526 | ast::Expr *resultExpr = retStmt->getResultExpr(); |
2527 | |
2528 | // Process the result of the decl. If no explicit signature results |
2529 | // were provided, check for return type inference. Otherwise, check that |
2530 | // the return expression can be converted to the expected type. |
2531 | if (results.empty()) |
2532 | resultType = resultExpr->getType(); |
2533 | else if (failed(result: convertExpressionTo(expr&: resultExpr, type: resultType))) |
2534 | return failure(); |
2535 | else |
2536 | retStmt->setResultExpr(resultExpr); |
2537 | } |
2538 | } |
2539 | return T::createPDLL(ctx, name, arguments, results, body, resultType); |
2540 | } |
2541 | |
2542 | FailureOr<ast::VariableDecl *> |
2543 | Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, |
2544 | ArrayRef<ast::ConstraintRef> constraints) { |
2545 | // The type of the variable, which is expected to be inferred by either a |
2546 | // constraint or an initializer expression. |
2547 | ast::Type type; |
2548 | if (failed(result: validateVariableConstraints(constraints, inferredType&: type))) |
2549 | return failure(); |
2550 | |
2551 | if (initializer) { |
2552 | // Update the variable type based on the initializer, or try to convert the |
2553 | // initializer to the existing type. |
2554 | if (!type) |
2555 | type = initializer->getType(); |
2556 | else if (ast::Type mergedType = type.refineWith(other: initializer->getType())) |
2557 | type = mergedType; |
2558 | else if (failed(result: convertExpressionTo(expr&: initializer, type))) |
2559 | return failure(); |
2560 | |
2561 | // Otherwise, if there is no initializer check that the type has already |
2562 | // been resolved from the constraint list. |
2563 | } else if (!type) { |
2564 | return emitErrorAndNote( |
2565 | loc, msg: "unable to infer type for variable `" + name + "`" , noteLoc: loc, |
2566 | note: "the type of a variable must be inferable from the constraint " |
2567 | "list or the initializer" ); |
2568 | } |
2569 | |
2570 | // Constraint types cannot be used when defining variables. |
2571 | if (type.isa<ast::ConstraintType, ast::RewriteType>()) { |
2572 | return emitError( |
2573 | loc, msg: llvm::formatv(Fmt: "unable to define variable of `{0}` type" , Vals&: type)); |
2574 | } |
2575 | |
2576 | // Try to define a variable with the given name. |
2577 | FailureOr<ast::VariableDecl *> varDecl = |
2578 | defineVariableDecl(name, nameLoc: loc, type, initExpr: initializer, constraints); |
2579 | if (failed(result: varDecl)) |
2580 | return failure(); |
2581 | |
2582 | return *varDecl; |
2583 | } |
2584 | |
2585 | FailureOr<ast::VariableDecl *> |
2586 | Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, |
2587 | const ast::ConstraintRef &constraint) { |
2588 | ast::Type argType; |
2589 | if (failed(result: validateVariableConstraint(ref: constraint, inferredType&: argType))) |
2590 | return failure(); |
2591 | return defineVariableDecl(name, nameLoc: loc, type: argType, constraints: constraint); |
2592 | } |
2593 | |
2594 | LogicalResult |
2595 | Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints, |
2596 | ast::Type &inferredType) { |
2597 | for (const ast::ConstraintRef &ref : constraints) |
2598 | if (failed(result: validateVariableConstraint(ref, inferredType))) |
2599 | return failure(); |
2600 | return success(); |
2601 | } |
2602 | |
2603 | LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, |
2604 | ast::Type &inferredType) { |
2605 | ast::Type constraintType; |
2606 | if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(Val: ref.constraint)) { |
2607 | if (const ast::Expr *typeExpr = cst->getTypeExpr()) { |
2608 | if (failed(result: validateTypeConstraintExpr(typeExpr))) |
2609 | return failure(); |
2610 | } |
2611 | constraintType = ast::AttributeType::get(context&: ctx); |
2612 | } else if (const auto *cst = |
2613 | dyn_cast<ast::OpConstraintDecl>(Val: ref.constraint)) { |
2614 | constraintType = ast::OperationType::get( |
2615 | context&: ctx, name: cst->getName(), odsOp: lookupODSOperation(opName: cst->getName())); |
2616 | } else if (isa<ast::TypeConstraintDecl>(Val: ref.constraint)) { |
2617 | constraintType = typeTy; |
2618 | } else if (isa<ast::TypeRangeConstraintDecl>(Val: ref.constraint)) { |
2619 | constraintType = typeRangeTy; |
2620 | } else if (const auto *cst = |
2621 | dyn_cast<ast::ValueConstraintDecl>(Val: ref.constraint)) { |
2622 | if (const ast::Expr *typeExpr = cst->getTypeExpr()) { |
2623 | if (failed(result: validateTypeConstraintExpr(typeExpr))) |
2624 | return failure(); |
2625 | } |
2626 | constraintType = valueTy; |
2627 | } else if (const auto *cst = |
2628 | dyn_cast<ast::ValueRangeConstraintDecl>(Val: ref.constraint)) { |
2629 | if (const ast::Expr *typeExpr = cst->getTypeExpr()) { |
2630 | if (failed(result: validateTypeRangeConstraintExpr(typeExpr))) |
2631 | return failure(); |
2632 | } |
2633 | constraintType = valueRangeTy; |
2634 | } else if (const auto *cst = |
2635 | dyn_cast<ast::UserConstraintDecl>(Val: ref.constraint)) { |
2636 | ArrayRef<ast::VariableDecl *> inputs = cst->getInputs(); |
2637 | if (inputs.size() != 1) { |
2638 | return emitErrorAndNote(loc: ref.referenceLoc, |
2639 | msg: "`Constraint`s applied via a variable constraint " |
2640 | "list must take a single input, but got " + |
2641 | Twine(inputs.size()), |
2642 | noteLoc: cst->getLoc(), |
2643 | note: "see definition of constraint here" ); |
2644 | } |
2645 | constraintType = inputs.front()->getType(); |
2646 | } else { |
2647 | llvm_unreachable("unknown constraint type" ); |
2648 | } |
2649 | |
2650 | // Check that the constraint type is compatible with the current inferred |
2651 | // type. |
2652 | if (!inferredType) { |
2653 | inferredType = constraintType; |
2654 | } else if (ast::Type mergedTy = inferredType.refineWith(other: constraintType)) { |
2655 | inferredType = mergedTy; |
2656 | } else { |
2657 | return emitError(loc: ref.referenceLoc, |
2658 | msg: llvm::formatv(Fmt: "constraint type `{0}` is incompatible " |
2659 | "with the previously inferred type `{1}`" , |
2660 | Vals&: constraintType, Vals&: inferredType)); |
2661 | } |
2662 | return success(); |
2663 | } |
2664 | |
2665 | LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { |
2666 | ast::Type typeExprType = typeExpr->getType(); |
2667 | if (typeExprType != typeTy) { |
2668 | return emitError(loc: typeExpr->getLoc(), |
2669 | msg: "expected expression of `Type` in type constraint" ); |
2670 | } |
2671 | return success(); |
2672 | } |
2673 | |
2674 | LogicalResult |
2675 | Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { |
2676 | ast::Type typeExprType = typeExpr->getType(); |
2677 | if (typeExprType != typeRangeTy) { |
2678 | return emitError(loc: typeExpr->getLoc(), |
2679 | msg: "expected expression of `TypeRange` in type constraint" ); |
2680 | } |
2681 | return success(); |
2682 | } |
2683 | |
2684 | //===----------------------------------------------------------------------===// |
2685 | // Exprs |
2686 | |
2687 | FailureOr<ast::CallExpr *> |
2688 | Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, |
2689 | MutableArrayRef<ast::Expr *> arguments, bool isNegated) { |
2690 | ast::Type parentType = parentExpr->getType(); |
2691 | |
2692 | ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parentExpr); |
2693 | if (!callableDecl) { |
2694 | return emitError(loc, |
2695 | msg: llvm::formatv(Fmt: "expected a reference to a callable " |
2696 | "`Constraint` or `Rewrite`, but got: `{0}`" , |
2697 | Vals&: parentType)); |
2698 | } |
2699 | if (parserContext == ParserContext::Rewrite) { |
2700 | if (isa<ast::UserConstraintDecl>(Val: callableDecl)) |
2701 | return emitError( |
2702 | loc, msg: "unable to invoke `Constraint` within a rewrite section" ); |
2703 | if (isNegated) |
2704 | return emitError(loc, msg: "unable to negate a Rewrite" ); |
2705 | } else { |
2706 | if (isa<ast::UserRewriteDecl>(Val: callableDecl)) |
2707 | return emitError(loc, |
2708 | msg: "unable to invoke `Rewrite` within a match section" ); |
2709 | if (isNegated && cast<ast::UserConstraintDecl>(Val: callableDecl)->getBody()) |
2710 | return emitError(loc, msg: "unable to negate non native constraints" ); |
2711 | } |
2712 | |
2713 | // Verify the arguments of the call. |
2714 | /// Handle size mismatch. |
2715 | ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs(); |
2716 | if (callArgs.size() != arguments.size()) { |
2717 | return emitErrorAndNote( |
2718 | loc, |
2719 | msg: llvm::formatv(Fmt: "invalid number of arguments for {0} call; expected " |
2720 | "{1}, but got {2}" , |
2721 | Vals: callableDecl->getCallableType(), Vals: callArgs.size(), |
2722 | Vals: arguments.size()), |
2723 | noteLoc: callableDecl->getLoc(), |
2724 | note: llvm::formatv(Fmt: "see the definition of {0} here" , |
2725 | Vals: callableDecl->getName()->getName())); |
2726 | } |
2727 | |
2728 | /// Handle argument type mismatch. |
2729 | auto attachDiagFn = [&](ast::Diagnostic &diag) { |
2730 | diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here" , |
2731 | Vals: callableDecl->getName()->getName()), |
2732 | noteLoc: callableDecl->getLoc()); |
2733 | }; |
2734 | for (auto it : llvm::zip(t&: callArgs, u&: arguments)) { |
2735 | if (failed(result: convertExpressionTo(expr&: std::get<1>(t&: it), type: std::get<0>(t&: it)->getType(), |
2736 | noteAttachFn: attachDiagFn))) |
2737 | return failure(); |
2738 | } |
2739 | |
2740 | return ast::CallExpr::create(ctx, loc, callable: parentExpr, arguments, |
2741 | resultType: callableDecl->getResultType(), isNegated); |
2742 | } |
2743 | |
2744 | FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc, |
2745 | ast::Decl *decl) { |
2746 | // Check the type of decl being referenced. |
2747 | ast::Type declType; |
2748 | if (isa<ast::ConstraintDecl>(Val: decl)) |
2749 | declType = ast::ConstraintType::get(context&: ctx); |
2750 | else if (isa<ast::UserRewriteDecl>(Val: decl)) |
2751 | declType = ast::RewriteType::get(context&: ctx); |
2752 | else if (auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl)) |
2753 | declType = varDecl->getType(); |
2754 | else |
2755 | return emitError(loc, msg: "invalid reference to `" + |
2756 | decl->getName()->getName() + "`" ); |
2757 | |
2758 | return ast::DeclRefExpr::create(ctx, loc, decl, type: declType); |
2759 | } |
2760 | |
2761 | FailureOr<ast::DeclRefExpr *> |
2762 | Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, |
2763 | ArrayRef<ast::ConstraintRef> constraints) { |
2764 | FailureOr<ast::VariableDecl *> decl = |
2765 | defineVariableDecl(name, nameLoc: loc, type, constraints); |
2766 | if (failed(result: decl)) |
2767 | return failure(); |
2768 | return ast::DeclRefExpr::create(ctx, loc, decl: *decl, type); |
2769 | } |
2770 | |
2771 | FailureOr<ast::MemberAccessExpr *> |
2772 | Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, |
2773 | SMRange loc) { |
2774 | // Validate the member name for the given parent expression. |
2775 | FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc); |
2776 | if (failed(result: memberType)) |
2777 | return failure(); |
2778 | |
2779 | return ast::MemberAccessExpr::create(ctx, loc, parentExpr, memberName: name, type: *memberType); |
2780 | } |
2781 | |
2782 | FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, |
2783 | StringRef name, SMRange loc) { |
2784 | ast::Type parentType = parentExpr->getType(); |
2785 | if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { |
2786 | if (name == ast::AllResultsMemberAccessExpr::getMemberName()) |
2787 | return valueRangeTy; |
2788 | |
2789 | // Verify member access based on the operation type. |
2790 | if (const ods::Operation *odsOp = opType.getODSOperation()) { |
2791 | auto results = odsOp->getResults(); |
2792 | |
2793 | // Handle indexed results. |
2794 | unsigned index = 0; |
2795 | if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) && |
2796 | index < results.size()) { |
2797 | return results[index].isVariadic() ? valueRangeTy : valueTy; |
2798 | } |
2799 | |
2800 | // Handle named results. |
2801 | const auto *it = llvm::find_if(Range&: results, P: [&](const auto &result) { |
2802 | return result.getName() == name; |
2803 | }); |
2804 | if (it != results.end()) |
2805 | return it->isVariadic() ? valueRangeTy : valueTy; |
2806 | } else if (llvm::isDigit(C: name[0])) { |
2807 | // Allow unchecked numeric indexing of the results of unregistered |
2808 | // operations. It returns a single value. |
2809 | return valueTy; |
2810 | } |
2811 | } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { |
2812 | // Handle indexed results. |
2813 | unsigned index = 0; |
2814 | if (llvm::isDigit(C: name[0]) && !name.getAsInteger(/*Radix=*/10, Result&: index) && |
2815 | index < tupleType.size()) { |
2816 | return tupleType.getElementTypes()[index]; |
2817 | } |
2818 | |
2819 | // Handle named results. |
2820 | auto elementNames = tupleType.getElementNames(); |
2821 | const auto *it = llvm::find(Range&: elementNames, Val: name); |
2822 | if (it != elementNames.end()) |
2823 | return tupleType.getElementTypes()[it - elementNames.begin()]; |
2824 | } |
2825 | return emitError( |
2826 | loc, |
2827 | msg: llvm::formatv(Fmt: "invalid member access `{0}` on expression of type `{1}`" , |
2828 | Vals&: name, Vals&: parentType)); |
2829 | } |
2830 | |
2831 | FailureOr<ast::OperationExpr *> Parser::createOperationExpr( |
2832 | SMRange loc, const ast::OpNameDecl *name, |
2833 | OpResultTypeContext resultTypeContext, |
2834 | SmallVectorImpl<ast::Expr *> &operands, |
2835 | MutableArrayRef<ast::NamedAttributeDecl *> attributes, |
2836 | SmallVectorImpl<ast::Expr *> &results) { |
2837 | std::optional<StringRef> opNameRef = name->getName(); |
2838 | const ods::Operation *odsOp = lookupODSOperation(opName: opNameRef); |
2839 | |
2840 | // Verify the inputs operands. |
2841 | if (failed(result: validateOperationOperands(loc, name: opNameRef, odsOp, operands))) |
2842 | return failure(); |
2843 | |
2844 | // Verify the attribute list. |
2845 | for (ast::NamedAttributeDecl *attr : attributes) { |
2846 | // Check for an attribute type, or a type awaiting resolution. |
2847 | ast::Type attrType = attr->getValue()->getType(); |
2848 | if (!attrType.isa<ast::AttributeType>()) { |
2849 | return emitError( |
2850 | loc: attr->getValue()->getLoc(), |
2851 | msg: llvm::formatv(Fmt: "expected `Attr` expression, but got `{0}`" , Vals&: attrType)); |
2852 | } |
2853 | } |
2854 | |
2855 | assert( |
2856 | (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) && |
2857 | "unexpected inferrence when results were explicitly specified" ); |
2858 | |
2859 | // If we aren't relying on type inferrence, or explicit results were provided, |
2860 | // validate them. |
2861 | if (resultTypeContext == OpResultTypeContext::Explicit) { |
2862 | if (failed(result: validateOperationResults(loc, name: opNameRef, odsOp, results))) |
2863 | return failure(); |
2864 | |
2865 | // Validate the use of interface based type inferrence for this operation. |
2866 | } else if (resultTypeContext == OpResultTypeContext::Interface) { |
2867 | assert(opNameRef && |
2868 | "expected valid operation name when inferring operation results" ); |
2869 | checkOperationResultTypeInferrence(loc, name: *opNameRef, odsOp); |
2870 | } |
2871 | |
2872 | return ast::OperationExpr::create(ctx, loc, odsOp, nameDecl: name, operands, resultTypes: results, |
2873 | attributes); |
2874 | } |
2875 | |
2876 | LogicalResult |
2877 | Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name, |
2878 | const ods::Operation *odsOp, |
2879 | SmallVectorImpl<ast::Expr *> &operands) { |
2880 | return validateOperationOperandsOrResults( |
2881 | groupName: "operand" , loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name, |
2882 | values&: operands, odsValues: odsOp ? odsOp->getOperands() : std::nullopt, singleTy: valueTy, |
2883 | rangeTy: valueRangeTy); |
2884 | } |
2885 | |
2886 | LogicalResult |
2887 | Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name, |
2888 | const ods::Operation *odsOp, |
2889 | SmallVectorImpl<ast::Expr *> &results) { |
2890 | return validateOperationOperandsOrResults( |
2891 | groupName: "result" , loc, odsOpLoc: odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name, |
2892 | values&: results, odsValues: odsOp ? odsOp->getResults() : std::nullopt, singleTy: typeTy, rangeTy: typeRangeTy); |
2893 | } |
2894 | |
2895 | void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName, |
2896 | const ods::Operation *odsOp) { |
2897 | // If the operation might not have inferrence support, emit a warning to the |
2898 | // user. We don't emit an error because the interface might be added to the |
2899 | // operation at runtime. It's rare, but it could still happen. We emit a |
2900 | // warning here instead. |
2901 | |
2902 | // Handle inferrence warnings for unknown operations. |
2903 | if (!odsOp) { |
2904 | ctx.getDiagEngine().emitWarning( |
2905 | loc, msg: llvm::formatv( |
2906 | Fmt: "operation result types are marked to be inferred, but " |
2907 | "`{0}` is unknown. Ensure that `{0}` supports zero " |
2908 | "results or implements `InferTypeOpInterface`. Include " |
2909 | "the ODS definition of this operation to remove this warning." , |
2910 | Vals&: opName)); |
2911 | return; |
2912 | } |
2913 | |
2914 | // Handle inferrence warnings for known operations that expected at least one |
2915 | // result, but don't have inference support. An elided results list can mean |
2916 | // "zero-results", and we don't want to warn when that is the expected |
2917 | // behavior. |
2918 | bool requiresInferrence = |
2919 | llvm::any_of(Range: odsOp->getResults(), P: [](const ods::OperandOrResult &result) { |
2920 | return !result.isVariableLength(); |
2921 | }); |
2922 | if (requiresInferrence && !odsOp->hasResultTypeInferrence()) { |
2923 | ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning( |
2924 | loc, |
2925 | msg: llvm::formatv(Fmt: "operation result types are marked to be inferred, but " |
2926 | "`{0}` does not provide an implementation of " |
2927 | "`InferTypeOpInterface`. Ensure that `{0}` attaches " |
2928 | "`InferTypeOpInterface` at runtime, or add support to " |
2929 | "the ODS definition to remove this warning." , |
2930 | Vals&: opName)); |
2931 | diag->attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here" , Vals&: opName), |
2932 | noteLoc: odsOp->getLoc()); |
2933 | return; |
2934 | } |
2935 | } |
2936 | |
2937 | LogicalResult Parser::validateOperationOperandsOrResults( |
2938 | StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc, |
2939 | std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values, |
2940 | ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy, |
2941 | ast::RangeType rangeTy) { |
2942 | // All operation types accept a single range parameter. |
2943 | if (values.size() == 1) { |
2944 | if (failed(result: convertExpressionTo(expr&: values[0], type: rangeTy))) |
2945 | return failure(); |
2946 | return success(); |
2947 | } |
2948 | |
2949 | /// If the operation has ODS information, we can more accurately verify the |
2950 | /// values. |
2951 | if (odsOpLoc) { |
2952 | auto emitSizeMismatchError = [&] { |
2953 | return emitErrorAndNote( |
2954 | loc, |
2955 | msg: llvm::formatv(Fmt: "invalid number of {0} groups for `{1}`; expected " |
2956 | "{2}, but got {3}" , |
2957 | Vals&: groupName, Vals&: *name, Vals: odsValues.size(), Vals: values.size()), |
2958 | noteLoc: *odsOpLoc, note: llvm::formatv(Fmt: "see the definition of `{0}` here" , Vals&: *name)); |
2959 | }; |
2960 | |
2961 | // Handle the case where no values were provided. |
2962 | if (values.empty()) { |
2963 | // If we don't expect any on the ODS side, we are done. |
2964 | if (odsValues.empty()) |
2965 | return success(); |
2966 | |
2967 | // If we do, check if we actually need to provide values (i.e. if any of |
2968 | // the values are actually required). |
2969 | unsigned numVariadic = 0; |
2970 | for (const auto &odsValue : odsValues) { |
2971 | if (!odsValue.isVariableLength()) |
2972 | return emitSizeMismatchError(); |
2973 | ++numVariadic; |
2974 | } |
2975 | |
2976 | // If we are in a non-rewrite context, we don't need to do anything more. |
2977 | // Zero-values is a valid constraint on the operation. |
2978 | if (parserContext != ParserContext::Rewrite) |
2979 | return success(); |
2980 | |
2981 | // Otherwise, when in a rewrite we may need to provide values to match the |
2982 | // ODS signature of the operation to create. |
2983 | |
2984 | // If we only have one variadic value, just use an empty list. |
2985 | if (numVariadic == 1) |
2986 | return success(); |
2987 | |
2988 | // Otherwise, create dummy values for each of the entries so that we |
2989 | // adhere to the ODS signature. |
2990 | for (unsigned i = 0, e = odsValues.size(); i < e; ++i) { |
2991 | values.push_back(Elt: ast::RangeExpr::create( |
2992 | ctx, loc, /*elements=*/std::nullopt, type: rangeTy)); |
2993 | } |
2994 | return success(); |
2995 | } |
2996 | |
2997 | // Verify that the number of values provided matches the number of value |
2998 | // groups ODS expects. |
2999 | if (odsValues.size() != values.size()) |
3000 | return emitSizeMismatchError(); |
3001 | |
3002 | auto diagFn = [&](ast::Diagnostic &diag) { |
3003 | diag.attachNote(msg: llvm::formatv(Fmt: "see the definition of `{0}` here" , Vals&: *name), |
3004 | noteLoc: *odsOpLoc); |
3005 | }; |
3006 | for (unsigned i = 0, e = values.size(); i < e; ++i) { |
3007 | ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; |
3008 | if (failed(result: convertExpressionTo(expr&: values[i], type: expectedType, noteAttachFn: diagFn))) |
3009 | return failure(); |
3010 | } |
3011 | return success(); |
3012 | } |
3013 | |
3014 | // Otherwise, accept the value groups as they have been defined and just |
3015 | // ensure they are one of the expected types. |
3016 | for (ast::Expr *&valueExpr : values) { |
3017 | ast::Type valueExprType = valueExpr->getType(); |
3018 | |
3019 | // Check if this is one of the expected types. |
3020 | if (valueExprType == rangeTy || valueExprType == singleTy) |
3021 | continue; |
3022 | |
3023 | // If the operand is an Operation, allow converting to a Value or |
3024 | // ValueRange. This situations arises quite often with nested operation |
3025 | // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)` |
3026 | if (singleTy == valueTy) { |
3027 | if (valueExprType.isa<ast::OperationType>()) { |
3028 | valueExpr = convertOpToValue(opExpr: valueExpr); |
3029 | continue; |
3030 | } |
3031 | } |
3032 | |
3033 | // Otherwise, try to convert the expression to a range. |
3034 | if (succeeded(result: convertExpressionTo(expr&: valueExpr, type: rangeTy))) |
3035 | continue; |
3036 | |
3037 | return emitError( |
3038 | loc: valueExpr->getLoc(), |
3039 | msg: llvm::formatv( |
3040 | Fmt: "expected `{0}` or `{1}` convertible expression, but got `{2}`" , |
3041 | Vals&: singleTy, Vals&: rangeTy, Vals&: valueExprType)); |
3042 | } |
3043 | return success(); |
3044 | } |
3045 | |
3046 | FailureOr<ast::TupleExpr *> |
3047 | Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements, |
3048 | ArrayRef<StringRef> elementNames) { |
3049 | for (const ast::Expr *element : elements) { |
3050 | ast::Type eleTy = element->getType(); |
3051 | if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) { |
3052 | return emitError( |
3053 | loc: element->getLoc(), |
3054 | msg: llvm::formatv(Fmt: "unable to build a tuple with `{0}` element" , Vals&: eleTy)); |
3055 | } |
3056 | } |
3057 | return ast::TupleExpr::create(ctx, loc, elements, elementNames); |
3058 | } |
3059 | |
3060 | //===----------------------------------------------------------------------===// |
3061 | // Stmts |
3062 | |
3063 | FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc, |
3064 | ast::Expr *rootOp) { |
3065 | // Check that root is an Operation. |
3066 | ast::Type rootType = rootOp->getType(); |
3067 | if (!rootType.isa<ast::OperationType>()) |
3068 | return emitError(loc: rootOp->getLoc(), msg: "expected `Op` expression" ); |
3069 | |
3070 | return ast::EraseStmt::create(ctx, loc, rootOp); |
3071 | } |
3072 | |
3073 | FailureOr<ast::ReplaceStmt *> |
3074 | Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp, |
3075 | MutableArrayRef<ast::Expr *> replValues) { |
3076 | // Check that root is an Operation. |
3077 | ast::Type rootType = rootOp->getType(); |
3078 | if (!rootType.isa<ast::OperationType>()) { |
3079 | return emitError( |
3080 | loc: rootOp->getLoc(), |
3081 | msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`" , Vals&: rootType)); |
3082 | } |
3083 | |
3084 | // If there are multiple replacement values, we implicitly convert any Op |
3085 | // expressions to the value form. |
3086 | bool shouldConvertOpToValues = replValues.size() > 1; |
3087 | for (ast::Expr *&replExpr : replValues) { |
3088 | ast::Type replType = replExpr->getType(); |
3089 | |
3090 | // Check that replExpr is an Operation, Value, or ValueRange. |
3091 | if (replType.isa<ast::OperationType>()) { |
3092 | if (shouldConvertOpToValues) |
3093 | replExpr = convertOpToValue(opExpr: replExpr); |
3094 | continue; |
3095 | } |
3096 | |
3097 | if (replType != valueTy && replType != valueRangeTy) { |
3098 | return emitError(loc: replExpr->getLoc(), |
3099 | msg: llvm::formatv(Fmt: "expected `Op`, `Value` or `ValueRange` " |
3100 | "expression, but got `{0}`" , |
3101 | Vals&: replType)); |
3102 | } |
3103 | } |
3104 | |
3105 | return ast::ReplaceStmt::create(ctx, loc, rootOp, replExprs: replValues); |
3106 | } |
3107 | |
3108 | FailureOr<ast::RewriteStmt *> |
3109 | Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp, |
3110 | ast::CompoundStmt *rewriteBody) { |
3111 | // Check that root is an Operation. |
3112 | ast::Type rootType = rootOp->getType(); |
3113 | if (!rootType.isa<ast::OperationType>()) { |
3114 | return emitError( |
3115 | loc: rootOp->getLoc(), |
3116 | msg: llvm::formatv(Fmt: "expected `Op` expression, but got `{0}`" , Vals&: rootType)); |
3117 | } |
3118 | |
3119 | return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); |
3120 | } |
3121 | |
3122 | //===----------------------------------------------------------------------===// |
3123 | // Code Completion |
3124 | //===----------------------------------------------------------------------===// |
3125 | |
3126 | LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) { |
3127 | ast::Type parentType = parentExpr->getType(); |
3128 | if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) |
3129 | codeCompleteContext->codeCompleteOperationMemberAccess(opType); |
3130 | else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>()) |
3131 | codeCompleteContext->codeCompleteTupleMemberAccess(tupleType); |
3132 | return failure(); |
3133 | } |
3134 | |
3135 | LogicalResult |
3136 | Parser::codeCompleteAttributeName(std::optional<StringRef> opName) { |
3137 | if (opName) |
3138 | codeCompleteContext->codeCompleteOperationAttributeName(opName: *opName); |
3139 | return failure(); |
3140 | } |
3141 | |
3142 | LogicalResult |
3143 | Parser::codeCompleteConstraintName(ast::Type inferredType, |
3144 | bool allowInlineTypeConstraints) { |
3145 | codeCompleteContext->codeCompleteConstraintName( |
3146 | currentType: inferredType, allowInlineTypeConstraints, scope: curDeclScope); |
3147 | return failure(); |
3148 | } |
3149 | |
3150 | LogicalResult Parser::codeCompleteDialectName() { |
3151 | codeCompleteContext->codeCompleteDialectName(); |
3152 | return failure(); |
3153 | } |
3154 | |
3155 | LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) { |
3156 | codeCompleteContext->codeCompleteOperationName(dialectName); |
3157 | return failure(); |
3158 | } |
3159 | |
3160 | LogicalResult Parser::codeCompletePatternMetadata() { |
3161 | codeCompleteContext->codeCompletePatternMetadata(); |
3162 | return failure(); |
3163 | } |
3164 | |
3165 | LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) { |
3166 | codeCompleteContext->codeCompleteIncludeFilename(curPath); |
3167 | return failure(); |
3168 | } |
3169 | |
3170 | void Parser::codeCompleteCallSignature(ast::Node *parent, |
3171 | unsigned currentNumArgs) { |
3172 | ast::CallableDecl *callableDecl = tryExtractCallableDecl(node: parent); |
3173 | if (!callableDecl) |
3174 | return; |
3175 | |
3176 | codeCompleteContext->codeCompleteCallSignature(callable: callableDecl, currentNumArgs); |
3177 | } |
3178 | |
3179 | void Parser::codeCompleteOperationOperandsSignature( |
3180 | std::optional<StringRef> opName, unsigned currentNumOperands) { |
3181 | codeCompleteContext->codeCompleteOperationOperandsSignature( |
3182 | opName, currentNumOperands); |
3183 | } |
3184 | |
3185 | void Parser::codeCompleteOperationResultsSignature( |
3186 | std::optional<StringRef> opName, unsigned currentNumResults) { |
3187 | codeCompleteContext->codeCompleteOperationResultsSignature(opName, |
3188 | currentNumResults); |
3189 | } |
3190 | |
3191 | //===----------------------------------------------------------------------===// |
3192 | // Parser |
3193 | //===----------------------------------------------------------------------===// |
3194 | |
3195 | FailureOr<ast::Module *> |
3196 | mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, |
3197 | bool enableDocumentation, |
3198 | CodeCompleteContext *codeCompleteContext) { |
3199 | Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext); |
3200 | return parser.parseModule(); |
3201 | } |
3202 | |