1 | //===- Nodes.h --------------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_TOOLS_PDLL_AST_NODES_H_ |
10 | #define MLIR_TOOLS_PDLL_AST_NODES_H_ |
11 | |
12 | #include "mlir/Support/LLVM.h" |
13 | #include "mlir/Tools/PDLL/AST/Types.h" |
14 | #include "llvm/ADT/StringMap.h" |
15 | #include "llvm/ADT/StringRef.h" |
16 | #include "llvm/Support/SMLoc.h" |
17 | #include "llvm/Support/SourceMgr.h" |
18 | #include "llvm/Support/TrailingObjects.h" |
19 | #include <optional> |
20 | |
21 | namespace mlir { |
22 | namespace pdll { |
23 | namespace ast { |
24 | class Context; |
25 | class Decl; |
26 | class Expr; |
27 | class NamedAttributeDecl; |
28 | class OpNameDecl; |
29 | class VariableDecl; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // Name |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | /// This class provides a convenient API for interacting with source names. It |
36 | /// contains a string name as well as the source location for that name. |
37 | struct Name { |
38 | static const Name &create(Context &ctx, StringRef name, SMRange location); |
39 | |
40 | /// Return the raw string name. |
41 | StringRef getName() const { return name; } |
42 | |
43 | /// Get the location of this name. |
44 | SMRange getLoc() const { return location; } |
45 | |
46 | private: |
47 | Name() = delete; |
48 | Name(const Name &) = delete; |
49 | Name &operator=(const Name &) = delete; |
50 | Name(StringRef name, SMRange location) : name(name), location(location) {} |
51 | |
52 | /// The string name of the decl. |
53 | StringRef name; |
54 | /// The location of the decl name. |
55 | SMRange location; |
56 | }; |
57 | |
58 | //===----------------------------------------------------------------------===// |
59 | // DeclScope |
60 | //===----------------------------------------------------------------------===// |
61 | |
62 | /// This class represents a scope for named AST decls. A scope determines the |
63 | /// visibility and lifetime of a named declaration. |
64 | class DeclScope { |
65 | public: |
66 | /// Create a new scope with an optional parent scope. |
67 | DeclScope(DeclScope *parent = nullptr) : parent(parent) {} |
68 | |
69 | /// Return the parent scope of this scope, or nullptr if there is no parent. |
70 | DeclScope *getParentScope() { return parent; } |
71 | const DeclScope *getParentScope() const { return parent; } |
72 | |
73 | /// Return all of the decls within this scope. |
74 | auto getDecls() const { return llvm::make_second_range(c: decls); } |
75 | |
76 | /// Add a new decl to the scope. |
77 | void add(Decl *decl); |
78 | |
79 | /// Lookup a decl with the given name starting from this scope. Returns |
80 | /// nullptr if no decl could be found. |
81 | Decl *lookup(StringRef name); |
82 | template <typename T> |
83 | T *lookup(StringRef name) { |
84 | return dyn_cast_or_null<T>(lookup(name)); |
85 | } |
86 | const Decl *lookup(StringRef name) const { |
87 | return const_cast<DeclScope *>(this)->lookup(name); |
88 | } |
89 | template <typename T> |
90 | const T *lookup(StringRef name) const { |
91 | return dyn_cast_or_null<T>(lookup(name)); |
92 | } |
93 | |
94 | private: |
95 | /// The parent scope, or null if this is a top-level scope. |
96 | DeclScope *parent; |
97 | /// The decls defined within this scope. |
98 | llvm::StringMap<Decl *> decls; |
99 | }; |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // Node |
103 | //===----------------------------------------------------------------------===// |
104 | |
105 | /// This class represents a base AST node. All AST nodes are derived from this |
106 | /// class, and it contains many of the base functionality for interacting with |
107 | /// nodes. |
108 | class Node { |
109 | public: |
110 | /// This CRTP class provides several utilies when defining new AST nodes. |
111 | template <typename T, typename BaseT> |
112 | class NodeBase : public BaseT { |
113 | public: |
114 | using Base = NodeBase<T, BaseT>; |
115 | |
116 | /// Provide type casting support. |
117 | static bool classof(const Node *node) { |
118 | return node->getTypeID() == TypeID::get<T>(); |
119 | } |
120 | |
121 | protected: |
122 | template <typename... Args> |
123 | explicit NodeBase(SMRange loc, Args &&...args) |
124 | : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {} |
125 | }; |
126 | |
127 | /// Return the type identifier of this node. |
128 | TypeID getTypeID() const { return typeID; } |
129 | |
130 | /// Return the location of this node. |
131 | SMRange getLoc() const { return loc; } |
132 | |
133 | /// Print this node to the given stream. |
134 | void print(raw_ostream &os) const; |
135 | |
136 | /// Walk all of the nodes including, and nested under, this node in pre-order. |
137 | void walk(function_ref<void(const Node *)> walkFn) const; |
138 | template <typename WalkFnT, typename ArgT = typename llvm::function_traits< |
139 | WalkFnT>::template arg_t<0>> |
140 | std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value> |
141 | walk(WalkFnT &&walkFn) const { |
142 | walk([&](const Node *node) { |
143 | if (const ArgT *derivedNode = dyn_cast<ArgT>(node)) |
144 | walkFn(derivedNode); |
145 | }); |
146 | } |
147 | |
148 | protected: |
149 | Node(TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {} |
150 | |
151 | private: |
152 | /// A unique type identifier for this node. |
153 | TypeID typeID; |
154 | |
155 | /// The location of this node. |
156 | SMRange loc; |
157 | }; |
158 | |
159 | //===----------------------------------------------------------------------===// |
160 | // Stmt |
161 | //===----------------------------------------------------------------------===// |
162 | |
163 | /// This class represents a base AST Statement node. |
164 | class Stmt : public Node { |
165 | public: |
166 | using Node::Node; |
167 | |
168 | /// Provide type casting support. |
169 | static bool classof(const Node *node); |
170 | }; |
171 | |
172 | //===----------------------------------------------------------------------===// |
173 | // CompoundStmt |
174 | //===----------------------------------------------------------------------===// |
175 | |
176 | /// This statement represents a compound statement, which contains a collection |
177 | /// of other statements. |
178 | class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>, |
179 | private llvm::TrailingObjects<CompoundStmt, Stmt *> { |
180 | public: |
181 | static CompoundStmt *create(Context &ctx, SMRange location, |
182 | ArrayRef<Stmt *> children); |
183 | |
184 | /// Return the children of this compound statement. |
185 | MutableArrayRef<Stmt *> getChildren() { |
186 | return {getTrailingObjects<Stmt *>(), numChildren}; |
187 | } |
188 | ArrayRef<Stmt *> getChildren() const { |
189 | return const_cast<CompoundStmt *>(this)->getChildren(); |
190 | } |
191 | ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); } |
192 | ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); } |
193 | |
194 | private: |
195 | CompoundStmt(SMRange location, unsigned numChildren) |
196 | : Base(location), numChildren(numChildren) {} |
197 | |
198 | /// The number of held children statements. |
199 | unsigned numChildren; |
200 | |
201 | // Allow access to various privates. |
202 | friend class llvm::TrailingObjects<CompoundStmt, Stmt *>; |
203 | }; |
204 | |
205 | //===----------------------------------------------------------------------===// |
206 | // LetStmt |
207 | //===----------------------------------------------------------------------===// |
208 | |
209 | /// This statement represents a `let` statement in PDLL. This statement is used |
210 | /// to define variables. |
211 | class LetStmt final : public Node::NodeBase<LetStmt, Stmt> { |
212 | public: |
213 | static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl); |
214 | |
215 | /// Return the variable defined by this statement. |
216 | VariableDecl *getVarDecl() const { return varDecl; } |
217 | |
218 | private: |
219 | LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {} |
220 | |
221 | /// The variable defined by this statement. |
222 | VariableDecl *varDecl; |
223 | }; |
224 | |
225 | //===----------------------------------------------------------------------===// |
226 | // OpRewriteStmt |
227 | //===----------------------------------------------------------------------===// |
228 | |
229 | /// This class represents a base operation rewrite statement. Operation rewrite |
230 | /// statements perform a set of transformations on a given root operation. |
231 | class OpRewriteStmt : public Stmt { |
232 | public: |
233 | /// Provide type casting support. |
234 | static bool classof(const Node *node); |
235 | |
236 | /// Return the root operation of this rewrite. |
237 | Expr *getRootOpExpr() const { return rootOp; } |
238 | |
239 | protected: |
240 | OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp) |
241 | : Stmt(typeID, loc), rootOp(rootOp) {} |
242 | |
243 | protected: |
244 | /// The root operation being rewritten. |
245 | Expr *rootOp; |
246 | }; |
247 | |
248 | //===----------------------------------------------------------------------===// |
249 | // EraseStmt |
250 | |
251 | /// This statement represents the `erase` statement in PDLL. This statement |
252 | /// erases the given root operation, corresponding roughly to the |
253 | /// PatternRewriter::eraseOp API. |
254 | class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> { |
255 | public: |
256 | static EraseStmt *create(Context &ctx, SMRange loc, Expr *rootOp); |
257 | |
258 | private: |
259 | EraseStmt(SMRange loc, Expr *rootOp) : Base(loc, rootOp) {} |
260 | }; |
261 | |
262 | //===----------------------------------------------------------------------===// |
263 | // ReplaceStmt |
264 | |
265 | /// This statement represents the `replace` statement in PDLL. This statement |
266 | /// replace the given root operation with a set of values, corresponding roughly |
267 | /// to the PatternRewriter::replaceOp API. |
268 | class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>, |
269 | private llvm::TrailingObjects<ReplaceStmt, Expr *> { |
270 | public: |
271 | static ReplaceStmt *create(Context &ctx, SMRange loc, Expr *rootOp, |
272 | ArrayRef<Expr *> replExprs); |
273 | |
274 | /// Return the replacement values of this statement. |
275 | MutableArrayRef<Expr *> getReplExprs() { |
276 | return {getTrailingObjects<Expr *>(), numReplExprs}; |
277 | } |
278 | ArrayRef<Expr *> getReplExprs() const { |
279 | return const_cast<ReplaceStmt *>(this)->getReplExprs(); |
280 | } |
281 | |
282 | private: |
283 | ReplaceStmt(SMRange loc, Expr *rootOp, unsigned numReplExprs) |
284 | : Base(loc, rootOp), numReplExprs(numReplExprs) {} |
285 | |
286 | /// The number of replacement values within this statement. |
287 | unsigned numReplExprs; |
288 | |
289 | /// TrailingObject utilities. |
290 | friend class llvm::TrailingObjects<ReplaceStmt, Expr *>; |
291 | }; |
292 | |
293 | //===----------------------------------------------------------------------===// |
294 | // RewriteStmt |
295 | |
296 | /// This statement represents an operation rewrite that contains a block of |
297 | /// nested rewrite commands. This allows for building more complex operation |
298 | /// rewrites that span across multiple statements, which may be unconnected. |
299 | class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> { |
300 | public: |
301 | static RewriteStmt *create(Context &ctx, SMRange loc, Expr *rootOp, |
302 | CompoundStmt *rewriteBody); |
303 | |
304 | /// Return the compound rewrite body. |
305 | CompoundStmt *getRewriteBody() const { return rewriteBody; } |
306 | |
307 | private: |
308 | RewriteStmt(SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody) |
309 | : Base(loc, rootOp), rewriteBody(rewriteBody) {} |
310 | |
311 | /// The body of nested rewriters within this statement. |
312 | CompoundStmt *rewriteBody; |
313 | }; |
314 | |
315 | //===----------------------------------------------------------------------===// |
316 | // ReturnStmt |
317 | //===----------------------------------------------------------------------===// |
318 | |
319 | /// This statement represents a return from a "callable" like decl, e.g. a |
320 | /// Constraint or a Rewrite. |
321 | class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> { |
322 | public: |
323 | static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr); |
324 | |
325 | /// Return the result expression of this statement. |
326 | Expr *getResultExpr() { return resultExpr; } |
327 | const Expr *getResultExpr() const { return resultExpr; } |
328 | |
329 | /// Set the result expression of this statement. |
330 | void setResultExpr(Expr *expr) { resultExpr = expr; } |
331 | |
332 | private: |
333 | ReturnStmt(SMRange loc, Expr *resultExpr) |
334 | : Base(loc), resultExpr(resultExpr) {} |
335 | |
336 | // The result expression of this statement. |
337 | Expr *resultExpr; |
338 | }; |
339 | |
340 | //===----------------------------------------------------------------------===// |
341 | // Expr |
342 | //===----------------------------------------------------------------------===// |
343 | |
344 | /// This class represents a base AST Expression node. |
345 | class Expr : public Stmt { |
346 | public: |
347 | /// Return the type of this expression. |
348 | Type getType() const { return type; } |
349 | |
350 | /// Provide type casting support. |
351 | static bool classof(const Node *node); |
352 | |
353 | protected: |
354 | Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {} |
355 | |
356 | private: |
357 | /// The type of this expression. |
358 | Type type; |
359 | }; |
360 | |
361 | //===----------------------------------------------------------------------===// |
362 | // AttributeExpr |
363 | //===----------------------------------------------------------------------===// |
364 | |
365 | /// This expression represents a literal MLIR Attribute, and contains the |
366 | /// textual assembly format of that attribute. |
367 | class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> { |
368 | public: |
369 | static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value); |
370 | |
371 | /// Get the raw value of this expression. This is the textual assembly format |
372 | /// of the MLIR Attribute. |
373 | StringRef getValue() const { return value; } |
374 | |
375 | private: |
376 | AttributeExpr(Context &ctx, SMRange loc, StringRef value) |
377 | : Base(loc, AttributeType::get(context&: ctx)), value(value) {} |
378 | |
379 | /// The value referenced by this expression. |
380 | StringRef value; |
381 | }; |
382 | |
383 | //===----------------------------------------------------------------------===// |
384 | // CallExpr |
385 | //===----------------------------------------------------------------------===// |
386 | |
387 | /// This expression represents a call to a decl, such as a |
388 | /// UserConstraintDecl/UserRewriteDecl. |
389 | class CallExpr final : public Node::NodeBase<CallExpr, Expr>, |
390 | private llvm::TrailingObjects<CallExpr, Expr *> { |
391 | public: |
392 | static CallExpr *create(Context &ctx, SMRange loc, Expr *callable, |
393 | ArrayRef<Expr *> arguments, Type resultType, |
394 | bool isNegated = false); |
395 | |
396 | /// Return the callable of this call. |
397 | Expr *getCallableExpr() const { return callable; } |
398 | |
399 | /// Return the arguments of this call. |
400 | MutableArrayRef<Expr *> getArguments() { |
401 | return {getTrailingObjects<Expr *>(), numArgs}; |
402 | } |
403 | ArrayRef<Expr *> getArguments() const { |
404 | return const_cast<CallExpr *>(this)->getArguments(); |
405 | } |
406 | |
407 | /// Returns whether the result of this call is to be negated. |
408 | bool getIsNegated() const { return isNegated; } |
409 | |
410 | private: |
411 | CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs, |
412 | bool isNegated) |
413 | : Base(loc, type), callable(callable), numArgs(numArgs), |
414 | isNegated(isNegated) {} |
415 | |
416 | /// The callable of this call. |
417 | Expr *callable; |
418 | |
419 | /// The number of arguments of the call. |
420 | unsigned numArgs; |
421 | |
422 | /// TrailingObject utilities. |
423 | friend llvm::TrailingObjects<CallExpr, Expr *>; |
424 | |
425 | // Is the result of this call to be negated. |
426 | bool isNegated; |
427 | }; |
428 | |
429 | //===----------------------------------------------------------------------===// |
430 | // DeclRefExpr |
431 | //===----------------------------------------------------------------------===// |
432 | |
433 | /// This expression represents a reference to a Decl node. |
434 | class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> { |
435 | public: |
436 | static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type); |
437 | |
438 | /// Get the decl referenced by this expression. |
439 | Decl *getDecl() const { return decl; } |
440 | |
441 | private: |
442 | DeclRefExpr(SMRange loc, Decl *decl, Type type) |
443 | : Base(loc, type), decl(decl) {} |
444 | |
445 | /// The decl referenced by this expression. |
446 | Decl *decl; |
447 | }; |
448 | |
449 | //===----------------------------------------------------------------------===// |
450 | // MemberAccessExpr |
451 | //===----------------------------------------------------------------------===// |
452 | |
453 | /// This expression represents a named member or field access of a given parent |
454 | /// expression. |
455 | class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> { |
456 | public: |
457 | static MemberAccessExpr *create(Context &ctx, SMRange loc, |
458 | const Expr *parentExpr, StringRef memberName, |
459 | Type type); |
460 | |
461 | /// Get the parent expression of this access. |
462 | const Expr *getParentExpr() const { return parentExpr; } |
463 | |
464 | /// Return the name of the member being accessed. |
465 | StringRef getMemberName() const { return memberName; } |
466 | |
467 | private: |
468 | MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName, |
469 | Type type) |
470 | : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {} |
471 | |
472 | /// The parent expression of this access. |
473 | const Expr *parentExpr; |
474 | |
475 | /// The name of the member being accessed from the parent. |
476 | StringRef memberName; |
477 | }; |
478 | |
479 | //===----------------------------------------------------------------------===// |
480 | // AllResultsMemberAccessExpr |
481 | |
482 | /// This class represents an instance of MemberAccessExpr that references all |
483 | /// results of an operation. |
484 | class AllResultsMemberAccessExpr : public MemberAccessExpr { |
485 | public: |
486 | /// Return the member name used for the "all-results" access. |
487 | static StringRef getMemberName() { return "$results" ; } |
488 | |
489 | static AllResultsMemberAccessExpr *create(Context &ctx, SMRange loc, |
490 | const Expr *parentExpr, Type type) { |
491 | return cast<AllResultsMemberAccessExpr>( |
492 | Val: MemberAccessExpr::create(ctx, loc, parentExpr, memberName: getMemberName(), type)); |
493 | } |
494 | |
495 | /// Provide type casting support. |
496 | static bool classof(const Node *node) { |
497 | const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(Val: node); |
498 | return memAccess && memAccess->getMemberName() == getMemberName(); |
499 | } |
500 | }; |
501 | |
502 | //===----------------------------------------------------------------------===// |
503 | // OperationExpr |
504 | //===----------------------------------------------------------------------===// |
505 | |
506 | /// This expression represents the structural form of an MLIR Operation. It |
507 | /// represents either an input operation to match, or an operation to create |
508 | /// within a rewrite. |
509 | class OperationExpr final |
510 | : public Node::NodeBase<OperationExpr, Expr>, |
511 | private llvm::TrailingObjects<OperationExpr, Expr *, |
512 | NamedAttributeDecl *> { |
513 | public: |
514 | static OperationExpr *create(Context &ctx, SMRange loc, |
515 | const ods::Operation *odsOp, |
516 | const OpNameDecl *nameDecl, |
517 | ArrayRef<Expr *> operands, |
518 | ArrayRef<Expr *> resultTypes, |
519 | ArrayRef<NamedAttributeDecl *> attributes); |
520 | |
521 | /// Return the name of the operation, or std::nullopt if there isn't one. |
522 | std::optional<StringRef> getName() const; |
523 | |
524 | /// Return the declaration of the operation name. |
525 | const OpNameDecl *getNameDecl() const { return nameDecl; } |
526 | |
527 | /// Return the location of the name of the operation expression, or an invalid |
528 | /// location if there isn't a name. |
529 | SMRange getNameLoc() const { return nameLoc; } |
530 | |
531 | /// Return the operands of this operation. |
532 | MutableArrayRef<Expr *> getOperands() { |
533 | return {getTrailingObjects<Expr *>(), numOperands}; |
534 | } |
535 | ArrayRef<Expr *> getOperands() const { |
536 | return const_cast<OperationExpr *>(this)->getOperands(); |
537 | } |
538 | |
539 | /// Return the result types of this operation. |
540 | MutableArrayRef<Expr *> getResultTypes() { |
541 | return {getTrailingObjects<Expr *>() + numOperands, numResultTypes}; |
542 | } |
543 | MutableArrayRef<Expr *> getResultTypes() const { |
544 | return const_cast<OperationExpr *>(this)->getResultTypes(); |
545 | } |
546 | |
547 | /// Return the attributes of this operation. |
548 | MutableArrayRef<NamedAttributeDecl *> getAttributes() { |
549 | return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes}; |
550 | } |
551 | MutableArrayRef<NamedAttributeDecl *> getAttributes() const { |
552 | return const_cast<OperationExpr *>(this)->getAttributes(); |
553 | } |
554 | |
555 | private: |
556 | OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl, |
557 | unsigned numOperands, unsigned numResultTypes, |
558 | unsigned numAttributes, SMRange nameLoc) |
559 | : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands), |
560 | numResultTypes(numResultTypes), numAttributes(numAttributes), |
561 | nameLoc(nameLoc) {} |
562 | |
563 | /// The name decl of this expression. |
564 | const OpNameDecl *nameDecl; |
565 | |
566 | /// The number of operands, result types, and attributes of the operation. |
567 | unsigned numOperands, numResultTypes, numAttributes; |
568 | |
569 | /// The location of the operation name in the expression if it has a name. |
570 | SMRange nameLoc; |
571 | |
572 | /// TrailingObject utilities. |
573 | friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>; |
574 | size_t numTrailingObjects(OverloadToken<Expr *>) const { |
575 | return numOperands + numResultTypes; |
576 | } |
577 | }; |
578 | |
579 | //===----------------------------------------------------------------------===// |
580 | // RangeExpr |
581 | //===----------------------------------------------------------------------===// |
582 | |
583 | /// This expression builds a range from a set of element values (which may be |
584 | /// ranges themselves). |
585 | class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>, |
586 | private llvm::TrailingObjects<RangeExpr, Expr *> { |
587 | public: |
588 | static RangeExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements, |
589 | RangeType type); |
590 | |
591 | /// Return the element expressions of this range. |
592 | MutableArrayRef<Expr *> getElements() { |
593 | return {getTrailingObjects<Expr *>(), numElements}; |
594 | } |
595 | ArrayRef<Expr *> getElements() const { |
596 | return const_cast<RangeExpr *>(this)->getElements(); |
597 | } |
598 | |
599 | /// Return the range result type of this expression. |
600 | RangeType getType() const { return Base::getType().cast<RangeType>(); } |
601 | |
602 | private: |
603 | RangeExpr(SMRange loc, RangeType type, unsigned numElements) |
604 | : Base(loc, type), numElements(numElements) {} |
605 | |
606 | /// The number of element values for this range. |
607 | unsigned numElements; |
608 | |
609 | /// TrailingObject utilities. |
610 | friend class llvm::TrailingObjects<RangeExpr, Expr *>; |
611 | }; |
612 | |
613 | //===----------------------------------------------------------------------===// |
614 | // TupleExpr |
615 | //===----------------------------------------------------------------------===// |
616 | |
617 | /// This expression builds a tuple from a set of element values. |
618 | class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>, |
619 | private llvm::TrailingObjects<TupleExpr, Expr *> { |
620 | public: |
621 | static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements, |
622 | ArrayRef<StringRef> elementNames); |
623 | |
624 | /// Return the element expressions of this tuple. |
625 | MutableArrayRef<Expr *> getElements() { |
626 | return {getTrailingObjects<Expr *>(), getType().size()}; |
627 | } |
628 | ArrayRef<Expr *> getElements() const { |
629 | return const_cast<TupleExpr *>(this)->getElements(); |
630 | } |
631 | |
632 | /// Return the tuple result type of this expression. |
633 | TupleType getType() const { return Base::getType().cast<TupleType>(); } |
634 | |
635 | private: |
636 | TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {} |
637 | |
638 | /// TrailingObject utilities. |
639 | friend class llvm::TrailingObjects<TupleExpr, Expr *>; |
640 | }; |
641 | |
642 | //===----------------------------------------------------------------------===// |
643 | // TypeExpr |
644 | //===----------------------------------------------------------------------===// |
645 | |
646 | /// This expression represents a literal MLIR Type, and contains the textual |
647 | /// assembly format of that type. |
648 | class TypeExpr : public Node::NodeBase<TypeExpr, Expr> { |
649 | public: |
650 | static TypeExpr *create(Context &ctx, SMRange loc, StringRef value); |
651 | |
652 | /// Get the raw value of this expression. This is the textual assembly format |
653 | /// of the MLIR Type. |
654 | StringRef getValue() const { return value; } |
655 | |
656 | private: |
657 | TypeExpr(Context &ctx, SMRange loc, StringRef value) |
658 | : Base(loc, TypeType::get(context&: ctx)), value(value) {} |
659 | |
660 | /// The value referenced by this expression. |
661 | StringRef value; |
662 | }; |
663 | |
664 | //===----------------------------------------------------------------------===// |
665 | // Decl |
666 | //===----------------------------------------------------------------------===// |
667 | |
668 | /// This class represents the base Decl node. |
669 | class Decl : public Node { |
670 | public: |
671 | /// Return the name of the decl, or nullptr if it doesn't have one. |
672 | const Name *getName() const { return name; } |
673 | |
674 | /// Provide type casting support. |
675 | static bool classof(const Node *node); |
676 | |
677 | /// Set the documentation comment for this decl. |
678 | void (Context &ctx, StringRef ); |
679 | |
680 | /// Return the documentation comment attached to this decl if it has been set. |
681 | /// Otherwise, returns std::nullopt. |
682 | std::optional<StringRef> () const { return docComment; } |
683 | |
684 | protected: |
685 | Decl(TypeID typeID, SMRange loc, const Name *name = nullptr) |
686 | : Node(typeID, loc), name(name) {} |
687 | |
688 | private: |
689 | /// The name of the decl. This is optional for some decls, such as |
690 | /// PatternDecl. |
691 | const Name *name; |
692 | |
693 | /// The documentation comment attached to this decl. Defaults to std::nullopt |
694 | /// if the comment is unset/unknown. |
695 | std::optional<StringRef> ; |
696 | }; |
697 | |
698 | //===----------------------------------------------------------------------===// |
699 | // ConstraintDecl |
700 | //===----------------------------------------------------------------------===// |
701 | |
702 | /// This class represents the base of all AST Constraint decls. Constraints |
703 | /// apply matcher conditions to, and define the type of PDLL variables. |
704 | class ConstraintDecl : public Decl { |
705 | public: |
706 | /// Provide type casting support. |
707 | static bool classof(const Node *node); |
708 | |
709 | protected: |
710 | ConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr) |
711 | : Decl(typeID, loc, name) {} |
712 | }; |
713 | |
714 | /// This class represents a reference to a constraint, and contains a constraint |
715 | /// and the location of the reference. |
716 | struct ConstraintRef { |
717 | ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc) |
718 | : constraint(constraint), referenceLoc(refLoc) {} |
719 | explicit ConstraintRef(const ConstraintDecl *constraint) |
720 | : ConstraintRef(constraint, constraint->getLoc()) {} |
721 | |
722 | const ConstraintDecl *constraint; |
723 | SMRange referenceLoc; |
724 | }; |
725 | |
726 | //===----------------------------------------------------------------------===// |
727 | // CoreConstraintDecl |
728 | //===----------------------------------------------------------------------===// |
729 | |
730 | /// This class represents the base of all "core" constraints. Core constraints |
731 | /// are those that generally represent a concrete IR construct, such as |
732 | /// `Type`s or `Value`s. |
733 | class CoreConstraintDecl : public ConstraintDecl { |
734 | public: |
735 | /// Provide type casting support. |
736 | static bool classof(const Node *node); |
737 | |
738 | protected: |
739 | CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr) |
740 | : ConstraintDecl(typeID, loc, name) {} |
741 | }; |
742 | |
743 | //===----------------------------------------------------------------------===// |
744 | // AttrConstraintDecl |
745 | |
746 | /// The class represents an Attribute constraint, and constrains a variable to |
747 | /// be an Attribute. |
748 | class AttrConstraintDecl |
749 | : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> { |
750 | public: |
751 | static AttrConstraintDecl *create(Context &ctx, SMRange loc, |
752 | Expr *typeExpr = nullptr); |
753 | |
754 | /// Return the optional type the attribute is constrained to. |
755 | Expr *getTypeExpr() { return typeExpr; } |
756 | const Expr *getTypeExpr() const { return typeExpr; } |
757 | |
758 | protected: |
759 | AttrConstraintDecl(SMRange loc, Expr *typeExpr) |
760 | : Base(loc), typeExpr(typeExpr) {} |
761 | |
762 | /// An optional type that the attribute is constrained to. |
763 | Expr *typeExpr; |
764 | }; |
765 | |
766 | //===----------------------------------------------------------------------===// |
767 | // OpConstraintDecl |
768 | |
769 | /// The class represents an Operation constraint, and constrains a variable to |
770 | /// be an Operation. |
771 | class OpConstraintDecl |
772 | : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> { |
773 | public: |
774 | static OpConstraintDecl *create(Context &ctx, SMRange loc, |
775 | const OpNameDecl *nameDecl = nullptr); |
776 | |
777 | /// Return the name of the operation, or std::nullopt if there isn't one. |
778 | std::optional<StringRef> getName() const; |
779 | |
780 | /// Return the declaration of the operation name. |
781 | const OpNameDecl *getNameDecl() const { return nameDecl; } |
782 | |
783 | protected: |
784 | explicit OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl) |
785 | : Base(loc), nameDecl(nameDecl) {} |
786 | |
787 | /// The operation name of this constraint. |
788 | const OpNameDecl *nameDecl; |
789 | }; |
790 | |
791 | //===----------------------------------------------------------------------===// |
792 | // TypeConstraintDecl |
793 | |
794 | /// The class represents a Type constraint, and constrains a variable to be a |
795 | /// Type. |
796 | class TypeConstraintDecl |
797 | : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> { |
798 | public: |
799 | static TypeConstraintDecl *create(Context &ctx, SMRange loc); |
800 | |
801 | protected: |
802 | using Base::Base; |
803 | }; |
804 | |
805 | //===----------------------------------------------------------------------===// |
806 | // TypeRangeConstraintDecl |
807 | |
808 | /// The class represents a TypeRange constraint, and constrains a variable to be |
809 | /// a TypeRange. |
810 | class TypeRangeConstraintDecl |
811 | : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> { |
812 | public: |
813 | static TypeRangeConstraintDecl *create(Context &ctx, SMRange loc); |
814 | |
815 | protected: |
816 | using Base::Base; |
817 | }; |
818 | |
819 | //===----------------------------------------------------------------------===// |
820 | // ValueConstraintDecl |
821 | |
822 | /// The class represents a Value constraint, and constrains a variable to be a |
823 | /// Value. |
824 | class ValueConstraintDecl |
825 | : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> { |
826 | public: |
827 | static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr); |
828 | |
829 | /// Return the optional type the value is constrained to. |
830 | Expr *getTypeExpr() { return typeExpr; } |
831 | const Expr *getTypeExpr() const { return typeExpr; } |
832 | |
833 | protected: |
834 | ValueConstraintDecl(SMRange loc, Expr *typeExpr) |
835 | : Base(loc), typeExpr(typeExpr) {} |
836 | |
837 | /// An optional type that the value is constrained to. |
838 | Expr *typeExpr; |
839 | }; |
840 | |
841 | //===----------------------------------------------------------------------===// |
842 | // ValueRangeConstraintDecl |
843 | |
844 | /// The class represents a ValueRange constraint, and constrains a variable to |
845 | /// be a ValueRange. |
846 | class ValueRangeConstraintDecl |
847 | : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> { |
848 | public: |
849 | static ValueRangeConstraintDecl *create(Context &ctx, SMRange loc, |
850 | Expr *typeExpr = nullptr); |
851 | |
852 | /// Return the optional type the value range is constrained to. |
853 | Expr *getTypeExpr() { return typeExpr; } |
854 | const Expr *getTypeExpr() const { return typeExpr; } |
855 | |
856 | protected: |
857 | ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr) |
858 | : Base(loc), typeExpr(typeExpr) {} |
859 | |
860 | /// An optional type that the value range is constrained to. |
861 | Expr *typeExpr; |
862 | }; |
863 | |
864 | //===----------------------------------------------------------------------===// |
865 | // UserConstraintDecl |
866 | //===----------------------------------------------------------------------===// |
867 | |
868 | /// This decl represents a user defined constraint. This is either: |
869 | /// * an imported native constraint |
870 | /// - Similar to an external function declaration. This is a native |
871 | /// constraint defined externally, and imported into PDLL via a |
872 | /// declaration. |
873 | /// * a native constraint defined in PDLL |
874 | /// - This is a native constraint, i.e. a constraint whose implementation is |
875 | /// defined in C++(or potentially some other non-PDLL language). The |
876 | /// implementation of this constraint is specified as a string code block |
877 | /// in PDLL. |
878 | /// * a PDLL constraint |
879 | /// - This is a constraint which is defined using only PDLL constructs. |
880 | class UserConstraintDecl final |
881 | : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>, |
882 | llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> { |
883 | public: |
884 | /// Create a native constraint with the given optional code block. |
885 | static UserConstraintDecl * |
886 | createNative(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs, |
887 | ArrayRef<VariableDecl *> results, |
888 | std::optional<StringRef> codeBlock, Type resultType, |
889 | ArrayRef<StringRef> nativeInputTypes = {}) { |
890 | return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock, |
891 | /*body=*/body: nullptr, resultType); |
892 | } |
893 | |
894 | /// Create a PDLL constraint with the given body. |
895 | static UserConstraintDecl *createPDLL(Context &ctx, const Name &name, |
896 | ArrayRef<VariableDecl *> inputs, |
897 | ArrayRef<VariableDecl *> results, |
898 | const CompoundStmt *body, |
899 | Type resultType) { |
900 | return createImpl(ctx, name, inputs, /*nativeInputTypes=*/nativeInputTypes: std::nullopt, |
901 | results, /*codeBlock=*/codeBlock: std::nullopt, body, resultType); |
902 | } |
903 | |
904 | /// Return the name of the constraint. |
905 | const Name &getName() const { return *Decl::getName(); } |
906 | |
907 | /// Return the input arguments of this constraint. |
908 | MutableArrayRef<VariableDecl *> getInputs() { |
909 | return {getTrailingObjects<VariableDecl *>(), numInputs}; |
910 | } |
911 | ArrayRef<VariableDecl *> getInputs() const { |
912 | return const_cast<UserConstraintDecl *>(this)->getInputs(); |
913 | } |
914 | |
915 | /// Return the explicit native type to use for the given input. Returns |
916 | /// std::nullopt if no explicit type was set. |
917 | std::optional<StringRef> getNativeInputType(unsigned index) const; |
918 | |
919 | /// Return the explicit results of the constraint declaration. May be empty, |
920 | /// even if the constraint has results (e.g. in the case of inferred results). |
921 | MutableArrayRef<VariableDecl *> getResults() { |
922 | return {getTrailingObjects<VariableDecl *>() + numInputs, numResults}; |
923 | } |
924 | ArrayRef<VariableDecl *> getResults() const { |
925 | return const_cast<UserConstraintDecl *>(this)->getResults(); |
926 | } |
927 | |
928 | /// Return the optional code block of this constraint, if this is a native |
929 | /// constraint with a provided implementation. |
930 | std::optional<StringRef> getCodeBlock() const { return codeBlock; } |
931 | |
932 | /// Return the body of this constraint if this constraint is a PDLL |
933 | /// constraint, otherwise returns nullptr. |
934 | const CompoundStmt *getBody() const { return constraintBody; } |
935 | |
936 | /// Return the result type of this constraint. |
937 | Type getResultType() const { return resultType; } |
938 | |
939 | /// Returns true if this constraint is external. |
940 | bool isExternal() const { return !constraintBody && !codeBlock; } |
941 | |
942 | private: |
943 | /// Create either a PDLL constraint or a native constraint with the given |
944 | /// components. |
945 | static UserConstraintDecl *createImpl(Context &ctx, const Name &name, |
946 | ArrayRef<VariableDecl *> inputs, |
947 | ArrayRef<StringRef> nativeInputTypes, |
948 | ArrayRef<VariableDecl *> results, |
949 | std::optional<StringRef> codeBlock, |
950 | const CompoundStmt *body, |
951 | Type resultType); |
952 | |
953 | UserConstraintDecl(const Name &name, unsigned numInputs, |
954 | bool hasNativeInputTypes, unsigned numResults, |
955 | std::optional<StringRef> codeBlock, |
956 | const CompoundStmt *body, Type resultType) |
957 | : Base(name.getLoc(), &name), numInputs(numInputs), |
958 | numResults(numResults), codeBlock(codeBlock), constraintBody(body), |
959 | resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {} |
960 | |
961 | /// The number of inputs to this constraint. |
962 | unsigned numInputs; |
963 | |
964 | /// The number of explicit results to this constraint. |
965 | unsigned numResults; |
966 | |
967 | /// The optional code block of this constraint. |
968 | std::optional<StringRef> codeBlock; |
969 | |
970 | /// The optional body of this constraint. |
971 | const CompoundStmt *constraintBody; |
972 | |
973 | /// The result type of the constraint. |
974 | Type resultType; |
975 | |
976 | /// Flag indicating if this constraint has explicit native input types. |
977 | bool hasNativeInputTypes; |
978 | |
979 | /// Allow access to various internals. |
980 | friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>; |
981 | size_t numTrailingObjects(OverloadToken<VariableDecl *>) const { |
982 | return numInputs + numResults; |
983 | } |
984 | }; |
985 | |
986 | //===----------------------------------------------------------------------===// |
987 | // NamedAttributeDecl |
988 | //===----------------------------------------------------------------------===// |
989 | |
990 | /// This Decl represents a NamedAttribute, and contains a string name and |
991 | /// attribute value. |
992 | class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> { |
993 | public: |
994 | static NamedAttributeDecl *create(Context &ctx, const Name &name, |
995 | Expr *value); |
996 | |
997 | /// Return the name of the attribute. |
998 | const Name &getName() const { return *Decl::getName(); } |
999 | |
1000 | /// Return value of the attribute. |
1001 | Expr *getValue() const { return value; } |
1002 | |
1003 | private: |
1004 | NamedAttributeDecl(const Name &name, Expr *value) |
1005 | : Base(name.getLoc(), &name), value(value) {} |
1006 | |
1007 | /// The value of the attribute. |
1008 | Expr *value; |
1009 | }; |
1010 | |
1011 | //===----------------------------------------------------------------------===// |
1012 | // OpNameDecl |
1013 | //===----------------------------------------------------------------------===// |
1014 | |
1015 | /// This Decl represents an OperationName. |
1016 | class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> { |
1017 | public: |
1018 | static OpNameDecl *create(Context &ctx, const Name &name); |
1019 | static OpNameDecl *create(Context &ctx, SMRange loc); |
1020 | |
1021 | /// Return the name of this operation, or std::nullopt if the name is unknown. |
1022 | std::optional<StringRef> getName() const { |
1023 | const Name *name = Decl::getName(); |
1024 | return name ? std::optional<StringRef>(name->getName()) : std::nullopt; |
1025 | } |
1026 | |
1027 | private: |
1028 | explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {} |
1029 | explicit OpNameDecl(SMRange loc) : Base(loc) {} |
1030 | }; |
1031 | |
1032 | //===----------------------------------------------------------------------===// |
1033 | // PatternDecl |
1034 | //===----------------------------------------------------------------------===// |
1035 | |
1036 | /// This Decl represents a single Pattern. |
1037 | class PatternDecl : public Node::NodeBase<PatternDecl, Decl> { |
1038 | public: |
1039 | static PatternDecl *create(Context &ctx, SMRange location, const Name *name, |
1040 | std::optional<uint16_t> benefit, |
1041 | bool hasBoundedRecursion, |
1042 | const CompoundStmt *body); |
1043 | |
1044 | /// Return the benefit of this pattern if specified, or std::nullopt. |
1045 | std::optional<uint16_t> getBenefit() const { return benefit; } |
1046 | |
1047 | /// Return if this pattern has bounded rewrite recursion. |
1048 | bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; } |
1049 | |
1050 | /// Return the body of this pattern. |
1051 | const CompoundStmt *getBody() const { return patternBody; } |
1052 | |
1053 | /// Return the root rewrite statement of this pattern. |
1054 | const OpRewriteStmt *getRootRewriteStmt() const { |
1055 | return cast<OpRewriteStmt>(Val: patternBody->getChildren().back()); |
1056 | } |
1057 | |
1058 | private: |
1059 | PatternDecl(SMRange loc, const Name *name, std::optional<uint16_t> benefit, |
1060 | bool hasBoundedRecursion, const CompoundStmt *body) |
1061 | : Base(loc, name), benefit(benefit), |
1062 | hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {} |
1063 | |
1064 | /// The benefit of the pattern if it was explicitly specified, std::nullopt |
1065 | /// otherwise. |
1066 | std::optional<uint16_t> benefit; |
1067 | |
1068 | /// If the pattern has properly bounded rewrite recursion or not. |
1069 | bool hasBoundedRecursion; |
1070 | |
1071 | /// The compound statement representing the body of the pattern. |
1072 | const CompoundStmt *patternBody; |
1073 | }; |
1074 | |
1075 | //===----------------------------------------------------------------------===// |
1076 | // UserRewriteDecl |
1077 | //===----------------------------------------------------------------------===// |
1078 | |
1079 | /// This decl represents a user defined rewrite. This is either: |
1080 | /// * an imported native rewrite |
1081 | /// - Similar to an external function declaration. This is a native |
1082 | /// rewrite defined externally, and imported into PDLL via a declaration. |
1083 | /// * a native rewrite defined in PDLL |
1084 | /// - This is a native rewrite, i.e. a rewrite whose implementation is |
1085 | /// defined in C++(or potentially some other non-PDLL language). The |
1086 | /// implementation of this rewrite is specified as a string code block |
1087 | /// in PDLL. |
1088 | /// * a PDLL rewrite |
1089 | /// - This is a rewrite which is defined using only PDLL constructs. |
1090 | class UserRewriteDecl final |
1091 | : public Node::NodeBase<UserRewriteDecl, Decl>, |
1092 | llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> { |
1093 | public: |
1094 | /// Create a native rewrite with the given optional code block. |
1095 | static UserRewriteDecl *createNative(Context &ctx, const Name &name, |
1096 | ArrayRef<VariableDecl *> inputs, |
1097 | ArrayRef<VariableDecl *> results, |
1098 | std::optional<StringRef> codeBlock, |
1099 | Type resultType) { |
1100 | return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/body: nullptr, |
1101 | resultType); |
1102 | } |
1103 | |
1104 | /// Create a PDLL rewrite with the given body. |
1105 | static UserRewriteDecl *createPDLL(Context &ctx, const Name &name, |
1106 | ArrayRef<VariableDecl *> inputs, |
1107 | ArrayRef<VariableDecl *> results, |
1108 | const CompoundStmt *body, |
1109 | Type resultType) { |
1110 | return createImpl(ctx, name, inputs, results, /*codeBlock=*/codeBlock: std::nullopt, |
1111 | body, resultType); |
1112 | } |
1113 | |
1114 | /// Return the name of the rewrite. |
1115 | const Name &getName() const { return *Decl::getName(); } |
1116 | |
1117 | /// Return the input arguments of this rewrite. |
1118 | MutableArrayRef<VariableDecl *> getInputs() { |
1119 | return {getTrailingObjects<VariableDecl *>(), numInputs}; |
1120 | } |
1121 | ArrayRef<VariableDecl *> getInputs() const { |
1122 | return const_cast<UserRewriteDecl *>(this)->getInputs(); |
1123 | } |
1124 | |
1125 | /// Return the explicit results of the rewrite declaration. May be empty, |
1126 | /// even if the rewrite has results (e.g. in the case of inferred results). |
1127 | MutableArrayRef<VariableDecl *> getResults() { |
1128 | return {getTrailingObjects<VariableDecl *>() + numInputs, numResults}; |
1129 | } |
1130 | ArrayRef<VariableDecl *> getResults() const { |
1131 | return const_cast<UserRewriteDecl *>(this)->getResults(); |
1132 | } |
1133 | |
1134 | /// Return the optional code block of this rewrite, if this is a native |
1135 | /// rewrite with a provided implementation. |
1136 | std::optional<StringRef> getCodeBlock() const { return codeBlock; } |
1137 | |
1138 | /// Return the body of this rewrite if this rewrite is a PDLL rewrite, |
1139 | /// otherwise returns nullptr. |
1140 | const CompoundStmt *getBody() const { return rewriteBody; } |
1141 | |
1142 | /// Return the result type of this rewrite. |
1143 | Type getResultType() const { return resultType; } |
1144 | |
1145 | /// Returns true if this rewrite is external. |
1146 | bool isExternal() const { return !rewriteBody && !codeBlock; } |
1147 | |
1148 | private: |
1149 | /// Create either a PDLL rewrite or a native rewrite with the given |
1150 | /// components. |
1151 | static UserRewriteDecl *createImpl(Context &ctx, const Name &name, |
1152 | ArrayRef<VariableDecl *> inputs, |
1153 | ArrayRef<VariableDecl *> results, |
1154 | std::optional<StringRef> codeBlock, |
1155 | const CompoundStmt *body, Type resultType); |
1156 | |
1157 | UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults, |
1158 | std::optional<StringRef> codeBlock, const CompoundStmt *body, |
1159 | Type resultType) |
1160 | : Base(name.getLoc(), &name), numInputs(numInputs), |
1161 | numResults(numResults), codeBlock(codeBlock), rewriteBody(body), |
1162 | resultType(resultType) {} |
1163 | |
1164 | /// The number of inputs to this rewrite. |
1165 | unsigned numInputs; |
1166 | |
1167 | /// The number of explicit results to this rewrite. |
1168 | unsigned numResults; |
1169 | |
1170 | /// The optional code block of this rewrite. |
1171 | std::optional<StringRef> codeBlock; |
1172 | |
1173 | /// The optional body of this rewrite. |
1174 | const CompoundStmt *rewriteBody; |
1175 | |
1176 | /// The result type of the rewrite. |
1177 | Type resultType; |
1178 | |
1179 | /// Allow access to various internals. |
1180 | friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>; |
1181 | }; |
1182 | |
1183 | //===----------------------------------------------------------------------===// |
1184 | // CallableDecl |
1185 | //===----------------------------------------------------------------------===// |
1186 | |
1187 | /// This decl represents a shared interface for all callable decls. |
1188 | class CallableDecl : public Decl { |
1189 | public: |
1190 | /// Return the callable type of this decl. |
1191 | StringRef getCallableType() const { |
1192 | if (isa<UserConstraintDecl>(Val: this)) |
1193 | return "constraint" ; |
1194 | assert(isa<UserRewriteDecl>(this) && "unknown callable type" ); |
1195 | return "rewrite" ; |
1196 | } |
1197 | |
1198 | /// Return the inputs of this decl. |
1199 | ArrayRef<VariableDecl *> getInputs() const { |
1200 | if (const auto *cst = dyn_cast<UserConstraintDecl>(Val: this)) |
1201 | return cst->getInputs(); |
1202 | return cast<UserRewriteDecl>(Val: this)->getInputs(); |
1203 | } |
1204 | |
1205 | /// Return the result type of this decl. |
1206 | Type getResultType() const { |
1207 | if (const auto *cst = dyn_cast<UserConstraintDecl>(Val: this)) |
1208 | return cst->getResultType(); |
1209 | return cast<UserRewriteDecl>(Val: this)->getResultType(); |
1210 | } |
1211 | |
1212 | /// Return the explicit results of the declaration. Note that these may be |
1213 | /// empty, even if the callable has results (e.g. in the case of inferred |
1214 | /// results). |
1215 | ArrayRef<VariableDecl *> getResults() const { |
1216 | if (const auto *cst = dyn_cast<UserConstraintDecl>(Val: this)) |
1217 | return cst->getResults(); |
1218 | return cast<UserRewriteDecl>(Val: this)->getResults(); |
1219 | } |
1220 | |
1221 | /// Return the optional code block of this callable, if this is a native |
1222 | /// callable with a provided implementation. |
1223 | std::optional<StringRef> getCodeBlock() const { |
1224 | if (const auto *cst = dyn_cast<UserConstraintDecl>(Val: this)) |
1225 | return cst->getCodeBlock(); |
1226 | return cast<UserRewriteDecl>(Val: this)->getCodeBlock(); |
1227 | } |
1228 | |
1229 | /// Support LLVM type casting facilities. |
1230 | static bool classof(const Node *decl) { |
1231 | return isa<UserConstraintDecl, UserRewriteDecl>(Val: decl); |
1232 | } |
1233 | }; |
1234 | |
1235 | //===----------------------------------------------------------------------===// |
1236 | // VariableDecl |
1237 | //===----------------------------------------------------------------------===// |
1238 | |
1239 | /// This Decl represents the definition of a PDLL variable. |
1240 | class VariableDecl final |
1241 | : public Node::NodeBase<VariableDecl, Decl>, |
1242 | private llvm::TrailingObjects<VariableDecl, ConstraintRef> { |
1243 | public: |
1244 | static VariableDecl *create(Context &ctx, const Name &name, Type type, |
1245 | Expr *initExpr, |
1246 | ArrayRef<ConstraintRef> constraints); |
1247 | |
1248 | /// Return the constraints of this variable. |
1249 | MutableArrayRef<ConstraintRef> getConstraints() { |
1250 | return {getTrailingObjects<ConstraintRef>(), numConstraints}; |
1251 | } |
1252 | ArrayRef<ConstraintRef> getConstraints() const { |
1253 | return const_cast<VariableDecl *>(this)->getConstraints(); |
1254 | } |
1255 | |
1256 | /// Return the initializer expression of this statement, or nullptr if there |
1257 | /// was no initializer. |
1258 | Expr *getInitExpr() const { return initExpr; } |
1259 | |
1260 | /// Return the name of the decl. |
1261 | const Name &getName() const { return *Decl::getName(); } |
1262 | |
1263 | /// Return the type of the decl. |
1264 | Type getType() const { return type; } |
1265 | |
1266 | private: |
1267 | VariableDecl(const Name &name, Type type, Expr *initExpr, |
1268 | unsigned numConstraints) |
1269 | : Base(name.getLoc(), &name), type(type), initExpr(initExpr), |
1270 | numConstraints(numConstraints) {} |
1271 | |
1272 | /// The type of the variable. |
1273 | Type type; |
1274 | |
1275 | /// The optional initializer expression of this statement. |
1276 | Expr *initExpr; |
1277 | |
1278 | /// The number of constraints attached to this variable. |
1279 | unsigned numConstraints; |
1280 | |
1281 | /// Allow access to various internals. |
1282 | friend llvm::TrailingObjects<VariableDecl, ConstraintRef>; |
1283 | }; |
1284 | |
1285 | //===----------------------------------------------------------------------===// |
1286 | // Module |
1287 | //===----------------------------------------------------------------------===// |
1288 | |
1289 | /// This class represents a top-level AST module. |
1290 | class Module final : public Node::NodeBase<Module, Node>, |
1291 | private llvm::TrailingObjects<Module, Decl *> { |
1292 | public: |
1293 | static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children); |
1294 | |
1295 | /// Return the children of this module. |
1296 | MutableArrayRef<Decl *> getChildren() { |
1297 | return {getTrailingObjects<Decl *>(), numChildren}; |
1298 | } |
1299 | ArrayRef<Decl *> getChildren() const { |
1300 | return const_cast<Module *>(this)->getChildren(); |
1301 | } |
1302 | |
1303 | private: |
1304 | Module(SMLoc loc, unsigned numChildren) |
1305 | : Base(SMRange{loc, loc}), numChildren(numChildren) {} |
1306 | |
1307 | /// The number of decls held by this module. |
1308 | unsigned numChildren; |
1309 | |
1310 | /// Allow access to various internals. |
1311 | friend llvm::TrailingObjects<Module, Decl *>; |
1312 | }; |
1313 | |
1314 | //===----------------------------------------------------------------------===// |
1315 | // Defered Method Definitions |
1316 | //===----------------------------------------------------------------------===// |
1317 | |
1318 | inline bool Decl::classof(const Node *node) { |
1319 | return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl, |
1320 | UserRewriteDecl, VariableDecl>(Val: node); |
1321 | } |
1322 | |
1323 | inline bool ConstraintDecl::classof(const Node *node) { |
1324 | return isa<CoreConstraintDecl, UserConstraintDecl>(Val: node); |
1325 | } |
1326 | |
1327 | inline bool CoreConstraintDecl::classof(const Node *node) { |
1328 | return isa<AttrConstraintDecl, OpConstraintDecl, TypeConstraintDecl, |
1329 | TypeRangeConstraintDecl, ValueConstraintDecl, |
1330 | ValueRangeConstraintDecl>(Val: node); |
1331 | } |
1332 | |
1333 | inline bool Expr::classof(const Node *node) { |
1334 | return isa<AttributeExpr, CallExpr, DeclRefExpr, MemberAccessExpr, |
1335 | OperationExpr, RangeExpr, TupleExpr, TypeExpr>(Val: node); |
1336 | } |
1337 | |
1338 | inline bool OpRewriteStmt::classof(const Node *node) { |
1339 | return isa<EraseStmt, ReplaceStmt, RewriteStmt>(Val: node); |
1340 | } |
1341 | |
1342 | inline bool Stmt::classof(const Node *node) { |
1343 | return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(Val: node); |
1344 | } |
1345 | |
1346 | } // namespace ast |
1347 | } // namespace pdll |
1348 | } // namespace mlir |
1349 | |
1350 | #endif // MLIR_TOOLS_PDLL_AST_NODES_H_ |
1351 | |