1 | //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // OpDefinitionsGen uses the description of operations to generate C++ |
10 | // definitions for ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/TableGen/CodeGenHelpers.h" |
15 | #include "mlir/TableGen/Operator.h" |
16 | #include "mlir/TableGen/Pattern.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | #include "llvm/Support/Path.h" |
19 | #include "llvm/TableGen/Record.h" |
20 | |
21 | using namespace llvm; |
22 | using namespace mlir; |
23 | using namespace mlir::tblgen; |
24 | |
25 | /// Generate a unique label based on the current file name to prevent name |
26 | /// collisions if multiple generated files are included at once. |
27 | static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { |
28 | // Use the input file name when generating a unique name. |
29 | std::string inputFilename = records.getInputFilename(); |
30 | |
31 | // Drop all but the base filename. |
32 | StringRef nameRef = llvm::sys::path::filename(path: inputFilename); |
33 | nameRef.consume_back(Suffix: ".td" ); |
34 | |
35 | // Sanitize any invalid characters. |
36 | std::string uniqueName; |
37 | for (char c : nameRef) { |
38 | if (llvm::isAlnum(C: c) || c == '_') |
39 | uniqueName.push_back(c: c); |
40 | else |
41 | uniqueName.append(str: llvm::utohexstr(X: (unsigned char)c)); |
42 | } |
43 | return uniqueName; |
44 | } |
45 | |
46 | StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( |
47 | raw_ostream &os, const llvm::RecordKeeper &records) |
48 | : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {} |
49 | |
50 | void StaticVerifierFunctionEmitter::emitOpConstraints( |
51 | ArrayRef<llvm::Record *> opDefs, bool emitDecl) { |
52 | collectOpConstraints(opDefs); |
53 | if (emitDecl) |
54 | return; |
55 | |
56 | NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); |
57 | emitTypeConstraints(); |
58 | emitAttrConstraints(); |
59 | emitSuccessorConstraints(); |
60 | emitRegionConstraints(); |
61 | } |
62 | |
63 | void StaticVerifierFunctionEmitter::emitPatternConstraints( |
64 | const llvm::ArrayRef<DagLeaf> constraints) { |
65 | collectPatternConstraints(constraints); |
66 | emitPatternConstraints(); |
67 | } |
68 | |
69 | //===----------------------------------------------------------------------===// |
70 | // Constraint Getters |
71 | |
72 | StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( |
73 | const Constraint &constraint) const { |
74 | const auto *it = typeConstraints.find(Key: constraint); |
75 | assert(it != typeConstraints.end() && "expected to find a type constraint" ); |
76 | return it->second; |
77 | } |
78 | |
79 | // Find a uniqued attribute constraint. Since not all attribute constraints can |
80 | // be uniqued, return std::nullopt if one was not found. |
81 | std::optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn( |
82 | const Constraint &constraint) const { |
83 | const auto *it = attrConstraints.find(Key: constraint); |
84 | return it == attrConstraints.end() ? std::optional<StringRef>() |
85 | : StringRef(it->second); |
86 | } |
87 | |
88 | StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn( |
89 | const Constraint &constraint) const { |
90 | const auto *it = successorConstraints.find(Key: constraint); |
91 | assert(it != successorConstraints.end() && |
92 | "expected to find a sucessor constraint" ); |
93 | return it->second; |
94 | } |
95 | |
96 | StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn( |
97 | const Constraint &constraint) const { |
98 | const auto *it = regionConstraints.find(Key: constraint); |
99 | assert(it != regionConstraints.end() && |
100 | "expected to find a region constraint" ); |
101 | return it->second; |
102 | } |
103 | |
104 | //===----------------------------------------------------------------------===// |
105 | // Constraint Emission |
106 | |
107 | /// Code templates for emitting type, attribute, successor, and region |
108 | /// constraints. Each of these templates require the following arguments: |
109 | /// |
110 | /// {0}: The unique constraint name. |
111 | /// {1}: The constraint code. |
112 | /// {2}: The constraint description. |
113 | |
114 | /// Code for a type constraint. These may be called on the type of either |
115 | /// operands or results. |
116 | static const char *const typeConstraintCode = R"( |
117 | static ::mlir::LogicalResult {0}( |
118 | ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, |
119 | unsigned valueIndex) { |
120 | if (!({1})) { |
121 | return op->emitOpError(valueKind) << " #" << valueIndex |
122 | << " must be {2}, but got " << type; |
123 | } |
124 | return ::mlir::success(); |
125 | } |
126 | )" ; |
127 | |
128 | /// Code for an attribute constraint. These may be called from ops only. |
129 | /// Attribute constraints cannot reference anything other than `$_self` and |
130 | /// `$_op`. |
131 | /// |
132 | /// TODO: Unique constraints for adaptors. However, most Adaptor::verify |
133 | /// functions are stripped anyways. |
134 | static const char *const attrConstraintCode = R"( |
135 | static ::mlir::LogicalResult {0}( |
136 | ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ |
137 | if (attr && !({1})) |
138 | return emitError() << "attribute '" << attrName |
139 | << "' failed to satisfy constraint: {2}"; |
140 | return ::mlir::success(); |
141 | } |
142 | static ::mlir::LogicalResult {0}( |
143 | ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{ |
144 | return {0}(attr, attrName, [op]() {{ |
145 | return op->emitOpError(); |
146 | }); |
147 | } |
148 | )" ; |
149 | |
150 | /// Code for a successor constraint. |
151 | static const char *const successorConstraintCode = R"( |
152 | static ::mlir::LogicalResult {0}( |
153 | ::mlir::Operation *op, ::mlir::Block *successor, |
154 | ::llvm::StringRef successorName, unsigned successorIndex) { |
155 | if (!({1})) { |
156 | return op->emitOpError("successor #") << successorIndex << " ('" |
157 | << successorName << ")' failed to verify constraint: {2}"; |
158 | } |
159 | return ::mlir::success(); |
160 | } |
161 | )" ; |
162 | |
163 | /// Code for a region constraint. Callers will need to pass in the region's name |
164 | /// for emitting an error message. |
165 | static const char *const regionConstraintCode = R"( |
166 | static ::mlir::LogicalResult {0}( |
167 | ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, |
168 | unsigned regionIndex) { |
169 | if (!({1})) { |
170 | return op->emitOpError("region #") << regionIndex |
171 | << (regionName.empty() ? " " : " ('" + regionName + "') ") |
172 | << "failed to verify constraint: {2}"; |
173 | } |
174 | return ::mlir::success(); |
175 | } |
176 | )" ; |
177 | |
178 | /// Code for a pattern type or attribute constraint. |
179 | /// |
180 | /// {3}: "Type type" or "Attribute attr". |
181 | static const char *const patternAttrOrTypeConstraintCode = R"( |
182 | static ::mlir::LogicalResult {0}( |
183 | ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3}, |
184 | ::llvm::StringRef failureStr) { |
185 | if (!({1})) { |
186 | return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { |
187 | diag << failureStr << ": {2}"; |
188 | }); |
189 | } |
190 | return ::mlir::success(); |
191 | } |
192 | )" ; |
193 | |
194 | void StaticVerifierFunctionEmitter::emitConstraints( |
195 | const ConstraintMap &constraints, StringRef selfName, |
196 | const char *const codeTemplate) { |
197 | FmtContext ctx; |
198 | ctx.addSubst(placeholder: "_op" , subst: "*op" ).withSelf(subst: selfName); |
199 | for (auto &it : constraints) { |
200 | os << formatv(Fmt: codeTemplate, Vals: it.second, |
201 | Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx), |
202 | Vals: escapeString(value: it.first.getSummary())); |
203 | } |
204 | } |
205 | |
206 | void StaticVerifierFunctionEmitter::emitTypeConstraints() { |
207 | emitConstraints(constraints: typeConstraints, selfName: "type" , codeTemplate: typeConstraintCode); |
208 | } |
209 | |
210 | void StaticVerifierFunctionEmitter::emitAttrConstraints() { |
211 | emitConstraints(constraints: attrConstraints, selfName: "attr" , codeTemplate: attrConstraintCode); |
212 | } |
213 | |
214 | void StaticVerifierFunctionEmitter::emitSuccessorConstraints() { |
215 | emitConstraints(constraints: successorConstraints, selfName: "successor" , codeTemplate: successorConstraintCode); |
216 | } |
217 | |
218 | void StaticVerifierFunctionEmitter::emitRegionConstraints() { |
219 | emitConstraints(constraints: regionConstraints, selfName: "region" , codeTemplate: regionConstraintCode); |
220 | } |
221 | |
222 | void StaticVerifierFunctionEmitter::emitPatternConstraints() { |
223 | FmtContext ctx; |
224 | ctx.addSubst(placeholder: "_op" , subst: "*op" ).withBuilder(subst: "rewriter" ).withSelf(subst: "type" ); |
225 | for (auto &it : typeConstraints) { |
226 | os << formatv(Fmt: patternAttrOrTypeConstraintCode, Vals&: it.second, |
227 | Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx), |
228 | Vals: escapeString(value: it.first.getSummary()), Vals: "Type type" ); |
229 | } |
230 | ctx.withSelf(subst: "attr" ); |
231 | for (auto &it : attrConstraints) { |
232 | os << formatv(Fmt: patternAttrOrTypeConstraintCode, Vals&: it.second, |
233 | Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx), |
234 | Vals: escapeString(value: it.first.getSummary()), Vals: "Attribute attr" ); |
235 | } |
236 | } |
237 | |
238 | //===----------------------------------------------------------------------===// |
239 | // Constraint Uniquing |
240 | |
241 | /// An attribute constraint that references anything other than itself and the |
242 | /// current op cannot be generically extracted into a function. Most |
243 | /// prohibitive are operands and results, which require calls to |
244 | /// `getODSOperands` or `getODSResults`. Attribute references are tricky too |
245 | /// because ops use cached identifiers. |
246 | static bool canUniqueAttrConstraint(Attribute attr) { |
247 | FmtContext ctx; |
248 | auto test = tgfmt(fmt: attr.getConditionTemplate(), |
249 | ctx: &ctx.withSelf(subst: "attr" ).addSubst(placeholder: "_op" , subst: "*op" )) |
250 | .str(); |
251 | return !StringRef(test).contains(Other: "<no-subst-found>" ); |
252 | } |
253 | |
254 | std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind, |
255 | unsigned index) { |
256 | return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel + |
257 | Twine(index)) |
258 | .str(); |
259 | } |
260 | |
261 | void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map, |
262 | StringRef kind, |
263 | Constraint constraint) { |
264 | auto *it = map.find(Key: constraint); |
265 | if (it == map.end()) |
266 | map.insert(KV: {constraint, getUniqueName(kind, index: map.size())}); |
267 | } |
268 | |
269 | void StaticVerifierFunctionEmitter::collectOpConstraints( |
270 | ArrayRef<Record *> opDefs) { |
271 | const auto collectTypeConstraints = [&](Operator::const_value_range values) { |
272 | for (const NamedTypeConstraint &value : values) |
273 | if (value.hasPredicate()) |
274 | collectConstraint(map&: typeConstraints, kind: "type" , constraint: value.constraint); |
275 | }; |
276 | |
277 | for (Record *def : opDefs) { |
278 | Operator op(*def); |
279 | /// Collect type constraints. |
280 | collectTypeConstraints(op.getOperands()); |
281 | collectTypeConstraints(op.getResults()); |
282 | /// Collect attribute constraints. |
283 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
284 | if (!namedAttr.attr.getPredicate().isNull() && |
285 | !namedAttr.attr.isDerivedAttr() && |
286 | canUniqueAttrConstraint(attr: namedAttr.attr)) |
287 | collectConstraint(map&: attrConstraints, kind: "attr" , constraint: namedAttr.attr); |
288 | } |
289 | /// Collect successor constraints. |
290 | for (const NamedSuccessor &successor : op.getSuccessors()) { |
291 | if (!successor.constraint.getPredicate().isNull()) { |
292 | collectConstraint(map&: successorConstraints, kind: "successor" , |
293 | constraint: successor.constraint); |
294 | } |
295 | } |
296 | /// Collect region constraints. |
297 | for (const NamedRegion ®ion : op.getRegions()) |
298 | if (!region.constraint.getPredicate().isNull()) |
299 | collectConstraint(map&: regionConstraints, kind: "region" , constraint: region.constraint); |
300 | } |
301 | } |
302 | |
303 | void StaticVerifierFunctionEmitter::collectPatternConstraints( |
304 | const llvm::ArrayRef<DagLeaf> constraints) { |
305 | for (auto &leaf : constraints) { |
306 | assert(leaf.isOperandMatcher() || leaf.isAttrMatcher()); |
307 | collectConstraint( |
308 | map&: leaf.isOperandMatcher() ? typeConstraints : attrConstraints, |
309 | kind: leaf.isOperandMatcher() ? "type" : "attr" , constraint: leaf.getAsConstraint()); |
310 | } |
311 | } |
312 | |
313 | //===----------------------------------------------------------------------===// |
314 | // Public Utility Functions |
315 | //===----------------------------------------------------------------------===// |
316 | |
317 | std::string mlir::tblgen::escapeString(StringRef value) { |
318 | std::string ret; |
319 | llvm::raw_string_ostream os(ret); |
320 | os.write_escaped(Str: value); |
321 | return os.str(); |
322 | } |
323 | |