1//===- NodePrinter.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Tools/PDLL/AST/Context.h"
10#include "mlir/Tools/PDLL/AST/Nodes.h"
11#include "llvm/ADT/StringExtras.h"
12#include "llvm/ADT/TypeSwitch.h"
13#include "llvm/Support/SaveAndRestore.h"
14#include "llvm/Support/ScopedPrinter.h"
15#include <optional>
16
17using namespace mlir;
18using namespace mlir::pdll::ast;
19
20//===----------------------------------------------------------------------===//
21// NodePrinter
22//===----------------------------------------------------------------------===//
23
24namespace {
25class NodePrinter {
26public:
27 NodePrinter(raw_ostream &os) : os(os) {}
28
29 /// Print the given type to the stream.
30 void print(Type type);
31
32 /// Print the given node to the stream.
33 void print(const Node *node);
34
35private:
36 /// Print a range containing children of a node.
37 template <typename RangeT,
38 std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
39 * = nullptr>
40 void printChildren(RangeT &&range) {
41 if (range.empty())
42 return;
43
44 // Print the first N-1 elements with a prefix of "|-".
45 auto it = std::begin(range);
46 for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
47 print(*it);
48
49 // Print the last element.
50 elementIndentStack.back() = true;
51 print(*it);
52 }
53 template <typename RangeT, typename... OthersT,
54 std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
55 * = nullptr>
56 void printChildren(RangeT &&range, OthersT &&...others) {
57 printChildren(range: ArrayRef<const Node *>({range, others...}));
58 }
59 /// Print a range containing children of a node, nesting the children under
60 /// the given label.
61 template <typename RangeT>
62 void printChildren(StringRef label, RangeT &&range) {
63 if (range.empty())
64 return;
65 elementIndentStack.reserve(N: elementIndentStack.size() + 1);
66 llvm::SaveAndRestore lastElement(elementIndentStack.back(), true);
67
68 printIndent();
69 os << label << "`\n";
70 elementIndentStack.push_back(/*isLastElt*/ Elt: false);
71 printChildren(std::forward<RangeT>(range));
72 elementIndentStack.pop_back();
73 }
74
75 /// Print the given derived node to the stream.
76 void printImpl(const CompoundStmt *stmt);
77 void printImpl(const EraseStmt *stmt);
78 void printImpl(const LetStmt *stmt);
79 void printImpl(const ReplaceStmt *stmt);
80 void printImpl(const ReturnStmt *stmt);
81 void printImpl(const RewriteStmt *stmt);
82
83 void printImpl(const AttributeExpr *expr);
84 void printImpl(const CallExpr *expr);
85 void printImpl(const DeclRefExpr *expr);
86 void printImpl(const MemberAccessExpr *expr);
87 void printImpl(const OperationExpr *expr);
88 void printImpl(const RangeExpr *expr);
89 void printImpl(const TupleExpr *expr);
90 void printImpl(const TypeExpr *expr);
91
92 void printImpl(const AttrConstraintDecl *decl);
93 void printImpl(const OpConstraintDecl *decl);
94 void printImpl(const TypeConstraintDecl *decl);
95 void printImpl(const TypeRangeConstraintDecl *decl);
96 void printImpl(const UserConstraintDecl *decl);
97 void printImpl(const ValueConstraintDecl *decl);
98 void printImpl(const ValueRangeConstraintDecl *decl);
99 void printImpl(const NamedAttributeDecl *decl);
100 void printImpl(const OpNameDecl *decl);
101 void printImpl(const PatternDecl *decl);
102 void printImpl(const UserRewriteDecl *decl);
103 void printImpl(const VariableDecl *decl);
104 void printImpl(const Module *module);
105
106 /// Print the current indent stack.
107 void printIndent() {
108 if (elementIndentStack.empty())
109 return;
110
111 for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back())
112 os << (isLastElt ? " " : " |");
113 os << (elementIndentStack.back() ? " `" : " |");
114 }
115
116 /// The raw output stream.
117 raw_ostream &os;
118
119 /// A stack of indents and a flag indicating if the current element being
120 /// printed at that indent is the last element.
121 SmallVector<bool> elementIndentStack;
122};
123} // namespace
124
125void NodePrinter::print(Type type) {
126 // Protect against invalid inputs.
127 if (!type) {
128 os << "Type<NULL>";
129 return;
130 }
131
132 TypeSwitch<Type>(type)
133 .Case(caseFn: [&](AttributeType) { os << "Attr"; })
134 .Case(caseFn: [&](ConstraintType) { os << "Constraint"; })
135 .Case(caseFn: [&](OperationType type) {
136 os << "Op";
137 if (std::optional<StringRef> name = type.getName())
138 os << "<" << *name << ">";
139 })
140 .Case(caseFn: [&](RangeType type) {
141 print(type: type.getElementType());
142 os << "Range";
143 })
144 .Case(caseFn: [&](RewriteType) { os << "Rewrite"; })
145 .Case(caseFn: [&](TupleType type) {
146 os << "Tuple<";
147 llvm::interleaveComma(
148 c: llvm::zip(t: type.getElementNames(), u: type.getElementTypes()), os,
149 each_fn: [&](auto it) {
150 if (!std::get<0>(it).empty())
151 os << std::get<0>(it) << ": ";
152 this->print(std::get<1>(it));
153 });
154 os << ">";
155 })
156 .Case(caseFn: [&](TypeType) { os << "Type"; })
157 .Case(caseFn: [&](ValueType) { os << "Value"; })
158 .Default(defaultFn: [](Type) { llvm_unreachable("unknown AST type"); });
159}
160
161void NodePrinter::print(const Node *node) {
162 printIndent();
163 os << "-";
164
165 elementIndentStack.push_back(/*isLastElt*/ Elt: false);
166 TypeSwitch<const Node *>(node)
167 .Case<
168 // Statements.
169 const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
170 const ReturnStmt, const RewriteStmt,
171
172 // Expressions.
173 const AttributeExpr, const CallExpr, const DeclRefExpr,
174 const MemberAccessExpr, const OperationExpr, const RangeExpr,
175 const TupleExpr, const TypeExpr,
176
177 // Decls.
178 const AttrConstraintDecl, const OpConstraintDecl,
179 const TypeConstraintDecl, const TypeRangeConstraintDecl,
180 const UserConstraintDecl, const ValueConstraintDecl,
181 const ValueRangeConstraintDecl, const NamedAttributeDecl,
182 const OpNameDecl, const PatternDecl, const UserRewriteDecl,
183 const VariableDecl,
184
185 const Module>(caseFn: [&](auto derivedNode) { this->printImpl(derivedNode); })
186 .Default(defaultFn: [](const Node *) { llvm_unreachable("unknown AST node"); });
187 elementIndentStack.pop_back();
188}
189
190void NodePrinter::printImpl(const CompoundStmt *stmt) {
191 os << "CompoundStmt " << stmt << "\n";
192 printChildren(range: stmt->getChildren());
193}
194
195void NodePrinter::printImpl(const EraseStmt *stmt) {
196 os << "EraseStmt " << stmt << "\n";
197 printChildren(range: stmt->getRootOpExpr());
198}
199
200void NodePrinter::printImpl(const LetStmt *stmt) {
201 os << "LetStmt " << stmt << "\n";
202 printChildren(range: stmt->getVarDecl());
203}
204
205void NodePrinter::printImpl(const ReplaceStmt *stmt) {
206 os << "ReplaceStmt " << stmt << "\n";
207 printChildren(range: stmt->getRootOpExpr());
208 printChildren(label: "ReplValues", range: stmt->getReplExprs());
209}
210
211void NodePrinter::printImpl(const ReturnStmt *stmt) {
212 os << "ReturnStmt " << stmt << "\n";
213 printChildren(range: stmt->getResultExpr());
214}
215
216void NodePrinter::printImpl(const RewriteStmt *stmt) {
217 os << "RewriteStmt " << stmt << "\n";
218 printChildren(range: stmt->getRootOpExpr(), others: stmt->getRewriteBody());
219}
220
221void NodePrinter::printImpl(const AttributeExpr *expr) {
222 os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
223}
224
225void NodePrinter::printImpl(const CallExpr *expr) {
226 os << "CallExpr " << expr << " Type<";
227 print(type: expr->getType());
228 os << ">";
229 if (expr->getIsNegated())
230 os << " Negated";
231 os << "\n";
232 printChildren(range: expr->getCallableExpr());
233 printChildren(label: "Arguments", range: expr->getArguments());
234}
235
236void NodePrinter::printImpl(const DeclRefExpr *expr) {
237 os << "DeclRefExpr " << expr << " Type<";
238 print(type: expr->getType());
239 os << ">\n";
240 printChildren(range: expr->getDecl());
241}
242
243void NodePrinter::printImpl(const MemberAccessExpr *expr) {
244 os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
245 << "> Type<";
246 print(type: expr->getType());
247 os << ">\n";
248 printChildren(range: expr->getParentExpr());
249}
250
251void NodePrinter::printImpl(const OperationExpr *expr) {
252 os << "OperationExpr " << expr << " Type<";
253 print(type: expr->getType());
254 os << ">\n";
255
256 printChildren(range: expr->getNameDecl());
257 printChildren(label: "Operands", range: expr->getOperands());
258 printChildren(label: "Result Types", range: expr->getResultTypes());
259 printChildren(label: "Attributes", range: expr->getAttributes());
260}
261
262void NodePrinter::printImpl(const RangeExpr *expr) {
263 os << "RangeExpr " << expr << " Type<";
264 print(type: expr->getType());
265 os << ">\n";
266
267 printChildren(range: expr->getElements());
268}
269
270void NodePrinter::printImpl(const TupleExpr *expr) {
271 os << "TupleExpr " << expr << " Type<";
272 print(type: expr->getType());
273 os << ">\n";
274
275 printChildren(range: expr->getElements());
276}
277
278void NodePrinter::printImpl(const TypeExpr *expr) {
279 os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
280}
281
282void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
283 os << "AttrConstraintDecl " << decl << "\n";
284 if (const auto *typeExpr = decl->getTypeExpr())
285 printChildren(range&: typeExpr);
286}
287
288void NodePrinter::printImpl(const OpConstraintDecl *decl) {
289 os << "OpConstraintDecl " << decl << "\n";
290 printChildren(range: decl->getNameDecl());
291}
292
293void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
294 os << "TypeConstraintDecl " << decl << "\n";
295}
296
297void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
298 os << "TypeRangeConstraintDecl " << decl << "\n";
299}
300
301void NodePrinter::printImpl(const UserConstraintDecl *decl) {
302 os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
303 << "> ResultType<" << decl->getResultType() << ">";
304 if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
305 os << " Code<";
306 llvm::printEscapedString(Name: *codeBlock, Out&: os);
307 os << ">";
308 }
309 os << "\n";
310 printChildren(label: "Inputs", range: decl->getInputs());
311 printChildren(label: "Results", range: decl->getResults());
312 if (const CompoundStmt *body = decl->getBody())
313 printChildren(range&: body);
314}
315
316void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
317 os << "ValueConstraintDecl " << decl << "\n";
318 if (const auto *typeExpr = decl->getTypeExpr())
319 printChildren(range&: typeExpr);
320}
321
322void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
323 os << "ValueRangeConstraintDecl " << decl << "\n";
324 if (const auto *typeExpr = decl->getTypeExpr())
325 printChildren(range&: typeExpr);
326}
327
328void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
329 os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
330 << ">\n";
331 printChildren(range: decl->getValue());
332}
333
334void NodePrinter::printImpl(const OpNameDecl *decl) {
335 os << "OpNameDecl " << decl;
336 if (std::optional<StringRef> name = decl->getName())
337 os << " Name<" << *name << ">";
338 os << "\n";
339}
340
341void NodePrinter::printImpl(const PatternDecl *decl) {
342 os << "PatternDecl " << decl;
343 if (const Name *name = decl->getName())
344 os << " Name<" << name->getName() << ">";
345 if (std::optional<uint16_t> benefit = decl->getBenefit())
346 os << " Benefit<" << *benefit << ">";
347 if (decl->hasBoundedRewriteRecursion())
348 os << " Recursion";
349
350 os << "\n";
351 printChildren(range: decl->getBody());
352}
353
354void NodePrinter::printImpl(const UserRewriteDecl *decl) {
355 os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
356 << "> ResultType<" << decl->getResultType() << ">";
357 if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
358 os << " Code<";
359 llvm::printEscapedString(Name: *codeBlock, Out&: os);
360 os << ">";
361 }
362 os << "\n";
363 printChildren(label: "Inputs", range: decl->getInputs());
364 printChildren(label: "Results", range: decl->getResults());
365 if (const CompoundStmt *body = decl->getBody())
366 printChildren(range&: body);
367}
368
369void NodePrinter::printImpl(const VariableDecl *decl) {
370 os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
371 << "> Type<";
372 print(type: decl->getType());
373 os << ">\n";
374 if (Expr *initExpr = decl->getInitExpr())
375 printChildren(range&: initExpr);
376
377 auto constraints =
378 llvm::map_range(C: decl->getConstraints(),
379 F: [](const ConstraintRef &ref) { return ref.constraint; });
380 printChildren(label: "Constraints", range&: constraints);
381}
382
383void NodePrinter::printImpl(const Module *module) {
384 os << "Module " << module << "\n";
385 printChildren(range: module->getChildren());
386}
387
388//===----------------------------------------------------------------------===//
389// Entry point
390//===----------------------------------------------------------------------===//
391
392void Node::print(raw_ostream &os) const { NodePrinter(os).print(node: this); }
393
394void Type::print(raw_ostream &os) const { NodePrinter(os).print(type: *this); }
395

source code of mlir/lib/Tools/PDLL/AST/NodePrinter.cpp