1 | //===- NodePrinter.cpp ----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/PDLL/AST/Context.h" |
10 | #include "mlir/Tools/PDLL/AST/Nodes.h" |
11 | #include "llvm/ADT/StringExtras.h" |
12 | #include "llvm/ADT/TypeSwitch.h" |
13 | #include "llvm/Support/SaveAndRestore.h" |
14 | #include "llvm/Support/ScopedPrinter.h" |
15 | #include <optional> |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::pdll::ast; |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // NodePrinter |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | namespace { |
25 | class NodePrinter { |
26 | public: |
27 | NodePrinter(raw_ostream &os) : os(os) {} |
28 | |
29 | /// Print the given type to the stream. |
30 | void print(Type type); |
31 | |
32 | /// Print the given node to the stream. |
33 | void print(const Node *node); |
34 | |
35 | private: |
36 | /// Print a range containing children of a node. |
37 | template <typename RangeT, |
38 | std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value> |
39 | * = nullptr> |
40 | void printChildren(RangeT &&range) { |
41 | if (range.empty()) |
42 | return; |
43 | |
44 | // Print the first N-1 elements with a prefix of "|-". |
45 | auto it = std::begin(range); |
46 | for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it) |
47 | print(*it); |
48 | |
49 | // Print the last element. |
50 | elementIndentStack.back() = true; |
51 | print(*it); |
52 | } |
53 | template <typename RangeT, typename... OthersT, |
54 | std::enable_if_t<std::is_convertible<RangeT, const Node *>::value> |
55 | * = nullptr> |
56 | void printChildren(RangeT &&range, OthersT &&...others) { |
57 | printChildren(range: ArrayRef<const Node *>({range, others...})); |
58 | } |
59 | /// Print a range containing children of a node, nesting the children under |
60 | /// the given label. |
61 | template <typename RangeT> |
62 | void printChildren(StringRef label, RangeT &&range) { |
63 | if (range.empty()) |
64 | return; |
65 | elementIndentStack.reserve(N: elementIndentStack.size() + 1); |
66 | llvm::SaveAndRestore lastElement(elementIndentStack.back(), true); |
67 | |
68 | printIndent(); |
69 | os << label << "`\n" ; |
70 | elementIndentStack.push_back(/*isLastElt*/ Elt: false); |
71 | printChildren(std::forward<RangeT>(range)); |
72 | elementIndentStack.pop_back(); |
73 | } |
74 | |
75 | /// Print the given derived node to the stream. |
76 | void printImpl(const CompoundStmt *stmt); |
77 | void printImpl(const EraseStmt *stmt); |
78 | void printImpl(const LetStmt *stmt); |
79 | void printImpl(const ReplaceStmt *stmt); |
80 | void printImpl(const ReturnStmt *stmt); |
81 | void printImpl(const RewriteStmt *stmt); |
82 | |
83 | void printImpl(const AttributeExpr *expr); |
84 | void printImpl(const CallExpr *expr); |
85 | void printImpl(const DeclRefExpr *expr); |
86 | void printImpl(const MemberAccessExpr *expr); |
87 | void printImpl(const OperationExpr *expr); |
88 | void printImpl(const RangeExpr *expr); |
89 | void printImpl(const TupleExpr *expr); |
90 | void printImpl(const TypeExpr *expr); |
91 | |
92 | void printImpl(const AttrConstraintDecl *decl); |
93 | void printImpl(const OpConstraintDecl *decl); |
94 | void printImpl(const TypeConstraintDecl *decl); |
95 | void printImpl(const TypeRangeConstraintDecl *decl); |
96 | void printImpl(const UserConstraintDecl *decl); |
97 | void printImpl(const ValueConstraintDecl *decl); |
98 | void printImpl(const ValueRangeConstraintDecl *decl); |
99 | void printImpl(const NamedAttributeDecl *decl); |
100 | void printImpl(const OpNameDecl *decl); |
101 | void printImpl(const PatternDecl *decl); |
102 | void printImpl(const UserRewriteDecl *decl); |
103 | void printImpl(const VariableDecl *decl); |
104 | void printImpl(const Module *module); |
105 | |
106 | /// Print the current indent stack. |
107 | void printIndent() { |
108 | if (elementIndentStack.empty()) |
109 | return; |
110 | |
111 | for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back()) |
112 | os << (isLastElt ? " " : " |" ); |
113 | os << (elementIndentStack.back() ? " `" : " |" ); |
114 | } |
115 | |
116 | /// The raw output stream. |
117 | raw_ostream &os; |
118 | |
119 | /// A stack of indents and a flag indicating if the current element being |
120 | /// printed at that indent is the last element. |
121 | SmallVector<bool> elementIndentStack; |
122 | }; |
123 | } // namespace |
124 | |
125 | void NodePrinter::print(Type type) { |
126 | // Protect against invalid inputs. |
127 | if (!type) { |
128 | os << "Type<NULL>" ; |
129 | return; |
130 | } |
131 | |
132 | TypeSwitch<Type>(type) |
133 | .Case(caseFn: [&](AttributeType) { os << "Attr" ; }) |
134 | .Case(caseFn: [&](ConstraintType) { os << "Constraint" ; }) |
135 | .Case(caseFn: [&](OperationType type) { |
136 | os << "Op" ; |
137 | if (std::optional<StringRef> name = type.getName()) |
138 | os << "<" << *name << ">" ; |
139 | }) |
140 | .Case(caseFn: [&](RangeType type) { |
141 | print(type: type.getElementType()); |
142 | os << "Range" ; |
143 | }) |
144 | .Case(caseFn: [&](RewriteType) { os << "Rewrite" ; }) |
145 | .Case(caseFn: [&](TupleType type) { |
146 | os << "Tuple<" ; |
147 | llvm::interleaveComma( |
148 | c: llvm::zip(t: type.getElementNames(), u: type.getElementTypes()), os, |
149 | each_fn: [&](auto it) { |
150 | if (!std::get<0>(it).empty()) |
151 | os << std::get<0>(it) << ": " ; |
152 | this->print(std::get<1>(it)); |
153 | }); |
154 | os << ">" ; |
155 | }) |
156 | .Case(caseFn: [&](TypeType) { os << "Type" ; }) |
157 | .Case(caseFn: [&](ValueType) { os << "Value" ; }) |
158 | .Default(defaultFn: [](Type) { llvm_unreachable("unknown AST type" ); }); |
159 | } |
160 | |
161 | void NodePrinter::print(const Node *node) { |
162 | printIndent(); |
163 | os << "-" ; |
164 | |
165 | elementIndentStack.push_back(/*isLastElt*/ Elt: false); |
166 | TypeSwitch<const Node *>(node) |
167 | .Case< |
168 | // Statements. |
169 | const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt, |
170 | const ReturnStmt, const RewriteStmt, |
171 | |
172 | // Expressions. |
173 | const AttributeExpr, const CallExpr, const DeclRefExpr, |
174 | const MemberAccessExpr, const OperationExpr, const RangeExpr, |
175 | const TupleExpr, const TypeExpr, |
176 | |
177 | // Decls. |
178 | const AttrConstraintDecl, const OpConstraintDecl, |
179 | const TypeConstraintDecl, const TypeRangeConstraintDecl, |
180 | const UserConstraintDecl, const ValueConstraintDecl, |
181 | const ValueRangeConstraintDecl, const NamedAttributeDecl, |
182 | const OpNameDecl, const PatternDecl, const UserRewriteDecl, |
183 | const VariableDecl, |
184 | |
185 | const Module>(caseFn: [&](auto derivedNode) { this->printImpl(derivedNode); }) |
186 | .Default(defaultFn: [](const Node *) { llvm_unreachable("unknown AST node" ); }); |
187 | elementIndentStack.pop_back(); |
188 | } |
189 | |
190 | void NodePrinter::printImpl(const CompoundStmt *stmt) { |
191 | os << "CompoundStmt " << stmt << "\n" ; |
192 | printChildren(range: stmt->getChildren()); |
193 | } |
194 | |
195 | void NodePrinter::printImpl(const EraseStmt *stmt) { |
196 | os << "EraseStmt " << stmt << "\n" ; |
197 | printChildren(range: stmt->getRootOpExpr()); |
198 | } |
199 | |
200 | void NodePrinter::printImpl(const LetStmt *stmt) { |
201 | os << "LetStmt " << stmt << "\n" ; |
202 | printChildren(range: stmt->getVarDecl()); |
203 | } |
204 | |
205 | void NodePrinter::printImpl(const ReplaceStmt *stmt) { |
206 | os << "ReplaceStmt " << stmt << "\n" ; |
207 | printChildren(range: stmt->getRootOpExpr()); |
208 | printChildren(label: "ReplValues" , range: stmt->getReplExprs()); |
209 | } |
210 | |
211 | void NodePrinter::printImpl(const ReturnStmt *stmt) { |
212 | os << "ReturnStmt " << stmt << "\n" ; |
213 | printChildren(range: stmt->getResultExpr()); |
214 | } |
215 | |
216 | void NodePrinter::printImpl(const RewriteStmt *stmt) { |
217 | os << "RewriteStmt " << stmt << "\n" ; |
218 | printChildren(range: stmt->getRootOpExpr(), others: stmt->getRewriteBody()); |
219 | } |
220 | |
221 | void NodePrinter::printImpl(const AttributeExpr *expr) { |
222 | os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n" ; |
223 | } |
224 | |
225 | void NodePrinter::printImpl(const CallExpr *expr) { |
226 | os << "CallExpr " << expr << " Type<" ; |
227 | print(type: expr->getType()); |
228 | os << ">" ; |
229 | if (expr->getIsNegated()) |
230 | os << " Negated" ; |
231 | os << "\n" ; |
232 | printChildren(range: expr->getCallableExpr()); |
233 | printChildren(label: "Arguments" , range: expr->getArguments()); |
234 | } |
235 | |
236 | void NodePrinter::printImpl(const DeclRefExpr *expr) { |
237 | os << "DeclRefExpr " << expr << " Type<" ; |
238 | print(type: expr->getType()); |
239 | os << ">\n" ; |
240 | printChildren(range: expr->getDecl()); |
241 | } |
242 | |
243 | void NodePrinter::printImpl(const MemberAccessExpr *expr) { |
244 | os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName() |
245 | << "> Type<" ; |
246 | print(type: expr->getType()); |
247 | os << ">\n" ; |
248 | printChildren(range: expr->getParentExpr()); |
249 | } |
250 | |
251 | void NodePrinter::printImpl(const OperationExpr *expr) { |
252 | os << "OperationExpr " << expr << " Type<" ; |
253 | print(type: expr->getType()); |
254 | os << ">\n" ; |
255 | |
256 | printChildren(range: expr->getNameDecl()); |
257 | printChildren(label: "Operands" , range: expr->getOperands()); |
258 | printChildren(label: "Result Types" , range: expr->getResultTypes()); |
259 | printChildren(label: "Attributes" , range: expr->getAttributes()); |
260 | } |
261 | |
262 | void NodePrinter::printImpl(const RangeExpr *expr) { |
263 | os << "RangeExpr " << expr << " Type<" ; |
264 | print(type: expr->getType()); |
265 | os << ">\n" ; |
266 | |
267 | printChildren(range: expr->getElements()); |
268 | } |
269 | |
270 | void NodePrinter::printImpl(const TupleExpr *expr) { |
271 | os << "TupleExpr " << expr << " Type<" ; |
272 | print(type: expr->getType()); |
273 | os << ">\n" ; |
274 | |
275 | printChildren(range: expr->getElements()); |
276 | } |
277 | |
278 | void NodePrinter::printImpl(const TypeExpr *expr) { |
279 | os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n" ; |
280 | } |
281 | |
282 | void NodePrinter::printImpl(const AttrConstraintDecl *decl) { |
283 | os << "AttrConstraintDecl " << decl << "\n" ; |
284 | if (const auto *typeExpr = decl->getTypeExpr()) |
285 | printChildren(range&: typeExpr); |
286 | } |
287 | |
288 | void NodePrinter::printImpl(const OpConstraintDecl *decl) { |
289 | os << "OpConstraintDecl " << decl << "\n" ; |
290 | printChildren(range: decl->getNameDecl()); |
291 | } |
292 | |
293 | void NodePrinter::printImpl(const TypeConstraintDecl *decl) { |
294 | os << "TypeConstraintDecl " << decl << "\n" ; |
295 | } |
296 | |
297 | void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) { |
298 | os << "TypeRangeConstraintDecl " << decl << "\n" ; |
299 | } |
300 | |
301 | void NodePrinter::printImpl(const UserConstraintDecl *decl) { |
302 | os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName() |
303 | << "> ResultType<" << decl->getResultType() << ">" ; |
304 | if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) { |
305 | os << " Code<" ; |
306 | llvm::printEscapedString(Name: *codeBlock, Out&: os); |
307 | os << ">" ; |
308 | } |
309 | os << "\n" ; |
310 | printChildren(label: "Inputs" , range: decl->getInputs()); |
311 | printChildren(label: "Results" , range: decl->getResults()); |
312 | if (const CompoundStmt *body = decl->getBody()) |
313 | printChildren(range&: body); |
314 | } |
315 | |
316 | void NodePrinter::printImpl(const ValueConstraintDecl *decl) { |
317 | os << "ValueConstraintDecl " << decl << "\n" ; |
318 | if (const auto *typeExpr = decl->getTypeExpr()) |
319 | printChildren(range&: typeExpr); |
320 | } |
321 | |
322 | void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) { |
323 | os << "ValueRangeConstraintDecl " << decl << "\n" ; |
324 | if (const auto *typeExpr = decl->getTypeExpr()) |
325 | printChildren(range&: typeExpr); |
326 | } |
327 | |
328 | void NodePrinter::printImpl(const NamedAttributeDecl *decl) { |
329 | os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName() |
330 | << ">\n" ; |
331 | printChildren(range: decl->getValue()); |
332 | } |
333 | |
334 | void NodePrinter::printImpl(const OpNameDecl *decl) { |
335 | os << "OpNameDecl " << decl; |
336 | if (std::optional<StringRef> name = decl->getName()) |
337 | os << " Name<" << *name << ">" ; |
338 | os << "\n" ; |
339 | } |
340 | |
341 | void NodePrinter::printImpl(const PatternDecl *decl) { |
342 | os << "PatternDecl " << decl; |
343 | if (const Name *name = decl->getName()) |
344 | os << " Name<" << name->getName() << ">" ; |
345 | if (std::optional<uint16_t> benefit = decl->getBenefit()) |
346 | os << " Benefit<" << *benefit << ">" ; |
347 | if (decl->hasBoundedRewriteRecursion()) |
348 | os << " Recursion" ; |
349 | |
350 | os << "\n" ; |
351 | printChildren(range: decl->getBody()); |
352 | } |
353 | |
354 | void NodePrinter::printImpl(const UserRewriteDecl *decl) { |
355 | os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName() |
356 | << "> ResultType<" << decl->getResultType() << ">" ; |
357 | if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) { |
358 | os << " Code<" ; |
359 | llvm::printEscapedString(Name: *codeBlock, Out&: os); |
360 | os << ">" ; |
361 | } |
362 | os << "\n" ; |
363 | printChildren(label: "Inputs" , range: decl->getInputs()); |
364 | printChildren(label: "Results" , range: decl->getResults()); |
365 | if (const CompoundStmt *body = decl->getBody()) |
366 | printChildren(range&: body); |
367 | } |
368 | |
369 | void NodePrinter::printImpl(const VariableDecl *decl) { |
370 | os << "VariableDecl " << decl << " Name<" << decl->getName().getName() |
371 | << "> Type<" ; |
372 | print(type: decl->getType()); |
373 | os << ">\n" ; |
374 | if (Expr *initExpr = decl->getInitExpr()) |
375 | printChildren(range&: initExpr); |
376 | |
377 | auto constraints = |
378 | llvm::map_range(C: decl->getConstraints(), |
379 | F: [](const ConstraintRef &ref) { return ref.constraint; }); |
380 | printChildren(label: "Constraints" , range&: constraints); |
381 | } |
382 | |
383 | void NodePrinter::printImpl(const Module *module) { |
384 | os << "Module " << module << "\n" ; |
385 | printChildren(range: module->getChildren()); |
386 | } |
387 | |
388 | //===----------------------------------------------------------------------===// |
389 | // Entry point |
390 | //===----------------------------------------------------------------------===// |
391 | |
392 | void Node::print(raw_ostream &os) const { NodePrinter(os).print(node: this); } |
393 | |
394 | void Type::print(raw_ostream &os) const { NodePrinter(os).print(type: *this); } |
395 | |