1//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpInterfacesGen generates definitions for operation interfaces.
10//
11//===----------------------------------------------------------------------===//
12
13#include "CppGenUtilities.h"
14#include "DocGenUtilities.h"
15#include "mlir/TableGen/Format.h"
16#include "mlir/TableGen/GenInfo.h"
17#include "mlir/TableGen/Interfaces.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/Support/raw_ostream.h"
22#include "llvm/TableGen/Error.h"
23#include "llvm/TableGen/Record.h"
24#include "llvm/TableGen/TableGenBackend.h"
25
26using namespace mlir;
27using llvm::Record;
28using llvm::RecordKeeper;
29using mlir::tblgen::Interface;
30using mlir::tblgen::InterfaceMethod;
31using mlir::tblgen::OpInterface;
32
33/// Emit a string corresponding to a C++ type, followed by a space if necessary.
34static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
35 type = type.trim();
36 os << type;
37 if (type.back() != '&' && type.back() != '*')
38 os << " ";
39 return os;
40}
41
42/// Emit the method name and argument list for the given method. If 'addThisArg'
43/// is true, then an argument is added to the beginning of the argument list for
44/// the concrete value.
45static void emitMethodNameAndArgs(const InterfaceMethod &method,
46 raw_ostream &os, StringRef valueType,
47 bool addThisArg, bool addConst) {
48 os << method.getName() << '(';
49 if (addThisArg) {
50 if (addConst)
51 os << "const ";
52 os << "const Concept *impl, ";
53 emitCPPType(type: valueType, os)
54 << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
55 }
56 llvm::interleaveComma(c: method.getArguments(), os,
57 each_fn: [&](const InterfaceMethod::Argument &arg) {
58 os << arg.type << " " << arg.name;
59 });
60 os << ')';
61 if (addConst)
62 os << " const";
63}
64
65/// Get an array of all OpInterface definitions but exclude those subclassing
66/// "DeclareOpInterfaceMethods".
67static std::vector<const Record *>
68getAllInterfaceDefinitions(const RecordKeeper &records, StringRef name) {
69 std::vector<const Record *> defs =
70 records.getAllDerivedDefinitions(ClassName: (name + "Interface").str());
71
72 std::string declareName = ("Declare" + name + "InterfaceMethods").str();
73 llvm::erase_if(C&: defs, P: [&](const Record *def) {
74 // Ignore any "declare methods" interfaces.
75 if (def->isSubClassOf(Name: declareName))
76 return true;
77 // Ignore interfaces defined outside of the top-level file.
78 return llvm::SrcMgr.FindBufferContainingLoc(Loc: def->getLoc()[0]) !=
79 llvm::SrcMgr.getMainFileID();
80 });
81 return defs;
82}
83
84namespace {
85/// This struct is the base generator used when processing tablegen interfaces.
86class InterfaceGenerator {
87public:
88 bool emitInterfaceDefs();
89 bool emitInterfaceDecls();
90 bool emitInterfaceDocs();
91
92protected:
93 InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os)
94 : defs(std::move(defs)), os(os) {}
95
96 void emitConceptDecl(const Interface &interface);
97 void emitModelDecl(const Interface &interface);
98 void emitModelMethodsDef(const Interface &interface);
99 void emitTraitDecl(const Interface &interface, StringRef interfaceName,
100 StringRef interfaceTraitsName);
101 void emitInterfaceDecl(const Interface &interface);
102
103 /// The set of interface records to emit.
104 std::vector<const Record *> defs;
105 // The stream to emit to.
106 raw_ostream &os;
107 /// The C++ value type of the interface, e.g. Operation*.
108 StringRef valueType;
109 /// The C++ base interface type.
110 StringRef interfaceBaseType;
111 /// The name of the typename for the value template.
112 StringRef valueTemplate;
113 /// The name of the substituion variable for the value.
114 StringRef substVar;
115 /// The format context to use for methods.
116 tblgen::FmtContext nonStaticMethodFmt;
117 tblgen::FmtContext traitMethodFmt;
118 tblgen::FmtContext extraDeclsFmt;
119};
120
121/// A specialized generator for attribute interfaces.
122struct AttrInterfaceGenerator : public InterfaceGenerator {
123 AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
124 : InterfaceGenerator(getAllInterfaceDefinitions(records, name: "Attr"), os) {
125 valueType = "::mlir::Attribute";
126 interfaceBaseType = "AttributeInterface";
127 valueTemplate = "ConcreteAttr";
128 substVar = "_attr";
129 StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))";
130 nonStaticMethodFmt.addSubst(placeholder: substVar, subst: castCode).withSelf(subst: castCode);
131 traitMethodFmt.addSubst(placeholder: substVar,
132 subst: "(*static_cast<const ConcreteAttr *>(this))");
133 extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)");
134 }
135};
136/// A specialized generator for operation interfaces.
137struct OpInterfaceGenerator : public InterfaceGenerator {
138 OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
139 : InterfaceGenerator(getAllInterfaceDefinitions(records, name: "Op"), os) {
140 valueType = "::mlir::Operation *";
141 interfaceBaseType = "OpInterface";
142 valueTemplate = "ConcreteOp";
143 substVar = "_op";
144 StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
145 nonStaticMethodFmt.addSubst(placeholder: "_this", subst: "impl")
146 .addSubst(placeholder: substVar, subst: castCode)
147 .withSelf(subst: castCode);
148 traitMethodFmt.addSubst(placeholder: substVar, subst: "(*static_cast<ConcreteOp *>(this))");
149 extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)");
150 }
151};
152/// A specialized generator for type interfaces.
153struct TypeInterfaceGenerator : public InterfaceGenerator {
154 TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
155 : InterfaceGenerator(getAllInterfaceDefinitions(records, name: "Type"), os) {
156 valueType = "::mlir::Type";
157 interfaceBaseType = "TypeInterface";
158 valueTemplate = "ConcreteType";
159 substVar = "_type";
160 StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))";
161 nonStaticMethodFmt.addSubst(placeholder: substVar, subst: castCode).withSelf(subst: castCode);
162 traitMethodFmt.addSubst(placeholder: substVar,
163 subst: "(*static_cast<const ConcreteType *>(this))");
164 extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)");
165 }
166};
167} // namespace
168
169//===----------------------------------------------------------------------===//
170// GEN: Interface definitions
171//===----------------------------------------------------------------------===//
172
173static void emitInterfaceMethodDoc(const InterfaceMethod &method,
174 raw_ostream &os, StringRef prefix = "") {
175 if (std::optional<StringRef> description = method.getDescription())
176 tblgen::emitDescriptionComment(description: *description, os, prefix);
177}
178static void emitInterfaceDefMethods(StringRef interfaceQualName,
179 const Interface &interface,
180 StringRef valueType, const Twine &implValue,
181 raw_ostream &os, bool isOpInterface) {
182 for (auto &method : interface.getMethods()) {
183 emitInterfaceMethodDoc(method, os);
184 emitCPPType(type: method.getReturnType(), os);
185 os << interfaceQualName << "::";
186 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
187 /*addConst=*/!isOpInterface);
188
189 // Forward to the method on the concrete operation type.
190 os << " {\n return " << implValue << "->" << method.getName() << '(';
191 if (!method.isStatic()) {
192 os << implValue << ", ";
193 os << (isOpInterface ? "getOperation()" : "*this");
194 os << (method.arg_empty() ? "" : ", ");
195 }
196 llvm::interleaveComma(
197 c: method.getArguments(), os,
198 each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
199 os << ");\n }\n";
200 }
201}
202
203static void emitInterfaceDef(const Interface &interface, StringRef valueType,
204 raw_ostream &os) {
205 std::string interfaceQualNameStr = interface.getFullyQualifiedName();
206 StringRef interfaceQualName = interfaceQualNameStr;
207 interfaceQualName.consume_front(Prefix: "::");
208
209 // Insert the method definitions.
210 bool isOpInterface = isa<OpInterface>(Val: interface);
211 emitInterfaceDefMethods(interfaceQualName, interface, valueType, implValue: "getImpl()",
212 os, isOpInterface);
213
214 // Insert the method definitions for base classes.
215 for (auto &base : interface.getBaseInterfaces()) {
216 emitInterfaceDefMethods(interfaceQualName, interface: base, valueType,
217 implValue: "getImpl()->impl" + base.getName(), os,
218 isOpInterface);
219 }
220}
221
222bool InterfaceGenerator::emitInterfaceDefs() {
223 llvm::emitSourceFileHeader(Desc: "Interface Definitions", OS&: os);
224
225 for (const auto *def : defs)
226 emitInterfaceDef(interface: Interface(def), valueType, os);
227 return false;
228}
229
230//===----------------------------------------------------------------------===//
231// GEN: Interface declarations
232//===----------------------------------------------------------------------===//
233
234void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
235 os << " struct Concept {\n";
236
237 // Insert each of the pure virtual concept methods.
238 os << " /// The methods defined by the interface.\n";
239 for (auto &method : interface.getMethods()) {
240 os << " ";
241 emitCPPType(type: method.getReturnType(), os);
242 os << "(*" << method.getName() << ")(";
243 if (!method.isStatic()) {
244 os << "const Concept *impl, ";
245 emitCPPType(type: valueType, os) << (method.arg_empty() ? "" : ", ");
246 }
247 llvm::interleaveComma(
248 c: method.getArguments(), os,
249 each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.type; });
250 os << ");\n";
251 }
252
253 // Insert a field containing a concept for each of the base interfaces.
254 auto baseInterfaces = interface.getBaseInterfaces();
255 if (!baseInterfaces.empty()) {
256 os << " /// The base classes of this interface.\n";
257 for (const auto &base : interface.getBaseInterfaces()) {
258 os << " const " << base.getFullyQualifiedName() << "::Concept *impl"
259 << base.getName() << " = nullptr;\n";
260 }
261
262 // Define an "initialize" method that allows for the initialization of the
263 // base class concepts.
264 os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
265 "&interfaceMap) {\n";
266 std::string interfaceQualName = interface.getFullyQualifiedName();
267 for (const auto &base : interface.getBaseInterfaces()) {
268 StringRef baseName = base.getName();
269 std::string baseQualName = base.getFullyQualifiedName();
270 os << " impl" << baseName << " = interfaceMap.lookup<"
271 << baseQualName << ">();\n"
272 << " assert(impl" << baseName << " && \"`" << interfaceQualName
273 << "` expected its base interface `" << baseQualName
274 << "` to be registered\");\n";
275 }
276 os << " }\n";
277 }
278
279 os << " };\n";
280}
281
282void InterfaceGenerator::emitModelDecl(const Interface &interface) {
283 // Emit the basic model and the fallback model.
284 for (const char *modelClass : {"Model", "FallbackModel"}) {
285 os << " template<typename " << valueTemplate << ">\n";
286 os << " class " << modelClass << " : public Concept {\n public:\n";
287 os << " using Interface = " << interface.getFullyQualifiedName()
288 << ";\n";
289 os << " " << modelClass << "() : Concept{";
290 llvm::interleaveComma(
291 c: interface.getMethods(), os,
292 each_fn: [&](const InterfaceMethod &method) { os << method.getName(); });
293 os << "} {}\n\n";
294
295 // Insert each of the virtual method overrides.
296 for (auto &method : interface.getMethods()) {
297 emitCPPType(type: method.getReturnType(), os&: os << " static inline ");
298 emitMethodNameAndArgs(method, os, valueType,
299 /*addThisArg=*/!method.isStatic(),
300 /*addConst=*/false);
301 os << ";\n";
302 }
303 os << " };\n";
304 }
305
306 // Emit the template for the external model.
307 os << " template<typename ConcreteModel, typename " << valueTemplate
308 << ">\n";
309 os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n";
310 os << " public:\n";
311 os << " using ConcreteEntity = " << valueTemplate << ";\n";
312
313 // Emit declarations for methods that have default implementations. Other
314 // methods are expected to be implemented by the concrete derived model.
315 for (auto &method : interface.getMethods()) {
316 if (!method.getDefaultImplementation())
317 continue;
318 os << " ";
319 if (method.isStatic())
320 os << "static ";
321 emitCPPType(type: method.getReturnType(), os);
322 os << method.getName() << "(";
323 if (!method.isStatic()) {
324 emitCPPType(type: valueType, os);
325 os << "tablegen_opaque_val";
326 if (!method.arg_empty())
327 os << ", ";
328 }
329 llvm::interleaveComma(c: method.getArguments(), os,
330 each_fn: [&](const InterfaceMethod::Argument &arg) {
331 emitCPPType(type: arg.type, os);
332 os << arg.name;
333 });
334 os << ")";
335 if (!method.isStatic())
336 os << " const";
337 os << ";\n";
338 }
339 os << " };\n";
340}
341
342void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
343 llvm::SmallVector<StringRef, 2> namespaces;
344 llvm::SplitString(Source: interface.getCppNamespace(), OutFragments&: namespaces, Delimiters: "::");
345 for (StringRef ns : namespaces)
346 os << "namespace " << ns << " {\n";
347
348 for (auto &method : interface.getMethods()) {
349 os << "template<typename " << valueTemplate << ">\n";
350 emitCPPType(type: method.getReturnType(), os);
351 os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
352 << valueTemplate << ">::";
353 emitMethodNameAndArgs(method, os, valueType,
354 /*addThisArg=*/!method.isStatic(),
355 /*addConst=*/false);
356 os << " {\n ";
357
358 // Check for a provided body to the function.
359 if (std::optional<StringRef> body = method.getBody()) {
360 if (method.isStatic())
361 os << body->trim();
362 else
363 os << tblgen::tgfmt(fmt: body->trim(), ctx: &nonStaticMethodFmt);
364 os << "\n}\n";
365 continue;
366 }
367
368 // Forward to the method on the concrete operation type.
369 if (method.isStatic())
370 os << "return " << valueTemplate << "::";
371 else
372 os << tblgen::tgfmt(fmt: "return $_self.", ctx: &nonStaticMethodFmt);
373
374 // Add the arguments to the call.
375 os << method.getName() << '(';
376 llvm::interleaveComma(
377 c: method.getArguments(), os,
378 each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
379 os << ");\n}\n";
380 }
381
382 for (auto &method : interface.getMethods()) {
383 os << "template<typename " << valueTemplate << ">\n";
384 emitCPPType(type: method.getReturnType(), os);
385 os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
386 << valueTemplate << ">::";
387 emitMethodNameAndArgs(method, os, valueType,
388 /*addThisArg=*/!method.isStatic(),
389 /*addConst=*/false);
390 os << " {\n ";
391
392 // Forward to the method on the concrete Model implementation.
393 if (method.isStatic())
394 os << "return " << valueTemplate << "::";
395 else
396 os << "return static_cast<const " << valueTemplate << " *>(impl)->";
397
398 // Add the arguments to the call.
399 os << method.getName() << '(';
400 if (!method.isStatic())
401 os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
402 llvm::interleaveComma(
403 c: method.getArguments(), os,
404 each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
405 os << ");\n}\n";
406 }
407
408 // Emit default implementations for the external model.
409 for (auto &method : interface.getMethods()) {
410 if (!method.getDefaultImplementation())
411 continue;
412 os << "template<typename ConcreteModel, typename " << valueTemplate
413 << ">\n";
414 emitCPPType(type: method.getReturnType(), os);
415 os << "detail::" << interface.getName()
416 << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
417 << ">::";
418
419 os << method.getName() << "(";
420 if (!method.isStatic()) {
421 emitCPPType(type: valueType, os);
422 os << "tablegen_opaque_val";
423 if (!method.arg_empty())
424 os << ", ";
425 }
426 llvm::interleaveComma(c: method.getArguments(), os,
427 each_fn: [&](const InterfaceMethod::Argument &arg) {
428 emitCPPType(type: arg.type, os);
429 os << arg.name;
430 });
431 os << ")";
432 if (!method.isStatic())
433 os << " const";
434
435 os << " {\n";
436
437 // Use the empty context for static methods.
438 tblgen::FmtContext ctx;
439 os << tblgen::tgfmt(fmt: method.getDefaultImplementation()->trim(),
440 ctx: method.isStatic() ? &ctx : &nonStaticMethodFmt);
441 os << "\n}\n";
442 }
443
444 for (StringRef ns : llvm::reverse(C&: namespaces))
445 os << "} // namespace " << ns << "\n";
446}
447
448void InterfaceGenerator::emitTraitDecl(const Interface &interface,
449 StringRef interfaceName,
450 StringRef interfaceTraitsName) {
451 os << llvm::formatv(Fmt: " template <typename {3}>\n"
452 " struct {0}Trait : public ::mlir::{2}<{0},"
453 " detail::{1}>::Trait<{3}> {{\n",
454 Vals&: interfaceName, Vals&: interfaceTraitsName, Vals&: interfaceBaseType,
455 Vals&: valueTemplate);
456
457 // Insert the default implementation for any methods.
458 bool isOpInterface = isa<OpInterface>(Val: interface);
459 for (auto &method : interface.getMethods()) {
460 // Flag interface methods named verifyTrait.
461 if (method.getName() == "verifyTrait")
462 PrintFatalError(
463 Msg: formatv(Fmt: "'verifyTrait' method cannot be specified as interface "
464 "method for '{0}'; use the 'verify' field instead",
465 Vals&: interfaceName));
466 auto defaultImpl = method.getDefaultImplementation();
467 if (!defaultImpl)
468 continue;
469
470 emitInterfaceMethodDoc(method, os, prefix: " ");
471 os << " " << (method.isStatic() ? "static " : "");
472 emitCPPType(type: method.getReturnType(), os);
473 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
474 /*addConst=*/!isOpInterface && !method.isStatic());
475 os << " {\n " << tblgen::tgfmt(fmt: defaultImpl->trim(), ctx: &traitMethodFmt)
476 << "\n }\n";
477 }
478
479 if (auto verify = interface.getVerify()) {
480 assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
481
482 tblgen::FmtContext verifyCtx;
483 verifyCtx.addSubst(placeholder: "_op", subst: "op");
484 os << llvm::formatv(
485 Fmt: " static ::llvm::LogicalResult {0}(::mlir::Operation *op) ",
486 Vals: (interface.verifyWithRegions() ? "verifyRegionTrait"
487 : "verifyTrait"))
488 << "{\n " << tblgen::tgfmt(fmt: verify->trim(), ctx: &verifyCtx)
489 << "\n }\n";
490 }
491 if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
492 os << tblgen::tgfmt(fmt: *extraTraitDecls, ctx: &traitMethodFmt) << "\n";
493 if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration())
494 os << tblgen::tgfmt(fmt: *extraTraitDecls, ctx: &traitMethodFmt) << "\n";
495
496 os << " };\n";
497}
498
499static void emitInterfaceDeclMethods(const Interface &interface,
500 raw_ostream &os, StringRef valueType,
501 bool isOpInterface,
502 tblgen::FmtContext &extraDeclsFmt) {
503 for (auto &method : interface.getMethods()) {
504 emitInterfaceMethodDoc(method, os, prefix: " ");
505 emitCPPType(type: method.getReturnType(), os&: os << " ");
506 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
507 /*addConst=*/!isOpInterface);
508 os << ";\n";
509 }
510
511 // Emit any extra declarations.
512 if (std::optional<StringRef> extraDecls =
513 interface.getExtraClassDeclaration())
514 os << extraDecls->rtrim() << "\n";
515 if (std::optional<StringRef> extraDecls =
516 interface.getExtraSharedClassDeclaration())
517 os << tblgen::tgfmt(fmt: extraDecls->rtrim(), ctx: &extraDeclsFmt) << "\n";
518}
519
520void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
521 llvm::SmallVector<StringRef, 2> namespaces;
522 llvm::SplitString(Source: interface.getCppNamespace(), OutFragments&: namespaces, Delimiters: "::");
523 for (StringRef ns : namespaces)
524 os << "namespace " << ns << " {\n";
525
526 StringRef interfaceName = interface.getName();
527 auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
528
529 // Emit a forward declaration of the interface class so that it becomes usable
530 // in the signature of its methods.
531 std::string comments = tblgen::emitSummaryAndDescComments(
532 summary: "", description: interface.getDescription().value_or(u: ""));
533 if (!comments.empty()) {
534 os << comments << "\n";
535 }
536 os << "class " << interfaceName << ";\n";
537
538 // Emit the traits struct containing the concept and model declarations.
539 os << "namespace detail {\n"
540 << "struct " << interfaceTraitsName << " {\n";
541 emitConceptDecl(interface);
542 emitModelDecl(interface);
543 os << "};\n";
544
545 // Emit the derived trait for the interface.
546 os << "template <typename " << valueTemplate << ">\n";
547 os << "struct " << interface.getName() << "Trait;\n";
548
549 os << "\n} // namespace detail\n";
550
551 // Emit the main interface class declaration.
552 os << llvm::formatv(Fmt: "class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
553 "public:\n"
554 " using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
555 Vals&: interfaceName, Vals&: interfaceName, Vals&: interfaceTraitsName,
556 Vals&: interfaceBaseType);
557
558 // Emit a utility wrapper trait class.
559 os << llvm::formatv(Fmt: " template <typename {1}>\n"
560 " struct Trait : public detail::{0}Trait<{1}> {{};\n",
561 Vals&: interfaceName, Vals&: valueTemplate);
562
563 // Insert the method declarations.
564 bool isOpInterface = isa<OpInterface>(Val: interface);
565 emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
566 extraDeclsFmt);
567
568 // Insert the method declarations for base classes.
569 for (auto &base : interface.getBaseInterfaces()) {
570 std::string baseQualName = base.getFullyQualifiedName();
571 os << " //"
572 "===---------------------------------------------------------------"
573 "-===//\n"
574 << " // Inherited from " << baseQualName << "\n"
575 << " //"
576 "===---------------------------------------------------------------"
577 "-===//\n\n";
578
579 // Allow implicit conversion to the base interface.
580 os << " operator " << baseQualName << " () const {\n"
581 << " if (!*this) return nullptr;\n"
582 << " return " << baseQualName << "(*this, getImpl()->impl"
583 << base.getName() << ");\n"
584 << " }\n\n";
585
586 // Inherit the base interface's methods.
587 emitInterfaceDeclMethods(interface: base, os, valueType, isOpInterface, extraDeclsFmt);
588 }
589
590 // Emit classof code if necessary.
591 if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
592 auto extraClassOfFmt = tblgen::FmtContext();
593 extraClassOfFmt.addSubst(placeholder: substVar, subst: "odsInterfaceInstance");
594 os << " static bool classof(" << valueType << " base) {\n"
595 << " auto* interface = getInterfaceFor(base);\n"
596 << " if (!interface)\n"
597 " return false;\n"
598 " "
599 << interfaceName << " odsInterfaceInstance(base, interface);\n"
600 << " " << tblgen::tgfmt(fmt: extraClassOf->trim(), ctx: &extraClassOfFmt)
601 << "\n }\n";
602 }
603
604 os << "};\n";
605
606 os << "namespace detail {\n";
607 emitTraitDecl(interface, interfaceName, interfaceTraitsName);
608 os << "}// namespace detail\n";
609
610 for (StringRef ns : llvm::reverse(C&: namespaces))
611 os << "} // namespace " << ns << "\n";
612}
613
614bool InterfaceGenerator::emitInterfaceDecls() {
615 llvm::emitSourceFileHeader(Desc: "Interface Declarations", OS&: os);
616 // Sort according to ID, so defs are emitted in the order in which they appear
617 // in the Tablegen file.
618 std::vector<const Record *> sortedDefs(defs);
619 llvm::sort(C&: sortedDefs, Comp: [](const Record *lhs, const Record *rhs) {
620 return lhs->getID() < rhs->getID();
621 });
622 for (const Record *def : sortedDefs)
623 emitInterfaceDecl(interface: Interface(def));
624 for (const Record *def : sortedDefs)
625 emitModelMethodsDef(interface: Interface(def));
626 return false;
627}
628
629//===----------------------------------------------------------------------===//
630// GEN: Interface documentation
631//===----------------------------------------------------------------------===//
632
633static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
634 Interface interface(&interfaceDef);
635
636 // Emit the interface name followed by the description.
637 os << "\n## " << interface.getName() << " (`" << interfaceDef.getName()
638 << "`)\n";
639 if (auto description = interface.getDescription())
640 mlir::tblgen::emitDescription(description: *description, os);
641
642 // Emit the methods required by the interface.
643 os << "\n### Methods:\n";
644 for (const auto &method : interface.getMethods()) {
645 // Emit the method name.
646 os << "\n#### `" << method.getName() << "`\n\n```c++\n";
647
648 // Emit the method signature.
649 if (method.isStatic())
650 os << "static ";
651 emitCPPType(type: method.getReturnType(), os) << method.getName() << '(';
652 llvm::interleaveComma(c: method.getArguments(), os,
653 each_fn: [&](const InterfaceMethod::Argument &arg) {
654 emitCPPType(type: arg.type, os) << arg.name;
655 });
656 os << ");\n```\n";
657
658 // Emit the description.
659 if (auto description = method.getDescription())
660 mlir::tblgen::emitDescription(description: *description, os);
661
662 // If the body is not provided, this method must be provided by the user.
663 if (!method.getBody())
664 os << "\nNOTE: This method *must* be implemented by the user.";
665
666 os << "\n";
667 }
668}
669
670bool InterfaceGenerator::emitInterfaceDocs() {
671 os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
672 os << "\n# " << interfaceBaseType << " definitions\n";
673
674 for (const auto *def : defs)
675 emitInterfaceDoc(interfaceDef: *def, os);
676 return false;
677}
678
679//===----------------------------------------------------------------------===//
680// GEN: Interface registration hooks
681//===----------------------------------------------------------------------===//
682
683namespace {
684template <typename GeneratorT>
685struct InterfaceGenRegistration {
686 InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
687 : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
688 genDefArg(("gen-" + genArg + "-interface-defs").str()),
689 genDocArg(("gen-" + genArg + "-interface-docs").str()),
690 genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
691 genDefDesc(("Generate " + genDesc + " interface definitions").str()),
692 genDocDesc(("Generate " + genDesc + " interface documentation").str()),
693 genDecls(genDeclArg, genDeclDesc,
694 [](const RecordKeeper &records, raw_ostream &os) {
695 return GeneratorT(records, os).emitInterfaceDecls();
696 }),
697 genDefs(genDefArg, genDefDesc,
698 [](const RecordKeeper &records, raw_ostream &os) {
699 return GeneratorT(records, os).emitInterfaceDefs();
700 }),
701 genDocs(genDocArg, genDocDesc,
702 [](const RecordKeeper &records, raw_ostream &os) {
703 return GeneratorT(records, os).emitInterfaceDocs();
704 }) {}
705
706 std::string genDeclArg, genDefArg, genDocArg;
707 std::string genDeclDesc, genDefDesc, genDocDesc;
708 mlir::GenRegistration genDecls, genDefs, genDocs;
709};
710} // namespace
711
712static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
713 "attribute");
714static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
715static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");
716

source code of mlir/tools/mlir-tblgen/OpInterfacesGen.cpp