1 | //===- Context.cpp --------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/PDLL/ODS/Context.h" |
10 | #include "mlir/Tools/PDLL/ODS/Constraint.h" |
11 | #include "mlir/Tools/PDLL/ODS/Dialect.h" |
12 | #include "mlir/Tools/PDLL/ODS/Operation.h" |
13 | #include "llvm/Support/ScopedPrinter.h" |
14 | #include "llvm/Support/raw_ostream.h" |
15 | #include <optional> |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::pdll::ods; |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // Context |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | Context::Context() = default; |
25 | Context::~Context() = default; |
26 | |
27 | const AttributeConstraint & |
28 | Context::insertAttributeConstraint(StringRef name, StringRef summary, |
29 | StringRef cppClass) { |
30 | std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name]; |
31 | if (!constraint) { |
32 | constraint.reset(p: new AttributeConstraint(name, summary, cppClass)); |
33 | } else { |
34 | assert(constraint->getCppClass() == cppClass && |
35 | constraint->getSummary() == summary && |
36 | "constraint with the same name was already registered with a " |
37 | "different class" ); |
38 | } |
39 | return *constraint; |
40 | } |
41 | |
42 | const TypeConstraint &Context::insertTypeConstraint(StringRef name, |
43 | StringRef summary, |
44 | StringRef cppClass) { |
45 | std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name]; |
46 | if (!constraint) |
47 | constraint.reset(p: new TypeConstraint(name, summary, cppClass)); |
48 | return *constraint; |
49 | } |
50 | |
51 | Dialect &Context::insertDialect(StringRef name) { |
52 | std::unique_ptr<Dialect> &dialect = dialects[name]; |
53 | if (!dialect) |
54 | dialect.reset(p: new Dialect(name)); |
55 | return *dialect; |
56 | } |
57 | |
58 | const Dialect *Context::lookupDialect(StringRef name) const { |
59 | auto it = dialects.find(Key: name); |
60 | return it == dialects.end() ? nullptr : &*it->second; |
61 | } |
62 | |
63 | std::pair<Operation *, bool> |
64 | Context::insertOperation(StringRef name, StringRef summary, StringRef desc, |
65 | StringRef nativeClassName, |
66 | bool supportsResultTypeInferrence, SMLoc loc) { |
67 | std::pair<StringRef, StringRef> dialectAndName = name.split(Separator: '.'); |
68 | return insertDialect(name: dialectAndName.first) |
69 | .insertOperation(name, summary, desc, nativeClassName, |
70 | supportsResultTypeInferrence, loc); |
71 | } |
72 | |
73 | const Operation *Context::lookupOperation(StringRef name) const { |
74 | std::pair<StringRef, StringRef> dialectAndName = name.split(Separator: '.'); |
75 | if (const Dialect *dialect = lookupDialect(name: dialectAndName.first)) |
76 | return dialect->lookupOperation(name); |
77 | return nullptr; |
78 | } |
79 | |
80 | template <typename T> |
81 | SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) { |
82 | SmallVector<T *> storage; |
83 | for (auto &entry : map) |
84 | storage.push_back(entry.second.get()); |
85 | llvm::sort(storage, [](const auto &lhs, const auto &rhs) { |
86 | return lhs->getName() < rhs->getName(); |
87 | }); |
88 | return storage; |
89 | } |
90 | |
91 | void Context::print(raw_ostream &os) const { |
92 | auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) { |
93 | switch (kind) { |
94 | case VariableLengthKind::Optional: |
95 | os << "Optional<" << cst << ">" ; |
96 | break; |
97 | case VariableLengthKind::Single: |
98 | os << cst; |
99 | break; |
100 | case VariableLengthKind::Variadic: |
101 | os << "Variadic<" << cst << ">" ; |
102 | break; |
103 | } |
104 | }; |
105 | |
106 | llvm::ScopedPrinter printer(os); |
107 | llvm::DictScope odsScope(printer, "ODSContext" ); |
108 | for (const Dialect *dialect : sortMapByName(map: dialects)) { |
109 | printer.startLine() << "Dialect `" << dialect->getName() << "` {\n" ; |
110 | printer.indent(); |
111 | |
112 | for (const Operation *op : sortMapByName(map: dialect->getOperations())) { |
113 | printer.startLine() << "Operation `" << op->getName() << "` {\n" ; |
114 | printer.indent(); |
115 | |
116 | // Attributes. |
117 | ArrayRef<Attribute> attributes = op->getAttributes(); |
118 | if (!attributes.empty()) { |
119 | printer.startLine() << "Attributes { " ; |
120 | llvm::interleaveComma(c: attributes, os, each_fn: [&](const Attribute &attr) { |
121 | os << attr.getName() << " : " ; |
122 | |
123 | auto kind = attr.isOptional() ? VariableLengthKind::Optional |
124 | : VariableLengthKind::Single; |
125 | printVariableLengthCst(attr.getConstraint().getDemangledName(), kind); |
126 | }); |
127 | os << " }\n" ; |
128 | } |
129 | |
130 | // Operands. |
131 | ArrayRef<OperandOrResult> operands = op->getOperands(); |
132 | if (!operands.empty()) { |
133 | printer.startLine() << "Operands { " ; |
134 | llvm::interleaveComma( |
135 | c: operands, os, each_fn: [&](const OperandOrResult &operand) { |
136 | os << operand.getName() << " : " ; |
137 | printVariableLengthCst(operand.getConstraint().getDemangledName(), |
138 | operand.getVariableLengthKind()); |
139 | }); |
140 | os << " }\n" ; |
141 | } |
142 | |
143 | // Results. |
144 | ArrayRef<OperandOrResult> results = op->getResults(); |
145 | if (!results.empty()) { |
146 | printer.startLine() << "Results { " ; |
147 | llvm::interleaveComma(c: results, os, each_fn: [&](const OperandOrResult &result) { |
148 | os << result.getName() << " : " ; |
149 | printVariableLengthCst(result.getConstraint().getDemangledName(), |
150 | result.getVariableLengthKind()); |
151 | }); |
152 | os << " }\n" ; |
153 | } |
154 | |
155 | printer.objectEnd(); |
156 | } |
157 | printer.objectEnd(); |
158 | } |
159 | for (const AttributeConstraint *cst : sortMapByName(map: attributeConstraints)) { |
160 | printer.startLine() << "AttributeConstraint `" << cst->getDemangledName() |
161 | << "` {\n" ; |
162 | printer.indent(); |
163 | |
164 | printer.startLine() << "Summary: " << cst->getSummary() << "\n" ; |
165 | printer.startLine() << "CppClass: " << cst->getCppClass() << "\n" ; |
166 | printer.objectEnd(); |
167 | } |
168 | for (const TypeConstraint *cst : sortMapByName(map: typeConstraints)) { |
169 | printer.startLine() << "TypeConstraint `" << cst->getDemangledName() |
170 | << "` {\n" ; |
171 | printer.indent(); |
172 | |
173 | printer.startLine() << "Summary: " << cst->getSummary() << "\n" ; |
174 | printer.startLine() << "CppClass: " << cst->getCppClass() << "\n" ; |
175 | printer.objectEnd(); |
176 | } |
177 | printer.objectEnd(); |
178 | } |
179 | |