1//===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpDefinitionsGen uses the description of operations to generate C++
10// definitions for ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/TableGen/CodeGenHelpers.h"
15#include "mlir/TableGen/Operator.h"
16#include "mlir/TableGen/Pattern.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/Path.h"
19#include "llvm/TableGen/Record.h"
20
21using namespace llvm;
22using namespace mlir;
23using namespace mlir::tblgen;
24
25/// Generate a unique label based on the current file name to prevent name
26/// collisions if multiple generated files are included at once.
27static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
28 // Use the input file name when generating a unique name.
29 std::string inputFilename = records.getInputFilename();
30
31 // Drop all but the base filename.
32 StringRef nameRef = llvm::sys::path::filename(path: inputFilename);
33 nameRef.consume_back(Suffix: ".td");
34
35 // Sanitize any invalid characters.
36 std::string uniqueName;
37 for (char c : nameRef) {
38 if (llvm::isAlnum(C: c) || c == '_')
39 uniqueName.push_back(c: c);
40 else
41 uniqueName.append(str: llvm::utohexstr(X: (unsigned char)c));
42 }
43 return uniqueName;
44}
45
46StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
47 raw_ostream &os, const llvm::RecordKeeper &records)
48 : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
49
50void StaticVerifierFunctionEmitter::emitOpConstraints(
51 ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
52 collectOpConstraints(opDefs);
53 if (emitDecl)
54 return;
55
56 NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
57 emitTypeConstraints();
58 emitAttrConstraints();
59 emitSuccessorConstraints();
60 emitRegionConstraints();
61}
62
63void StaticVerifierFunctionEmitter::emitPatternConstraints(
64 const llvm::ArrayRef<DagLeaf> constraints) {
65 collectPatternConstraints(constraints);
66 emitPatternConstraints();
67}
68
69//===----------------------------------------------------------------------===//
70// Constraint Getters
71
72StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
73 const Constraint &constraint) const {
74 const auto *it = typeConstraints.find(Key: constraint);
75 assert(it != typeConstraints.end() && "expected to find a type constraint");
76 return it->second;
77}
78
79// Find a uniqued attribute constraint. Since not all attribute constraints can
80// be uniqued, return std::nullopt if one was not found.
81std::optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn(
82 const Constraint &constraint) const {
83 const auto *it = attrConstraints.find(Key: constraint);
84 return it == attrConstraints.end() ? std::optional<StringRef>()
85 : StringRef(it->second);
86}
87
88StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn(
89 const Constraint &constraint) const {
90 const auto *it = successorConstraints.find(Key: constraint);
91 assert(it != successorConstraints.end() &&
92 "expected to find a sucessor constraint");
93 return it->second;
94}
95
96StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn(
97 const Constraint &constraint) const {
98 const auto *it = regionConstraints.find(Key: constraint);
99 assert(it != regionConstraints.end() &&
100 "expected to find a region constraint");
101 return it->second;
102}
103
104//===----------------------------------------------------------------------===//
105// Constraint Emission
106
107/// Code templates for emitting type, attribute, successor, and region
108/// constraints. Each of these templates require the following arguments:
109///
110/// {0}: The unique constraint name.
111/// {1}: The constraint code.
112/// {2}: The constraint description.
113
114/// Code for a type constraint. These may be called on the type of either
115/// operands or results.
116static const char *const typeConstraintCode = R"(
117static ::mlir::LogicalResult {0}(
118 ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
119 unsigned valueIndex) {
120 if (!({1})) {
121 return op->emitOpError(valueKind) << " #" << valueIndex
122 << " must be {2}, but got " << type;
123 }
124 return ::mlir::success();
125}
126)";
127
128/// Code for an attribute constraint. These may be called from ops only.
129/// Attribute constraints cannot reference anything other than `$_self` and
130/// `$_op`.
131///
132/// TODO: Unique constraints for adaptors. However, most Adaptor::verify
133/// functions are stripped anyways.
134static const char *const attrConstraintCode = R"(
135static ::mlir::LogicalResult {0}(
136 ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
137 if (attr && !({1}))
138 return emitError() << "attribute '" << attrName
139 << "' failed to satisfy constraint: {2}";
140 return ::mlir::success();
141}
142static ::mlir::LogicalResult {0}(
143 ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{
144 return {0}(attr, attrName, [op]() {{
145 return op->emitOpError();
146 });
147}
148)";
149
150/// Code for a successor constraint.
151static const char *const successorConstraintCode = R"(
152static ::mlir::LogicalResult {0}(
153 ::mlir::Operation *op, ::mlir::Block *successor,
154 ::llvm::StringRef successorName, unsigned successorIndex) {
155 if (!({1})) {
156 return op->emitOpError("successor #") << successorIndex << " ('"
157 << successorName << ")' failed to verify constraint: {2}";
158 }
159 return ::mlir::success();
160}
161)";
162
163/// Code for a region constraint. Callers will need to pass in the region's name
164/// for emitting an error message.
165static const char *const regionConstraintCode = R"(
166static ::mlir::LogicalResult {0}(
167 ::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName,
168 unsigned regionIndex) {
169 if (!({1})) {
170 return op->emitOpError("region #") << regionIndex
171 << (regionName.empty() ? " " : " ('" + regionName + "') ")
172 << "failed to verify constraint: {2}";
173 }
174 return ::mlir::success();
175}
176)";
177
178/// Code for a pattern type or attribute constraint.
179///
180/// {3}: "Type type" or "Attribute attr".
181static const char *const patternAttrOrTypeConstraintCode = R"(
182static ::mlir::LogicalResult {0}(
183 ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
184 ::llvm::StringRef failureStr) {
185 if (!({1})) {
186 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
187 diag << failureStr << ": {2}";
188 });
189 }
190 return ::mlir::success();
191}
192)";
193
194void StaticVerifierFunctionEmitter::emitConstraints(
195 const ConstraintMap &constraints, StringRef selfName,
196 const char *const codeTemplate) {
197 FmtContext ctx;
198 ctx.addSubst(placeholder: "_op", subst: "*op").withSelf(subst: selfName);
199 for (auto &it : constraints) {
200 os << formatv(Fmt: codeTemplate, Vals: it.second,
201 Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx),
202 Vals: escapeString(value: it.first.getSummary()));
203 }
204}
205
206void StaticVerifierFunctionEmitter::emitTypeConstraints() {
207 emitConstraints(constraints: typeConstraints, selfName: "type", codeTemplate: typeConstraintCode);
208}
209
210void StaticVerifierFunctionEmitter::emitAttrConstraints() {
211 emitConstraints(constraints: attrConstraints, selfName: "attr", codeTemplate: attrConstraintCode);
212}
213
214void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
215 emitConstraints(constraints: successorConstraints, selfName: "successor", codeTemplate: successorConstraintCode);
216}
217
218void StaticVerifierFunctionEmitter::emitRegionConstraints() {
219 emitConstraints(constraints: regionConstraints, selfName: "region", codeTemplate: regionConstraintCode);
220}
221
222void StaticVerifierFunctionEmitter::emitPatternConstraints() {
223 FmtContext ctx;
224 ctx.addSubst(placeholder: "_op", subst: "*op").withBuilder(subst: "rewriter").withSelf(subst: "type");
225 for (auto &it : typeConstraints) {
226 os << formatv(Fmt: patternAttrOrTypeConstraintCode, Vals&: it.second,
227 Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx),
228 Vals: escapeString(value: it.first.getSummary()), Vals: "Type type");
229 }
230 ctx.withSelf(subst: "attr");
231 for (auto &it : attrConstraints) {
232 os << formatv(Fmt: patternAttrOrTypeConstraintCode, Vals&: it.second,
233 Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx),
234 Vals: escapeString(value: it.first.getSummary()), Vals: "Attribute attr");
235 }
236}
237
238//===----------------------------------------------------------------------===//
239// Constraint Uniquing
240
241/// An attribute constraint that references anything other than itself and the
242/// current op cannot be generically extracted into a function. Most
243/// prohibitive are operands and results, which require calls to
244/// `getODSOperands` or `getODSResults`. Attribute references are tricky too
245/// because ops use cached identifiers.
246static bool canUniqueAttrConstraint(Attribute attr) {
247 FmtContext ctx;
248 auto test = tgfmt(fmt: attr.getConditionTemplate(),
249 ctx: &ctx.withSelf(subst: "attr").addSubst(placeholder: "_op", subst: "*op"))
250 .str();
251 return !StringRef(test).contains(Other: "<no-subst-found>");
252}
253
254std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind,
255 unsigned index) {
256 return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel +
257 Twine(index))
258 .str();
259}
260
261void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
262 StringRef kind,
263 Constraint constraint) {
264 auto *it = map.find(Key: constraint);
265 if (it == map.end())
266 map.insert(KV: {constraint, getUniqueName(kind, index: map.size())});
267}
268
269void StaticVerifierFunctionEmitter::collectOpConstraints(
270 ArrayRef<Record *> opDefs) {
271 const auto collectTypeConstraints = [&](Operator::const_value_range values) {
272 for (const NamedTypeConstraint &value : values)
273 if (value.hasPredicate())
274 collectConstraint(map&: typeConstraints, kind: "type", constraint: value.constraint);
275 };
276
277 for (Record *def : opDefs) {
278 Operator op(*def);
279 /// Collect type constraints.
280 collectTypeConstraints(op.getOperands());
281 collectTypeConstraints(op.getResults());
282 /// Collect attribute constraints.
283 for (const NamedAttribute &namedAttr : op.getAttributes()) {
284 if (!namedAttr.attr.getPredicate().isNull() &&
285 !namedAttr.attr.isDerivedAttr() &&
286 canUniqueAttrConstraint(attr: namedAttr.attr))
287 collectConstraint(map&: attrConstraints, kind: "attr", constraint: namedAttr.attr);
288 }
289 /// Collect successor constraints.
290 for (const NamedSuccessor &successor : op.getSuccessors()) {
291 if (!successor.constraint.getPredicate().isNull()) {
292 collectConstraint(map&: successorConstraints, kind: "successor",
293 constraint: successor.constraint);
294 }
295 }
296 /// Collect region constraints.
297 for (const NamedRegion &region : op.getRegions())
298 if (!region.constraint.getPredicate().isNull())
299 collectConstraint(map&: regionConstraints, kind: "region", constraint: region.constraint);
300 }
301}
302
303void StaticVerifierFunctionEmitter::collectPatternConstraints(
304 const llvm::ArrayRef<DagLeaf> constraints) {
305 for (auto &leaf : constraints) {
306 assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
307 collectConstraint(
308 map&: leaf.isOperandMatcher() ? typeConstraints : attrConstraints,
309 kind: leaf.isOperandMatcher() ? "type" : "attr", constraint: leaf.getAsConstraint());
310 }
311}
312
313//===----------------------------------------------------------------------===//
314// Public Utility Functions
315//===----------------------------------------------------------------------===//
316
317std::string mlir::tblgen::escapeString(StringRef value) {
318 std::string ret;
319 llvm::raw_string_ostream os(ret);
320 os.write_escaped(Str: value);
321 return os.str();
322}
323

source code of mlir/lib/TableGen/CodeGenHelpers.cpp