1//===- Context.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Tools/PDLL/ODS/Context.h"
10#include "mlir/Tools/PDLL/ODS/Constraint.h"
11#include "mlir/Tools/PDLL/ODS/Dialect.h"
12#include "mlir/Tools/PDLL/ODS/Operation.h"
13#include "llvm/Support/ScopedPrinter.h"
14#include "llvm/Support/raw_ostream.h"
15#include <optional>
16
17using namespace mlir;
18using namespace mlir::pdll::ods;
19
20//===----------------------------------------------------------------------===//
21// Context
22//===----------------------------------------------------------------------===//
23
24Context::Context() = default;
25Context::~Context() = default;
26
27const AttributeConstraint &
28Context::insertAttributeConstraint(StringRef name, StringRef summary,
29 StringRef cppClass) {
30 std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
31 if (!constraint) {
32 constraint.reset(p: new AttributeConstraint(name, summary, cppClass));
33 } else {
34 assert(constraint->getCppClass() == cppClass &&
35 constraint->getSummary() == summary &&
36 "constraint with the same name was already registered with a "
37 "different class");
38 }
39 return *constraint;
40}
41
42const TypeConstraint &Context::insertTypeConstraint(StringRef name,
43 StringRef summary,
44 StringRef cppClass) {
45 std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
46 if (!constraint)
47 constraint.reset(p: new TypeConstraint(name, summary, cppClass));
48 return *constraint;
49}
50
51Dialect &Context::insertDialect(StringRef name) {
52 std::unique_ptr<Dialect> &dialect = dialects[name];
53 if (!dialect)
54 dialect.reset(p: new Dialect(name));
55 return *dialect;
56}
57
58const Dialect *Context::lookupDialect(StringRef name) const {
59 auto it = dialects.find(Key: name);
60 return it == dialects.end() ? nullptr : &*it->second;
61}
62
63std::pair<Operation *, bool>
64Context::insertOperation(StringRef name, StringRef summary, StringRef desc,
65 StringRef nativeClassName,
66 bool supportsResultTypeInferrence, SMLoc loc) {
67 std::pair<StringRef, StringRef> dialectAndName = name.split(Separator: '.');
68 return insertDialect(name: dialectAndName.first)
69 .insertOperation(name, summary, desc, nativeClassName,
70 supportsResultTypeInferrence, loc);
71}
72
73const Operation *Context::lookupOperation(StringRef name) const {
74 std::pair<StringRef, StringRef> dialectAndName = name.split(Separator: '.');
75 if (const Dialect *dialect = lookupDialect(name: dialectAndName.first))
76 return dialect->lookupOperation(name);
77 return nullptr;
78}
79
80template <typename T>
81SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
82 SmallVector<T *> storage;
83 for (auto &entry : map)
84 storage.push_back(entry.second.get());
85 llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
86 return lhs->getName() < rhs->getName();
87 });
88 return storage;
89}
90
91void Context::print(raw_ostream &os) const {
92 auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
93 switch (kind) {
94 case VariableLengthKind::Optional:
95 os << "Optional<" << cst << ">";
96 break;
97 case VariableLengthKind::Single:
98 os << cst;
99 break;
100 case VariableLengthKind::Variadic:
101 os << "Variadic<" << cst << ">";
102 break;
103 }
104 };
105
106 llvm::ScopedPrinter printer(os);
107 llvm::DictScope odsScope(printer, "ODSContext");
108 for (const Dialect *dialect : sortMapByName(map: dialects)) {
109 printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
110 printer.indent();
111
112 for (const Operation *op : sortMapByName(map: dialect->getOperations())) {
113 printer.startLine() << "Operation `" << op->getName() << "` {\n";
114 printer.indent();
115
116 // Attributes.
117 ArrayRef<Attribute> attributes = op->getAttributes();
118 if (!attributes.empty()) {
119 printer.startLine() << "Attributes { ";
120 llvm::interleaveComma(c: attributes, os, each_fn: [&](const Attribute &attr) {
121 os << attr.getName() << " : ";
122
123 auto kind = attr.isOptional() ? VariableLengthKind::Optional
124 : VariableLengthKind::Single;
125 printVariableLengthCst(attr.getConstraint().getDemangledName(), kind);
126 });
127 os << " }\n";
128 }
129
130 // Operands.
131 ArrayRef<OperandOrResult> operands = op->getOperands();
132 if (!operands.empty()) {
133 printer.startLine() << "Operands { ";
134 llvm::interleaveComma(
135 c: operands, os, each_fn: [&](const OperandOrResult &operand) {
136 os << operand.getName() << " : ";
137 printVariableLengthCst(operand.getConstraint().getDemangledName(),
138 operand.getVariableLengthKind());
139 });
140 os << " }\n";
141 }
142
143 // Results.
144 ArrayRef<OperandOrResult> results = op->getResults();
145 if (!results.empty()) {
146 printer.startLine() << "Results { ";
147 llvm::interleaveComma(c: results, os, each_fn: [&](const OperandOrResult &result) {
148 os << result.getName() << " : ";
149 printVariableLengthCst(result.getConstraint().getDemangledName(),
150 result.getVariableLengthKind());
151 });
152 os << " }\n";
153 }
154
155 printer.objectEnd();
156 }
157 printer.objectEnd();
158 }
159 for (const AttributeConstraint *cst : sortMapByName(map: attributeConstraints)) {
160 printer.startLine() << "AttributeConstraint `" << cst->getDemangledName()
161 << "` {\n";
162 printer.indent();
163
164 printer.startLine() << "Summary: " << cst->getSummary() << "\n";
165 printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
166 printer.objectEnd();
167 }
168 for (const TypeConstraint *cst : sortMapByName(map: typeConstraints)) {
169 printer.startLine() << "TypeConstraint `" << cst->getDemangledName()
170 << "` {\n";
171 printer.indent();
172
173 printer.startLine() << "Summary: " << cst->getSummary() << "\n";
174 printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
175 printer.objectEnd();
176 }
177 printer.objectEnd();
178}
179

source code of mlir/lib/Tools/PDLL/ODS/Context.cpp