1//===- Nodes.h --------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
10#define MLIR_TOOLS_PDLL_AST_NODES_H_
11
12#include "mlir/Support/LLVM.h"
13#include "mlir/Tools/PDLL/AST/Types.h"
14#include "llvm/ADT/StringMap.h"
15#include "llvm/ADT/StringRef.h"
16#include "llvm/Support/SMLoc.h"
17#include "llvm/Support/SourceMgr.h"
18#include "llvm/Support/TrailingObjects.h"
19#include <optional>
20
21namespace mlir {
22namespace pdll {
23namespace ast {
24class Context;
25class Decl;
26class Expr;
27class NamedAttributeDecl;
28class OpNameDecl;
29class VariableDecl;
30
31//===----------------------------------------------------------------------===//
32// Name
33//===----------------------------------------------------------------------===//
34
35/// This class provides a convenient API for interacting with source names. It
36/// contains a string name as well as the source location for that name.
37struct Name {
38 static const Name &create(Context &ctx, StringRef name, SMRange location);
39
40 /// Return the raw string name.
41 StringRef getName() const { return name; }
42
43 /// Get the location of this name.
44 SMRange getLoc() const { return location; }
45
46private:
47 Name() = delete;
48 Name(const Name &) = delete;
49 Name &operator=(const Name &) = delete;
50 Name(StringRef name, SMRange location) : name(name), location(location) {}
51
52 /// The string name of the decl.
53 StringRef name;
54 /// The location of the decl name.
55 SMRange location;
56};
57
58//===----------------------------------------------------------------------===//
59// DeclScope
60//===----------------------------------------------------------------------===//
61
62/// This class represents a scope for named AST decls. A scope determines the
63/// visibility and lifetime of a named declaration.
64class DeclScope {
65public:
66 /// Create a new scope with an optional parent scope.
67 DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
68
69 /// Return the parent scope of this scope, or nullptr if there is no parent.
70 DeclScope *getParentScope() { return parent; }
71 const DeclScope *getParentScope() const { return parent; }
72
73 /// Return all of the decls within this scope.
74 auto getDecls() const { return llvm::make_second_range(c: decls); }
75
76 /// Add a new decl to the scope.
77 void add(Decl *decl);
78
79 /// Lookup a decl with the given name starting from this scope. Returns
80 /// nullptr if no decl could be found.
81 Decl *lookup(StringRef name);
82 template <typename T>
83 T *lookup(StringRef name) {
84 return dyn_cast_or_null<T>(lookup(name));
85 }
86 const Decl *lookup(StringRef name) const {
87 return const_cast<DeclScope *>(this)->lookup(name);
88 }
89 template <typename T>
90 const T *lookup(StringRef name) const {
91 return dyn_cast_or_null<T>(lookup(name));
92 }
93
94private:
95 /// The parent scope, or null if this is a top-level scope.
96 DeclScope *parent;
97 /// The decls defined within this scope.
98 llvm::StringMap<Decl *> decls;
99};
100
101//===----------------------------------------------------------------------===//
102// Node
103//===----------------------------------------------------------------------===//
104
105/// This class represents a base AST node. All AST nodes are derived from this
106/// class, and it contains many of the base functionality for interacting with
107/// nodes.
108class Node {
109public:
110 /// This CRTP class provides several utilies when defining new AST nodes.
111 template <typename T, typename BaseT>
112 class NodeBase : public BaseT {
113 public:
114 using Base = NodeBase<T, BaseT>;
115
116 /// Provide type casting support.
117 static bool classof(const Node *node) {
118 return node->getTypeID() == TypeID::get<T>();
119 }
120
121 protected:
122 template <typename... Args>
123 explicit NodeBase(SMRange loc, Args &&...args)
124 : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
125 };
126
127 /// Return the type identifier of this node.
128 TypeID getTypeID() const { return typeID; }
129
130 /// Return the location of this node.
131 SMRange getLoc() const { return loc; }
132
133 /// Print this node to the given stream.
134 void print(raw_ostream &os) const;
135
136 /// Walk all of the nodes including, and nested under, this node in pre-order.
137 void walk(function_ref<void(const Node *)> walkFn) const;
138 template <typename WalkFnT, typename ArgT = typename llvm::function_traits<
139 WalkFnT>::template arg_t<0>>
140 std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value>
141 walk(WalkFnT &&walkFn) const {
142 walk([&](const Node *node) {
143 if (const ArgT *derivedNode = dyn_cast<ArgT>(node))
144 walkFn(derivedNode);
145 });
146 }
147
148protected:
149 Node(TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {}
150
151private:
152 /// A unique type identifier for this node.
153 TypeID typeID;
154
155 /// The location of this node.
156 SMRange loc;
157};
158
159//===----------------------------------------------------------------------===//
160// Stmt
161//===----------------------------------------------------------------------===//
162
163/// This class represents a base AST Statement node.
164class Stmt : public Node {
165public:
166 using Node::Node;
167
168 /// Provide type casting support.
169 static bool classof(const Node *node);
170};
171
172//===----------------------------------------------------------------------===//
173// CompoundStmt
174//===----------------------------------------------------------------------===//
175
176/// This statement represents a compound statement, which contains a collection
177/// of other statements.
178class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
179 private llvm::TrailingObjects<CompoundStmt, Stmt *> {
180public:
181 static CompoundStmt *create(Context &ctx, SMRange location,
182 ArrayRef<Stmt *> children);
183
184 /// Return the children of this compound statement.
185 MutableArrayRef<Stmt *> getChildren() {
186 return {getTrailingObjects<Stmt *>(), numChildren};
187 }
188 ArrayRef<Stmt *> getChildren() const {
189 return const_cast<CompoundStmt *>(this)->getChildren();
190 }
191 ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
192 ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
193
194private:
195 CompoundStmt(SMRange location, unsigned numChildren)
196 : Base(location), numChildren(numChildren) {}
197
198 /// The number of held children statements.
199 unsigned numChildren;
200
201 // Allow access to various privates.
202 friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
203};
204
205//===----------------------------------------------------------------------===//
206// LetStmt
207//===----------------------------------------------------------------------===//
208
209/// This statement represents a `let` statement in PDLL. This statement is used
210/// to define variables.
211class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
212public:
213 static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl);
214
215 /// Return the variable defined by this statement.
216 VariableDecl *getVarDecl() const { return varDecl; }
217
218private:
219 LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {}
220
221 /// The variable defined by this statement.
222 VariableDecl *varDecl;
223};
224
225//===----------------------------------------------------------------------===//
226// OpRewriteStmt
227//===----------------------------------------------------------------------===//
228
229/// This class represents a base operation rewrite statement. Operation rewrite
230/// statements perform a set of transformations on a given root operation.
231class OpRewriteStmt : public Stmt {
232public:
233 /// Provide type casting support.
234 static bool classof(const Node *node);
235
236 /// Return the root operation of this rewrite.
237 Expr *getRootOpExpr() const { return rootOp; }
238
239protected:
240 OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
241 : Stmt(typeID, loc), rootOp(rootOp) {}
242
243protected:
244 /// The root operation being rewritten.
245 Expr *rootOp;
246};
247
248//===----------------------------------------------------------------------===//
249// EraseStmt
250
251/// This statement represents the `erase` statement in PDLL. This statement
252/// erases the given root operation, corresponding roughly to the
253/// PatternRewriter::eraseOp API.
254class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
255public:
256 static EraseStmt *create(Context &ctx, SMRange loc, Expr *rootOp);
257
258private:
259 EraseStmt(SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
260};
261
262//===----------------------------------------------------------------------===//
263// ReplaceStmt
264
265/// This statement represents the `replace` statement in PDLL. This statement
266/// replace the given root operation with a set of values, corresponding roughly
267/// to the PatternRewriter::replaceOp API.
268class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
269 private llvm::TrailingObjects<ReplaceStmt, Expr *> {
270public:
271 static ReplaceStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
272 ArrayRef<Expr *> replExprs);
273
274 /// Return the replacement values of this statement.
275 MutableArrayRef<Expr *> getReplExprs() {
276 return {getTrailingObjects<Expr *>(), numReplExprs};
277 }
278 ArrayRef<Expr *> getReplExprs() const {
279 return const_cast<ReplaceStmt *>(this)->getReplExprs();
280 }
281
282private:
283 ReplaceStmt(SMRange loc, Expr *rootOp, unsigned numReplExprs)
284 : Base(loc, rootOp), numReplExprs(numReplExprs) {}
285
286 /// The number of replacement values within this statement.
287 unsigned numReplExprs;
288
289 /// TrailingObject utilities.
290 friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
291};
292
293//===----------------------------------------------------------------------===//
294// RewriteStmt
295
296/// This statement represents an operation rewrite that contains a block of
297/// nested rewrite commands. This allows for building more complex operation
298/// rewrites that span across multiple statements, which may be unconnected.
299class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
300public:
301 static RewriteStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
302 CompoundStmt *rewriteBody);
303
304 /// Return the compound rewrite body.
305 CompoundStmt *getRewriteBody() const { return rewriteBody; }
306
307private:
308 RewriteStmt(SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
309 : Base(loc, rootOp), rewriteBody(rewriteBody) {}
310
311 /// The body of nested rewriters within this statement.
312 CompoundStmt *rewriteBody;
313};
314
315//===----------------------------------------------------------------------===//
316// ReturnStmt
317//===----------------------------------------------------------------------===//
318
319/// This statement represents a return from a "callable" like decl, e.g. a
320/// Constraint or a Rewrite.
321class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
322public:
323 static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
324
325 /// Return the result expression of this statement.
326 Expr *getResultExpr() { return resultExpr; }
327 const Expr *getResultExpr() const { return resultExpr; }
328
329 /// Set the result expression of this statement.
330 void setResultExpr(Expr *expr) { resultExpr = expr; }
331
332private:
333 ReturnStmt(SMRange loc, Expr *resultExpr)
334 : Base(loc), resultExpr(resultExpr) {}
335
336 // The result expression of this statement.
337 Expr *resultExpr;
338};
339
340//===----------------------------------------------------------------------===//
341// Expr
342//===----------------------------------------------------------------------===//
343
344/// This class represents a base AST Expression node.
345class Expr : public Stmt {
346public:
347 /// Return the type of this expression.
348 Type getType() const { return type; }
349
350 /// Provide type casting support.
351 static bool classof(const Node *node);
352
353protected:
354 Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {}
355
356private:
357 /// The type of this expression.
358 Type type;
359};
360
361//===----------------------------------------------------------------------===//
362// AttributeExpr
363//===----------------------------------------------------------------------===//
364
365/// This expression represents a literal MLIR Attribute, and contains the
366/// textual assembly format of that attribute.
367class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
368public:
369 static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value);
370
371 /// Get the raw value of this expression. This is the textual assembly format
372 /// of the MLIR Attribute.
373 StringRef getValue() const { return value; }
374
375private:
376 AttributeExpr(Context &ctx, SMRange loc, StringRef value)
377 : Base(loc, AttributeType::get(context&: ctx)), value(value) {}
378
379 /// The value referenced by this expression.
380 StringRef value;
381};
382
383//===----------------------------------------------------------------------===//
384// CallExpr
385//===----------------------------------------------------------------------===//
386
387/// This expression represents a call to a decl, such as a
388/// UserConstraintDecl/UserRewriteDecl.
389class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
390 private llvm::TrailingObjects<CallExpr, Expr *> {
391public:
392 static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
393 ArrayRef<Expr *> arguments, Type resultType,
394 bool isNegated = false);
395
396 /// Return the callable of this call.
397 Expr *getCallableExpr() const { return callable; }
398
399 /// Return the arguments of this call.
400 MutableArrayRef<Expr *> getArguments() {
401 return {getTrailingObjects<Expr *>(), numArgs};
402 }
403 ArrayRef<Expr *> getArguments() const {
404 return const_cast<CallExpr *>(this)->getArguments();
405 }
406
407 /// Returns whether the result of this call is to be negated.
408 bool getIsNegated() const { return isNegated; }
409
410private:
411 CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
412 bool isNegated)
413 : Base(loc, type), callable(callable), numArgs(numArgs),
414 isNegated(isNegated) {}
415
416 /// The callable of this call.
417 Expr *callable;
418
419 /// The number of arguments of the call.
420 unsigned numArgs;
421
422 /// TrailingObject utilities.
423 friend llvm::TrailingObjects<CallExpr, Expr *>;
424
425 // Is the result of this call to be negated.
426 bool isNegated;
427};
428
429//===----------------------------------------------------------------------===//
430// DeclRefExpr
431//===----------------------------------------------------------------------===//
432
433/// This expression represents a reference to a Decl node.
434class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
435public:
436 static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type);
437
438 /// Get the decl referenced by this expression.
439 Decl *getDecl() const { return decl; }
440
441private:
442 DeclRefExpr(SMRange loc, Decl *decl, Type type)
443 : Base(loc, type), decl(decl) {}
444
445 /// The decl referenced by this expression.
446 Decl *decl;
447};
448
449//===----------------------------------------------------------------------===//
450// MemberAccessExpr
451//===----------------------------------------------------------------------===//
452
453/// This expression represents a named member or field access of a given parent
454/// expression.
455class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
456public:
457 static MemberAccessExpr *create(Context &ctx, SMRange loc,
458 const Expr *parentExpr, StringRef memberName,
459 Type type);
460
461 /// Get the parent expression of this access.
462 const Expr *getParentExpr() const { return parentExpr; }
463
464 /// Return the name of the member being accessed.
465 StringRef getMemberName() const { return memberName; }
466
467private:
468 MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName,
469 Type type)
470 : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
471
472 /// The parent expression of this access.
473 const Expr *parentExpr;
474
475 /// The name of the member being accessed from the parent.
476 StringRef memberName;
477};
478
479//===----------------------------------------------------------------------===//
480// AllResultsMemberAccessExpr
481
482/// This class represents an instance of MemberAccessExpr that references all
483/// results of an operation.
484class AllResultsMemberAccessExpr : public MemberAccessExpr {
485public:
486 /// Return the member name used for the "all-results" access.
487 static StringRef getMemberName() { return "$results"; }
488
489 static AllResultsMemberAccessExpr *create(Context &ctx, SMRange loc,
490 const Expr *parentExpr, Type type) {
491 return cast<AllResultsMemberAccessExpr>(
492 Val: MemberAccessExpr::create(ctx, loc, parentExpr, memberName: getMemberName(), type));
493 }
494
495 /// Provide type casting support.
496 static bool classof(const Node *node) {
497 const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(Val: node);
498 return memAccess && memAccess->getMemberName() == getMemberName();
499 }
500};
501
502//===----------------------------------------------------------------------===//
503// OperationExpr
504//===----------------------------------------------------------------------===//
505
506/// This expression represents the structural form of an MLIR Operation. It
507/// represents either an input operation to match, or an operation to create
508/// within a rewrite.
509class OperationExpr final
510 : public Node::NodeBase<OperationExpr, Expr>,
511 private llvm::TrailingObjects<OperationExpr, Expr *,
512 NamedAttributeDecl *> {
513public:
514 static OperationExpr *create(Context &ctx, SMRange loc,
515 const ods::Operation *odsOp,
516 const OpNameDecl *nameDecl,
517 ArrayRef<Expr *> operands,
518 ArrayRef<Expr *> resultTypes,
519 ArrayRef<NamedAttributeDecl *> attributes);
520
521 /// Return the name of the operation, or std::nullopt if there isn't one.
522 std::optional<StringRef> getName() const;
523
524 /// Return the declaration of the operation name.
525 const OpNameDecl *getNameDecl() const { return nameDecl; }
526
527 /// Return the location of the name of the operation expression, or an invalid
528 /// location if there isn't a name.
529 SMRange getNameLoc() const { return nameLoc; }
530
531 /// Return the operands of this operation.
532 MutableArrayRef<Expr *> getOperands() {
533 return {getTrailingObjects<Expr *>(), numOperands};
534 }
535 ArrayRef<Expr *> getOperands() const {
536 return const_cast<OperationExpr *>(this)->getOperands();
537 }
538
539 /// Return the result types of this operation.
540 MutableArrayRef<Expr *> getResultTypes() {
541 return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
542 }
543 MutableArrayRef<Expr *> getResultTypes() const {
544 return const_cast<OperationExpr *>(this)->getResultTypes();
545 }
546
547 /// Return the attributes of this operation.
548 MutableArrayRef<NamedAttributeDecl *> getAttributes() {
549 return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
550 }
551 MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
552 return const_cast<OperationExpr *>(this)->getAttributes();
553 }
554
555private:
556 OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
557 unsigned numOperands, unsigned numResultTypes,
558 unsigned numAttributes, SMRange nameLoc)
559 : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
560 numResultTypes(numResultTypes), numAttributes(numAttributes),
561 nameLoc(nameLoc) {}
562
563 /// The name decl of this expression.
564 const OpNameDecl *nameDecl;
565
566 /// The number of operands, result types, and attributes of the operation.
567 unsigned numOperands, numResultTypes, numAttributes;
568
569 /// The location of the operation name in the expression if it has a name.
570 SMRange nameLoc;
571
572 /// TrailingObject utilities.
573 friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
574 size_t numTrailingObjects(OverloadToken<Expr *>) const {
575 return numOperands + numResultTypes;
576 }
577};
578
579//===----------------------------------------------------------------------===//
580// RangeExpr
581//===----------------------------------------------------------------------===//
582
583/// This expression builds a range from a set of element values (which may be
584/// ranges themselves).
585class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
586 private llvm::TrailingObjects<RangeExpr, Expr *> {
587public:
588 static RangeExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
589 RangeType type);
590
591 /// Return the element expressions of this range.
592 MutableArrayRef<Expr *> getElements() {
593 return {getTrailingObjects<Expr *>(), numElements};
594 }
595 ArrayRef<Expr *> getElements() const {
596 return const_cast<RangeExpr *>(this)->getElements();
597 }
598
599 /// Return the range result type of this expression.
600 RangeType getType() const { return Base::getType().cast<RangeType>(); }
601
602private:
603 RangeExpr(SMRange loc, RangeType type, unsigned numElements)
604 : Base(loc, type), numElements(numElements) {}
605
606 /// The number of element values for this range.
607 unsigned numElements;
608
609 /// TrailingObject utilities.
610 friend class llvm::TrailingObjects<RangeExpr, Expr *>;
611};
612
613//===----------------------------------------------------------------------===//
614// TupleExpr
615//===----------------------------------------------------------------------===//
616
617/// This expression builds a tuple from a set of element values.
618class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
619 private llvm::TrailingObjects<TupleExpr, Expr *> {
620public:
621 static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
622 ArrayRef<StringRef> elementNames);
623
624 /// Return the element expressions of this tuple.
625 MutableArrayRef<Expr *> getElements() {
626 return {getTrailingObjects<Expr *>(), getType().size()};
627 }
628 ArrayRef<Expr *> getElements() const {
629 return const_cast<TupleExpr *>(this)->getElements();
630 }
631
632 /// Return the tuple result type of this expression.
633 TupleType getType() const { return Base::getType().cast<TupleType>(); }
634
635private:
636 TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
637
638 /// TrailingObject utilities.
639 friend class llvm::TrailingObjects<TupleExpr, Expr *>;
640};
641
642//===----------------------------------------------------------------------===//
643// TypeExpr
644//===----------------------------------------------------------------------===//
645
646/// This expression represents a literal MLIR Type, and contains the textual
647/// assembly format of that type.
648class TypeExpr : public Node::NodeBase<TypeExpr, Expr> {
649public:
650 static TypeExpr *create(Context &ctx, SMRange loc, StringRef value);
651
652 /// Get the raw value of this expression. This is the textual assembly format
653 /// of the MLIR Type.
654 StringRef getValue() const { return value; }
655
656private:
657 TypeExpr(Context &ctx, SMRange loc, StringRef value)
658 : Base(loc, TypeType::get(context&: ctx)), value(value) {}
659
660 /// The value referenced by this expression.
661 StringRef value;
662};
663
664//===----------------------------------------------------------------------===//
665// Decl
666//===----------------------------------------------------------------------===//
667
668/// This class represents the base Decl node.
669class Decl : public Node {
670public:
671 /// Return the name of the decl, or nullptr if it doesn't have one.
672 const Name *getName() const { return name; }
673
674 /// Provide type casting support.
675 static bool classof(const Node *node);
676
677 /// Set the documentation comment for this decl.
678 void setDocComment(Context &ctx, StringRef comment);
679
680 /// Return the documentation comment attached to this decl if it has been set.
681 /// Otherwise, returns std::nullopt.
682 std::optional<StringRef> getDocComment() const { return docComment; }
683
684protected:
685 Decl(TypeID typeID, SMRange loc, const Name *name = nullptr)
686 : Node(typeID, loc), name(name) {}
687
688private:
689 /// The name of the decl. This is optional for some decls, such as
690 /// PatternDecl.
691 const Name *name;
692
693 /// The documentation comment attached to this decl. Defaults to std::nullopt
694 /// if the comment is unset/unknown.
695 std::optional<StringRef> docComment;
696};
697
698//===----------------------------------------------------------------------===//
699// ConstraintDecl
700//===----------------------------------------------------------------------===//
701
702/// This class represents the base of all AST Constraint decls. Constraints
703/// apply matcher conditions to, and define the type of PDLL variables.
704class ConstraintDecl : public Decl {
705public:
706 /// Provide type casting support.
707 static bool classof(const Node *node);
708
709protected:
710 ConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
711 : Decl(typeID, loc, name) {}
712};
713
714/// This class represents a reference to a constraint, and contains a constraint
715/// and the location of the reference.
716struct ConstraintRef {
717 ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
718 : constraint(constraint), referenceLoc(refLoc) {}
719 explicit ConstraintRef(const ConstraintDecl *constraint)
720 : ConstraintRef(constraint, constraint->getLoc()) {}
721
722 const ConstraintDecl *constraint;
723 SMRange referenceLoc;
724};
725
726//===----------------------------------------------------------------------===//
727// CoreConstraintDecl
728//===----------------------------------------------------------------------===//
729
730/// This class represents the base of all "core" constraints. Core constraints
731/// are those that generally represent a concrete IR construct, such as
732/// `Type`s or `Value`s.
733class CoreConstraintDecl : public ConstraintDecl {
734public:
735 /// Provide type casting support.
736 static bool classof(const Node *node);
737
738protected:
739 CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
740 : ConstraintDecl(typeID, loc, name) {}
741};
742
743//===----------------------------------------------------------------------===//
744// AttrConstraintDecl
745
746/// The class represents an Attribute constraint, and constrains a variable to
747/// be an Attribute.
748class AttrConstraintDecl
749 : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
750public:
751 static AttrConstraintDecl *create(Context &ctx, SMRange loc,
752 Expr *typeExpr = nullptr);
753
754 /// Return the optional type the attribute is constrained to.
755 Expr *getTypeExpr() { return typeExpr; }
756 const Expr *getTypeExpr() const { return typeExpr; }
757
758protected:
759 AttrConstraintDecl(SMRange loc, Expr *typeExpr)
760 : Base(loc), typeExpr(typeExpr) {}
761
762 /// An optional type that the attribute is constrained to.
763 Expr *typeExpr;
764};
765
766//===----------------------------------------------------------------------===//
767// OpConstraintDecl
768
769/// The class represents an Operation constraint, and constrains a variable to
770/// be an Operation.
771class OpConstraintDecl
772 : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
773public:
774 static OpConstraintDecl *create(Context &ctx, SMRange loc,
775 const OpNameDecl *nameDecl = nullptr);
776
777 /// Return the name of the operation, or std::nullopt if there isn't one.
778 std::optional<StringRef> getName() const;
779
780 /// Return the declaration of the operation name.
781 const OpNameDecl *getNameDecl() const { return nameDecl; }
782
783protected:
784 explicit OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
785 : Base(loc), nameDecl(nameDecl) {}
786
787 /// The operation name of this constraint.
788 const OpNameDecl *nameDecl;
789};
790
791//===----------------------------------------------------------------------===//
792// TypeConstraintDecl
793
794/// The class represents a Type constraint, and constrains a variable to be a
795/// Type.
796class TypeConstraintDecl
797 : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
798public:
799 static TypeConstraintDecl *create(Context &ctx, SMRange loc);
800
801protected:
802 using Base::Base;
803};
804
805//===----------------------------------------------------------------------===//
806// TypeRangeConstraintDecl
807
808/// The class represents a TypeRange constraint, and constrains a variable to be
809/// a TypeRange.
810class TypeRangeConstraintDecl
811 : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
812public:
813 static TypeRangeConstraintDecl *create(Context &ctx, SMRange loc);
814
815protected:
816 using Base::Base;
817};
818
819//===----------------------------------------------------------------------===//
820// ValueConstraintDecl
821
822/// The class represents a Value constraint, and constrains a variable to be a
823/// Value.
824class ValueConstraintDecl
825 : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
826public:
827 static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr);
828
829 /// Return the optional type the value is constrained to.
830 Expr *getTypeExpr() { return typeExpr; }
831 const Expr *getTypeExpr() const { return typeExpr; }
832
833protected:
834 ValueConstraintDecl(SMRange loc, Expr *typeExpr)
835 : Base(loc), typeExpr(typeExpr) {}
836
837 /// An optional type that the value is constrained to.
838 Expr *typeExpr;
839};
840
841//===----------------------------------------------------------------------===//
842// ValueRangeConstraintDecl
843
844/// The class represents a ValueRange constraint, and constrains a variable to
845/// be a ValueRange.
846class ValueRangeConstraintDecl
847 : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
848public:
849 static ValueRangeConstraintDecl *create(Context &ctx, SMRange loc,
850 Expr *typeExpr = nullptr);
851
852 /// Return the optional type the value range is constrained to.
853 Expr *getTypeExpr() { return typeExpr; }
854 const Expr *getTypeExpr() const { return typeExpr; }
855
856protected:
857 ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr)
858 : Base(loc), typeExpr(typeExpr) {}
859
860 /// An optional type that the value range is constrained to.
861 Expr *typeExpr;
862};
863
864//===----------------------------------------------------------------------===//
865// UserConstraintDecl
866//===----------------------------------------------------------------------===//
867
868/// This decl represents a user defined constraint. This is either:
869/// * an imported native constraint
870/// - Similar to an external function declaration. This is a native
871/// constraint defined externally, and imported into PDLL via a
872/// declaration.
873/// * a native constraint defined in PDLL
874/// - This is a native constraint, i.e. a constraint whose implementation is
875/// defined in C++(or potentially some other non-PDLL language). The
876/// implementation of this constraint is specified as a string code block
877/// in PDLL.
878/// * a PDLL constraint
879/// - This is a constraint which is defined using only PDLL constructs.
880class UserConstraintDecl final
881 : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
882 llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
883public:
884 /// Create a native constraint with the given optional code block.
885 static UserConstraintDecl *
886 createNative(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
887 ArrayRef<VariableDecl *> results,
888 std::optional<StringRef> codeBlock, Type resultType,
889 ArrayRef<StringRef> nativeInputTypes = {}) {
890 return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
891 /*body=*/body: nullptr, resultType);
892 }
893
894 /// Create a PDLL constraint with the given body.
895 static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
896 ArrayRef<VariableDecl *> inputs,
897 ArrayRef<VariableDecl *> results,
898 const CompoundStmt *body,
899 Type resultType) {
900 return createImpl(ctx, name, inputs, /*nativeInputTypes=*/nativeInputTypes: std::nullopt,
901 results, /*codeBlock=*/codeBlock: std::nullopt, body, resultType);
902 }
903
904 /// Return the name of the constraint.
905 const Name &getName() const { return *Decl::getName(); }
906
907 /// Return the input arguments of this constraint.
908 MutableArrayRef<VariableDecl *> getInputs() {
909 return {getTrailingObjects<VariableDecl *>(), numInputs};
910 }
911 ArrayRef<VariableDecl *> getInputs() const {
912 return const_cast<UserConstraintDecl *>(this)->getInputs();
913 }
914
915 /// Return the explicit native type to use for the given input. Returns
916 /// std::nullopt if no explicit type was set.
917 std::optional<StringRef> getNativeInputType(unsigned index) const;
918
919 /// Return the explicit results of the constraint declaration. May be empty,
920 /// even if the constraint has results (e.g. in the case of inferred results).
921 MutableArrayRef<VariableDecl *> getResults() {
922 return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
923 }
924 ArrayRef<VariableDecl *> getResults() const {
925 return const_cast<UserConstraintDecl *>(this)->getResults();
926 }
927
928 /// Return the optional code block of this constraint, if this is a native
929 /// constraint with a provided implementation.
930 std::optional<StringRef> getCodeBlock() const { return codeBlock; }
931
932 /// Return the body of this constraint if this constraint is a PDLL
933 /// constraint, otherwise returns nullptr.
934 const CompoundStmt *getBody() const { return constraintBody; }
935
936 /// Return the result type of this constraint.
937 Type getResultType() const { return resultType; }
938
939 /// Returns true if this constraint is external.
940 bool isExternal() const { return !constraintBody && !codeBlock; }
941
942private:
943 /// Create either a PDLL constraint or a native constraint with the given
944 /// components.
945 static UserConstraintDecl *createImpl(Context &ctx, const Name &name,
946 ArrayRef<VariableDecl *> inputs,
947 ArrayRef<StringRef> nativeInputTypes,
948 ArrayRef<VariableDecl *> results,
949 std::optional<StringRef> codeBlock,
950 const CompoundStmt *body,
951 Type resultType);
952
953 UserConstraintDecl(const Name &name, unsigned numInputs,
954 bool hasNativeInputTypes, unsigned numResults,
955 std::optional<StringRef> codeBlock,
956 const CompoundStmt *body, Type resultType)
957 : Base(name.getLoc(), &name), numInputs(numInputs),
958 numResults(numResults), codeBlock(codeBlock), constraintBody(body),
959 resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {}
960
961 /// The number of inputs to this constraint.
962 unsigned numInputs;
963
964 /// The number of explicit results to this constraint.
965 unsigned numResults;
966
967 /// The optional code block of this constraint.
968 std::optional<StringRef> codeBlock;
969
970 /// The optional body of this constraint.
971 const CompoundStmt *constraintBody;
972
973 /// The result type of the constraint.
974 Type resultType;
975
976 /// Flag indicating if this constraint has explicit native input types.
977 bool hasNativeInputTypes;
978
979 /// Allow access to various internals.
980 friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
981 size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
982 return numInputs + numResults;
983 }
984};
985
986//===----------------------------------------------------------------------===//
987// NamedAttributeDecl
988//===----------------------------------------------------------------------===//
989
990/// This Decl represents a NamedAttribute, and contains a string name and
991/// attribute value.
992class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {
993public:
994 static NamedAttributeDecl *create(Context &ctx, const Name &name,
995 Expr *value);
996
997 /// Return the name of the attribute.
998 const Name &getName() const { return *Decl::getName(); }
999
1000 /// Return value of the attribute.
1001 Expr *getValue() const { return value; }
1002
1003private:
1004 NamedAttributeDecl(const Name &name, Expr *value)
1005 : Base(name.getLoc(), &name), value(value) {}
1006
1007 /// The value of the attribute.
1008 Expr *value;
1009};
1010
1011//===----------------------------------------------------------------------===//
1012// OpNameDecl
1013//===----------------------------------------------------------------------===//
1014
1015/// This Decl represents an OperationName.
1016class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
1017public:
1018 static OpNameDecl *create(Context &ctx, const Name &name);
1019 static OpNameDecl *create(Context &ctx, SMRange loc);
1020
1021 /// Return the name of this operation, or std::nullopt if the name is unknown.
1022 std::optional<StringRef> getName() const {
1023 const Name *name = Decl::getName();
1024 return name ? std::optional<StringRef>(name->getName()) : std::nullopt;
1025 }
1026
1027private:
1028 explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
1029 explicit OpNameDecl(SMRange loc) : Base(loc) {}
1030};
1031
1032//===----------------------------------------------------------------------===//
1033// PatternDecl
1034//===----------------------------------------------------------------------===//
1035
1036/// This Decl represents a single Pattern.
1037class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
1038public:
1039 static PatternDecl *create(Context &ctx, SMRange location, const Name *name,
1040 std::optional<uint16_t> benefit,
1041 bool hasBoundedRecursion,
1042 const CompoundStmt *body);
1043
1044 /// Return the benefit of this pattern if specified, or std::nullopt.
1045 std::optional<uint16_t> getBenefit() const { return benefit; }
1046
1047 /// Return if this pattern has bounded rewrite recursion.
1048 bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
1049
1050 /// Return the body of this pattern.
1051 const CompoundStmt *getBody() const { return patternBody; }
1052
1053 /// Return the root rewrite statement of this pattern.
1054 const OpRewriteStmt *getRootRewriteStmt() const {
1055 return cast<OpRewriteStmt>(Val: patternBody->getChildren().back());
1056 }
1057
1058private:
1059 PatternDecl(SMRange loc, const Name *name, std::optional<uint16_t> benefit,
1060 bool hasBoundedRecursion, const CompoundStmt *body)
1061 : Base(loc, name), benefit(benefit),
1062 hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
1063
1064 /// The benefit of the pattern if it was explicitly specified, std::nullopt
1065 /// otherwise.
1066 std::optional<uint16_t> benefit;
1067
1068 /// If the pattern has properly bounded rewrite recursion or not.
1069 bool hasBoundedRecursion;
1070
1071 /// The compound statement representing the body of the pattern.
1072 const CompoundStmt *patternBody;
1073};
1074
1075//===----------------------------------------------------------------------===//
1076// UserRewriteDecl
1077//===----------------------------------------------------------------------===//
1078
1079/// This decl represents a user defined rewrite. This is either:
1080/// * an imported native rewrite
1081/// - Similar to an external function declaration. This is a native
1082/// rewrite defined externally, and imported into PDLL via a declaration.
1083/// * a native rewrite defined in PDLL
1084/// - This is a native rewrite, i.e. a rewrite whose implementation is
1085/// defined in C++(or potentially some other non-PDLL language). The
1086/// implementation of this rewrite is specified as a string code block
1087/// in PDLL.
1088/// * a PDLL rewrite
1089/// - This is a rewrite which is defined using only PDLL constructs.
1090class UserRewriteDecl final
1091 : public Node::NodeBase<UserRewriteDecl, Decl>,
1092 llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
1093public:
1094 /// Create a native rewrite with the given optional code block.
1095 static UserRewriteDecl *createNative(Context &ctx, const Name &name,
1096 ArrayRef<VariableDecl *> inputs,
1097 ArrayRef<VariableDecl *> results,
1098 std::optional<StringRef> codeBlock,
1099 Type resultType) {
1100 return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/body: nullptr,
1101 resultType);
1102 }
1103
1104 /// Create a PDLL rewrite with the given body.
1105 static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
1106 ArrayRef<VariableDecl *> inputs,
1107 ArrayRef<VariableDecl *> results,
1108 const CompoundStmt *body,
1109 Type resultType) {
1110 return createImpl(ctx, name, inputs, results, /*codeBlock=*/codeBlock: std::nullopt,
1111 body, resultType);
1112 }
1113
1114 /// Return the name of the rewrite.
1115 const Name &getName() const { return *Decl::getName(); }
1116
1117 /// Return the input arguments of this rewrite.
1118 MutableArrayRef<VariableDecl *> getInputs() {
1119 return {getTrailingObjects<VariableDecl *>(), numInputs};
1120 }
1121 ArrayRef<VariableDecl *> getInputs() const {
1122 return const_cast<UserRewriteDecl *>(this)->getInputs();
1123 }
1124
1125 /// Return the explicit results of the rewrite declaration. May be empty,
1126 /// even if the rewrite has results (e.g. in the case of inferred results).
1127 MutableArrayRef<VariableDecl *> getResults() {
1128 return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
1129 }
1130 ArrayRef<VariableDecl *> getResults() const {
1131 return const_cast<UserRewriteDecl *>(this)->getResults();
1132 }
1133
1134 /// Return the optional code block of this rewrite, if this is a native
1135 /// rewrite with a provided implementation.
1136 std::optional<StringRef> getCodeBlock() const { return codeBlock; }
1137
1138 /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
1139 /// otherwise returns nullptr.
1140 const CompoundStmt *getBody() const { return rewriteBody; }
1141
1142 /// Return the result type of this rewrite.
1143 Type getResultType() const { return resultType; }
1144
1145 /// Returns true if this rewrite is external.
1146 bool isExternal() const { return !rewriteBody && !codeBlock; }
1147
1148private:
1149 /// Create either a PDLL rewrite or a native rewrite with the given
1150 /// components.
1151 static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
1152 ArrayRef<VariableDecl *> inputs,
1153 ArrayRef<VariableDecl *> results,
1154 std::optional<StringRef> codeBlock,
1155 const CompoundStmt *body, Type resultType);
1156
1157 UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
1158 std::optional<StringRef> codeBlock, const CompoundStmt *body,
1159 Type resultType)
1160 : Base(name.getLoc(), &name), numInputs(numInputs),
1161 numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
1162 resultType(resultType) {}
1163
1164 /// The number of inputs to this rewrite.
1165 unsigned numInputs;
1166
1167 /// The number of explicit results to this rewrite.
1168 unsigned numResults;
1169
1170 /// The optional code block of this rewrite.
1171 std::optional<StringRef> codeBlock;
1172
1173 /// The optional body of this rewrite.
1174 const CompoundStmt *rewriteBody;
1175
1176 /// The result type of the rewrite.
1177 Type resultType;
1178
1179 /// Allow access to various internals.
1180 friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
1181};
1182
1183//===----------------------------------------------------------------------===//
1184// CallableDecl
1185//===----------------------------------------------------------------------===//
1186
1187/// This decl represents a shared interface for all callable decls.
1188class CallableDecl : public Decl {
1189public:
1190 /// Return the callable type of this decl.
1191 StringRef getCallableType() const {
1192 if (isa<UserConstraintDecl>(Val: this))
1193 return "constraint";
1194 assert(isa<UserRewriteDecl>(this) && "unknown callable type");
1195 return "rewrite";
1196 }
1197
1198 /// Return the inputs of this decl.
1199 ArrayRef<VariableDecl *> getInputs() const {
1200 if (const auto *cst = dyn_cast<UserConstraintDecl>(Val: this))
1201 return cst->getInputs();
1202 return cast<UserRewriteDecl>(Val: this)->getInputs();
1203 }
1204
1205 /// Return the result type of this decl.
1206 Type getResultType() const {
1207 if (const auto *cst = dyn_cast<UserConstraintDecl>(Val: this))
1208 return cst->getResultType();
1209 return cast<UserRewriteDecl>(Val: this)->getResultType();
1210 }
1211
1212 /// Return the explicit results of the declaration. Note that these may be
1213 /// empty, even if the callable has results (e.g. in the case of inferred
1214 /// results).
1215 ArrayRef<VariableDecl *> getResults() const {
1216 if (const auto *cst = dyn_cast<UserConstraintDecl>(Val: this))
1217 return cst->getResults();
1218 return cast<UserRewriteDecl>(Val: this)->getResults();
1219 }
1220
1221 /// Return the optional code block of this callable, if this is a native
1222 /// callable with a provided implementation.
1223 std::optional<StringRef> getCodeBlock() const {
1224 if (const auto *cst = dyn_cast<UserConstraintDecl>(Val: this))
1225 return cst->getCodeBlock();
1226 return cast<UserRewriteDecl>(Val: this)->getCodeBlock();
1227 }
1228
1229 /// Support LLVM type casting facilities.
1230 static bool classof(const Node *decl) {
1231 return isa<UserConstraintDecl, UserRewriteDecl>(Val: decl);
1232 }
1233};
1234
1235//===----------------------------------------------------------------------===//
1236// VariableDecl
1237//===----------------------------------------------------------------------===//
1238
1239/// This Decl represents the definition of a PDLL variable.
1240class VariableDecl final
1241 : public Node::NodeBase<VariableDecl, Decl>,
1242 private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
1243public:
1244 static VariableDecl *create(Context &ctx, const Name &name, Type type,
1245 Expr *initExpr,
1246 ArrayRef<ConstraintRef> constraints);
1247
1248 /// Return the constraints of this variable.
1249 MutableArrayRef<ConstraintRef> getConstraints() {
1250 return {getTrailingObjects<ConstraintRef>(), numConstraints};
1251 }
1252 ArrayRef<ConstraintRef> getConstraints() const {
1253 return const_cast<VariableDecl *>(this)->getConstraints();
1254 }
1255
1256 /// Return the initializer expression of this statement, or nullptr if there
1257 /// was no initializer.
1258 Expr *getInitExpr() const { return initExpr; }
1259
1260 /// Return the name of the decl.
1261 const Name &getName() const { return *Decl::getName(); }
1262
1263 /// Return the type of the decl.
1264 Type getType() const { return type; }
1265
1266private:
1267 VariableDecl(const Name &name, Type type, Expr *initExpr,
1268 unsigned numConstraints)
1269 : Base(name.getLoc(), &name), type(type), initExpr(initExpr),
1270 numConstraints(numConstraints) {}
1271
1272 /// The type of the variable.
1273 Type type;
1274
1275 /// The optional initializer expression of this statement.
1276 Expr *initExpr;
1277
1278 /// The number of constraints attached to this variable.
1279 unsigned numConstraints;
1280
1281 /// Allow access to various internals.
1282 friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
1283};
1284
1285//===----------------------------------------------------------------------===//
1286// Module
1287//===----------------------------------------------------------------------===//
1288
1289/// This class represents a top-level AST module.
1290class Module final : public Node::NodeBase<Module, Node>,
1291 private llvm::TrailingObjects<Module, Decl *> {
1292public:
1293 static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children);
1294
1295 /// Return the children of this module.
1296 MutableArrayRef<Decl *> getChildren() {
1297 return {getTrailingObjects<Decl *>(), numChildren};
1298 }
1299 ArrayRef<Decl *> getChildren() const {
1300 return const_cast<Module *>(this)->getChildren();
1301 }
1302
1303private:
1304 Module(SMLoc loc, unsigned numChildren)
1305 : Base(SMRange{loc, loc}), numChildren(numChildren) {}
1306
1307 /// The number of decls held by this module.
1308 unsigned numChildren;
1309
1310 /// Allow access to various internals.
1311 friend llvm::TrailingObjects<Module, Decl *>;
1312};
1313
1314//===----------------------------------------------------------------------===//
1315// Defered Method Definitions
1316//===----------------------------------------------------------------------===//
1317
1318inline bool Decl::classof(const Node *node) {
1319 return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl,
1320 UserRewriteDecl, VariableDecl>(Val: node);
1321}
1322
1323inline bool ConstraintDecl::classof(const Node *node) {
1324 return isa<CoreConstraintDecl, UserConstraintDecl>(Val: node);
1325}
1326
1327inline bool CoreConstraintDecl::classof(const Node *node) {
1328 return isa<AttrConstraintDecl, OpConstraintDecl, TypeConstraintDecl,
1329 TypeRangeConstraintDecl, ValueConstraintDecl,
1330 ValueRangeConstraintDecl>(Val: node);
1331}
1332
1333inline bool Expr::classof(const Node *node) {
1334 return isa<AttributeExpr, CallExpr, DeclRefExpr, MemberAccessExpr,
1335 OperationExpr, RangeExpr, TupleExpr, TypeExpr>(Val: node);
1336}
1337
1338inline bool OpRewriteStmt::classof(const Node *node) {
1339 return isa<EraseStmt, ReplaceStmt, RewriteStmt>(Val: node);
1340}
1341
1342inline bool Stmt::classof(const Node *node) {
1343 return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(Val: node);
1344}
1345
1346} // namespace ast
1347} // namespace pdll
1348} // namespace mlir
1349
1350#endif // MLIR_TOOLS_PDLL_AST_NODES_H_
1351

source code of mlir/include/mlir/Tools/PDLL/AST/Nodes.h