1//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpInterfacesGen generates definitions for operation interfaces.
10//
11//===----------------------------------------------------------------------===//
12
13#include "DocGenUtilities.h"
14#include "mlir/TableGen/Format.h"
15#include "mlir/TableGen/GenInfo.h"
16#include "mlir/TableGen/Interfaces.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/Support/FormatVariadic.h"
20#include "llvm/Support/raw_ostream.h"
21#include "llvm/TableGen/Error.h"
22#include "llvm/TableGen/Record.h"
23#include "llvm/TableGen/TableGenBackend.h"
24
25using namespace mlir;
26using mlir::tblgen::Interface;
27using mlir::tblgen::InterfaceMethod;
28using mlir::tblgen::OpInterface;
29
30/// Emit a string corresponding to a C++ type, followed by a space if necessary.
31static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
32 type = type.trim();
33 os << type;
34 if (type.back() != '&' && type.back() != '*')
35 os << " ";
36 return os;
37}
38
39/// Emit the method name and argument list for the given method. If 'addThisArg'
40/// is true, then an argument is added to the beginning of the argument list for
41/// the concrete value.
42static void emitMethodNameAndArgs(const InterfaceMethod &method,
43 raw_ostream &os, StringRef valueType,
44 bool addThisArg, bool addConst) {
45 os << method.getName() << '(';
46 if (addThisArg) {
47 if (addConst)
48 os << "const ";
49 os << "const Concept *impl, ";
50 emitCPPType(type: valueType, os)
51 << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
52 }
53 llvm::interleaveComma(c: method.getArguments(), os,
54 each_fn: [&](const InterfaceMethod::Argument &arg) {
55 os << arg.type << " " << arg.name;
56 });
57 os << ')';
58 if (addConst)
59 os << " const";
60}
61
62/// Get an array of all OpInterface definitions but exclude those subclassing
63/// "DeclareOpInterfaceMethods".
64static std::vector<llvm::Record *>
65getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
66 StringRef name) {
67 std::vector<llvm::Record *> defs =
68 recordKeeper.getAllDerivedDefinitions(ClassName: (name + "Interface").str());
69
70 std::string declareName = ("Declare" + name + "InterfaceMethods").str();
71 llvm::erase_if(C&: defs, P: [&](const llvm::Record *def) {
72 // Ignore any "declare methods" interfaces.
73 if (def->isSubClassOf(Name: declareName))
74 return true;
75 // Ignore interfaces defined outside of the top-level file.
76 return llvm::SrcMgr.FindBufferContainingLoc(Loc: def->getLoc()[0]) !=
77 llvm::SrcMgr.getMainFileID();
78 });
79 return defs;
80}
81
82namespace {
83/// This struct is the base generator used when processing tablegen interfaces.
84class InterfaceGenerator {
85public:
86 bool emitInterfaceDefs();
87 bool emitInterfaceDecls();
88 bool emitInterfaceDocs();
89
90protected:
91 InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
92 : defs(std::move(defs)), os(os) {}
93
94 void emitConceptDecl(const Interface &interface);
95 void emitModelDecl(const Interface &interface);
96 void emitModelMethodsDef(const Interface &interface);
97 void emitTraitDecl(const Interface &interface, StringRef interfaceName,
98 StringRef interfaceTraitsName);
99 void emitInterfaceDecl(const Interface &interface);
100
101 /// The set of interface records to emit.
102 std::vector<llvm::Record *> defs;
103 // The stream to emit to.
104 raw_ostream &os;
105 /// The C++ value type of the interface, e.g. Operation*.
106 StringRef valueType;
107 /// The C++ base interface type.
108 StringRef interfaceBaseType;
109 /// The name of the typename for the value template.
110 StringRef valueTemplate;
111 /// The name of the substituion variable for the value.
112 StringRef substVar;
113 /// The format context to use for methods.
114 tblgen::FmtContext nonStaticMethodFmt;
115 tblgen::FmtContext traitMethodFmt;
116 tblgen::FmtContext extraDeclsFmt;
117};
118
119/// A specialized generator for attribute interfaces.
120struct AttrInterfaceGenerator : public InterfaceGenerator {
121 AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
122 : InterfaceGenerator(getAllInterfaceDefinitions(recordKeeper: records, name: "Attr"), os) {
123 valueType = "::mlir::Attribute";
124 interfaceBaseType = "AttributeInterface";
125 valueTemplate = "ConcreteAttr";
126 substVar = "_attr";
127 StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))";
128 nonStaticMethodFmt.addSubst(placeholder: substVar, subst: castCode).withSelf(subst: castCode);
129 traitMethodFmt.addSubst(placeholder: substVar,
130 subst: "(*static_cast<const ConcreteAttr *>(this))");
131 extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)");
132 }
133};
134/// A specialized generator for operation interfaces.
135struct OpInterfaceGenerator : public InterfaceGenerator {
136 OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
137 : InterfaceGenerator(getAllInterfaceDefinitions(recordKeeper: records, name: "Op"), os) {
138 valueType = "::mlir::Operation *";
139 interfaceBaseType = "OpInterface";
140 valueTemplate = "ConcreteOp";
141 substVar = "_op";
142 StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
143 nonStaticMethodFmt.addSubst(placeholder: "_this", subst: "impl")
144 .addSubst(placeholder: substVar, subst: castCode)
145 .withSelf(subst: castCode);
146 traitMethodFmt.addSubst(placeholder: substVar, subst: "(*static_cast<ConcreteOp *>(this))");
147 extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)");
148 }
149};
150/// A specialized generator for type interfaces.
151struct TypeInterfaceGenerator : public InterfaceGenerator {
152 TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
153 : InterfaceGenerator(getAllInterfaceDefinitions(recordKeeper: records, name: "Type"), os) {
154 valueType = "::mlir::Type";
155 interfaceBaseType = "TypeInterface";
156 valueTemplate = "ConcreteType";
157 substVar = "_type";
158 StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))";
159 nonStaticMethodFmt.addSubst(placeholder: substVar, subst: castCode).withSelf(subst: castCode);
160 traitMethodFmt.addSubst(placeholder: substVar,
161 subst: "(*static_cast<const ConcreteType *>(this))");
162 extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)");
163 }
164};
165} // namespace
166
167//===----------------------------------------------------------------------===//
168// GEN: Interface definitions
169//===----------------------------------------------------------------------===//
170
171static void emitInterfaceMethodDoc(const InterfaceMethod &method,
172 raw_ostream &os, StringRef prefix = "") {
173 if (std::optional<StringRef> description = method.getDescription())
174 tblgen::emitDescriptionComment(description: *description, os, prefix);
175}
176static void emitInterfaceDefMethods(StringRef interfaceQualName,
177 const Interface &interface,
178 StringRef valueType, const Twine &implValue,
179 raw_ostream &os, bool isOpInterface) {
180 for (auto &method : interface.getMethods()) {
181 emitInterfaceMethodDoc(method, os);
182 emitCPPType(type: method.getReturnType(), os);
183 os << interfaceQualName << "::";
184 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
185 /*addConst=*/!isOpInterface);
186
187 // Forward to the method on the concrete operation type.
188 os << " {\n return " << implValue << "->" << method.getName() << '(';
189 if (!method.isStatic()) {
190 os << implValue << ", ";
191 os << (isOpInterface ? "getOperation()" : "*this");
192 os << (method.arg_empty() ? "" : ", ");
193 }
194 llvm::interleaveComma(
195 c: method.getArguments(), os,
196 each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
197 os << ");\n }\n";
198 }
199}
200
201static void emitInterfaceDef(const Interface &interface, StringRef valueType,
202 raw_ostream &os) {
203 std::string interfaceQualNameStr = interface.getFullyQualifiedName();
204 StringRef interfaceQualName = interfaceQualNameStr;
205 interfaceQualName.consume_front(Prefix: "::");
206
207 // Insert the method definitions.
208 bool isOpInterface = isa<OpInterface>(Val: interface);
209 emitInterfaceDefMethods(interfaceQualName, interface, valueType, implValue: "getImpl()",
210 os, isOpInterface);
211
212 // Insert the method definitions for base classes.
213 for (auto &base : interface.getBaseInterfaces()) {
214 emitInterfaceDefMethods(interfaceQualName, interface: base, valueType,
215 implValue: "getImpl()->impl" + base.getName(), os,
216 isOpInterface);
217 }
218}
219
220bool InterfaceGenerator::emitInterfaceDefs() {
221 llvm::emitSourceFileHeader(Desc: "Interface Definitions", OS&: os);
222
223 for (const auto *def : defs)
224 emitInterfaceDef(interface: Interface(def), valueType, os);
225 return false;
226}
227
228//===----------------------------------------------------------------------===//
229// GEN: Interface declarations
230//===----------------------------------------------------------------------===//
231
232void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
233 os << " struct Concept {\n";
234
235 // Insert each of the pure virtual concept methods.
236 os << " /// The methods defined by the interface.\n";
237 for (auto &method : interface.getMethods()) {
238 os << " ";
239 emitCPPType(type: method.getReturnType(), os);
240 os << "(*" << method.getName() << ")(";
241 if (!method.isStatic()) {
242 os << "const Concept *impl, ";
243 emitCPPType(type: valueType, os) << (method.arg_empty() ? "" : ", ");
244 }
245 llvm::interleaveComma(
246 c: method.getArguments(), os,
247 each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.type; });
248 os << ");\n";
249 }
250
251 // Insert a field containing a concept for each of the base interfaces.
252 auto baseInterfaces = interface.getBaseInterfaces();
253 if (!baseInterfaces.empty()) {
254 os << " /// The base classes of this interface.\n";
255 for (const auto &base : interface.getBaseInterfaces()) {
256 os << " const " << base.getFullyQualifiedName() << "::Concept *impl"
257 << base.getName() << " = nullptr;\n";
258 }
259
260 // Define an "initialize" method that allows for the initialization of the
261 // base class concepts.
262 os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
263 "&interfaceMap) {\n";
264 std::string interfaceQualName = interface.getFullyQualifiedName();
265 for (const auto &base : interface.getBaseInterfaces()) {
266 StringRef baseName = base.getName();
267 std::string baseQualName = base.getFullyQualifiedName();
268 os << " impl" << baseName << " = interfaceMap.lookup<"
269 << baseQualName << ">();\n"
270 << " assert(impl" << baseName << " && \"`" << interfaceQualName
271 << "` expected its base interface `" << baseQualName
272 << "` to be registered\");\n";
273 }
274 os << " }\n";
275 }
276
277 os << " };\n";
278}
279
280void InterfaceGenerator::emitModelDecl(const Interface &interface) {
281 // Emit the basic model and the fallback model.
282 for (const char *modelClass : {"Model", "FallbackModel"}) {
283 os << " template<typename " << valueTemplate << ">\n";
284 os << " class " << modelClass << " : public Concept {\n public:\n";
285 os << " using Interface = " << interface.getFullyQualifiedName()
286 << ";\n";
287 os << " " << modelClass << "() : Concept{";
288 llvm::interleaveComma(
289 c: interface.getMethods(), os,
290 each_fn: [&](const InterfaceMethod &method) { os << method.getName(); });
291 os << "} {}\n\n";
292
293 // Insert each of the virtual method overrides.
294 for (auto &method : interface.getMethods()) {
295 emitCPPType(type: method.getReturnType(), os&: os << " static inline ");
296 emitMethodNameAndArgs(method, os, valueType,
297 /*addThisArg=*/!method.isStatic(),
298 /*addConst=*/false);
299 os << ";\n";
300 }
301 os << " };\n";
302 }
303
304 // Emit the template for the external model.
305 os << " template<typename ConcreteModel, typename " << valueTemplate
306 << ">\n";
307 os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n";
308 os << " public:\n";
309 os << " using ConcreteEntity = " << valueTemplate << ";\n";
310
311 // Emit declarations for methods that have default implementations. Other
312 // methods are expected to be implemented by the concrete derived model.
313 for (auto &method : interface.getMethods()) {
314 if (!method.getDefaultImplementation())
315 continue;
316 os << " ";
317 if (method.isStatic())
318 os << "static ";
319 emitCPPType(type: method.getReturnType(), os);
320 os << method.getName() << "(";
321 if (!method.isStatic()) {
322 emitCPPType(type: valueType, os);
323 os << "tablegen_opaque_val";
324 if (!method.arg_empty())
325 os << ", ";
326 }
327 llvm::interleaveComma(c: method.getArguments(), os,
328 each_fn: [&](const InterfaceMethod::Argument &arg) {
329 emitCPPType(type: arg.type, os);
330 os << arg.name;
331 });
332 os << ")";
333 if (!method.isStatic())
334 os << " const";
335 os << ";\n";
336 }
337 os << " };\n";
338}
339
340void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
341 llvm::SmallVector<StringRef, 2> namespaces;
342 llvm::SplitString(Source: interface.getCppNamespace(), OutFragments&: namespaces, Delimiters: "::");
343 for (StringRef ns : namespaces)
344 os << "namespace " << ns << " {\n";
345
346 for (auto &method : interface.getMethods()) {
347 os << "template<typename " << valueTemplate << ">\n";
348 emitCPPType(type: method.getReturnType(), os);
349 os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
350 << valueTemplate << ">::";
351 emitMethodNameAndArgs(method, os, valueType,
352 /*addThisArg=*/!method.isStatic(),
353 /*addConst=*/false);
354 os << " {\n ";
355
356 // Check for a provided body to the function.
357 if (std::optional<StringRef> body = method.getBody()) {
358 if (method.isStatic())
359 os << body->trim();
360 else
361 os << tblgen::tgfmt(fmt: body->trim(), ctx: &nonStaticMethodFmt);
362 os << "\n}\n";
363 continue;
364 }
365
366 // Forward to the method on the concrete operation type.
367 if (method.isStatic())
368 os << "return " << valueTemplate << "::";
369 else
370 os << tblgen::tgfmt(fmt: "return $_self.", ctx: &nonStaticMethodFmt);
371
372 // Add the arguments to the call.
373 os << method.getName() << '(';
374 llvm::interleaveComma(
375 c: method.getArguments(), os,
376 each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
377 os << ");\n}\n";
378 }
379
380 for (auto &method : interface.getMethods()) {
381 os << "template<typename " << valueTemplate << ">\n";
382 emitCPPType(type: method.getReturnType(), os);
383 os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
384 << valueTemplate << ">::";
385 emitMethodNameAndArgs(method, os, valueType,
386 /*addThisArg=*/!method.isStatic(),
387 /*addConst=*/false);
388 os << " {\n ";
389
390 // Forward to the method on the concrete Model implementation.
391 if (method.isStatic())
392 os << "return " << valueTemplate << "::";
393 else
394 os << "return static_cast<const " << valueTemplate << " *>(impl)->";
395
396 // Add the arguments to the call.
397 os << method.getName() << '(';
398 if (!method.isStatic())
399 os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
400 llvm::interleaveComma(
401 c: method.getArguments(), os,
402 each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
403 os << ");\n}\n";
404 }
405
406 // Emit default implementations for the external model.
407 for (auto &method : interface.getMethods()) {
408 if (!method.getDefaultImplementation())
409 continue;
410 os << "template<typename ConcreteModel, typename " << valueTemplate
411 << ">\n";
412 emitCPPType(type: method.getReturnType(), os);
413 os << "detail::" << interface.getName()
414 << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
415 << ">::";
416
417 os << method.getName() << "(";
418 if (!method.isStatic()) {
419 emitCPPType(type: valueType, os);
420 os << "tablegen_opaque_val";
421 if (!method.arg_empty())
422 os << ", ";
423 }
424 llvm::interleaveComma(c: method.getArguments(), os,
425 each_fn: [&](const InterfaceMethod::Argument &arg) {
426 emitCPPType(type: arg.type, os);
427 os << arg.name;
428 });
429 os << ")";
430 if (!method.isStatic())
431 os << " const";
432
433 os << " {\n";
434
435 // Use the empty context for static methods.
436 tblgen::FmtContext ctx;
437 os << tblgen::tgfmt(fmt: method.getDefaultImplementation()->trim(),
438 ctx: method.isStatic() ? &ctx : &nonStaticMethodFmt);
439 os << "\n}\n";
440 }
441
442 for (StringRef ns : llvm::reverse(C&: namespaces))
443 os << "} // namespace " << ns << "\n";
444}
445
446void InterfaceGenerator::emitTraitDecl(const Interface &interface,
447 StringRef interfaceName,
448 StringRef interfaceTraitsName) {
449 os << llvm::formatv(Fmt: " template <typename {3}>\n"
450 " struct {0}Trait : public ::mlir::{2}<{0},"
451 " detail::{1}>::Trait<{3}> {{\n",
452 Vals&: interfaceName, Vals&: interfaceTraitsName, Vals&: interfaceBaseType,
453 Vals&: valueTemplate);
454
455 // Insert the default implementation for any methods.
456 bool isOpInterface = isa<OpInterface>(Val: interface);
457 for (auto &method : interface.getMethods()) {
458 // Flag interface methods named verifyTrait.
459 if (method.getName() == "verifyTrait")
460 PrintFatalError(
461 Msg: formatv(Fmt: "'verifyTrait' method cannot be specified as interface "
462 "method for '{0}'; use the 'verify' field instead",
463 Vals&: interfaceName));
464 auto defaultImpl = method.getDefaultImplementation();
465 if (!defaultImpl)
466 continue;
467
468 emitInterfaceMethodDoc(method, os, prefix: " ");
469 os << " " << (method.isStatic() ? "static " : "");
470 emitCPPType(type: method.getReturnType(), os);
471 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
472 /*addConst=*/!isOpInterface && !method.isStatic());
473 os << " {\n " << tblgen::tgfmt(fmt: defaultImpl->trim(), ctx: &traitMethodFmt)
474 << "\n }\n";
475 }
476
477 if (auto verify = interface.getVerify()) {
478 assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
479
480 tblgen::FmtContext verifyCtx;
481 verifyCtx.addSubst(placeholder: "_op", subst: "op");
482 os << llvm::formatv(
483 Fmt: " static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
484 Vals: (interface.verifyWithRegions() ? "verifyRegionTrait"
485 : "verifyTrait"))
486 << "{\n " << tblgen::tgfmt(fmt: verify->trim(), ctx: &verifyCtx)
487 << "\n }\n";
488 }
489 if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
490 os << tblgen::tgfmt(fmt: *extraTraitDecls, ctx: &traitMethodFmt) << "\n";
491 if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration())
492 os << tblgen::tgfmt(fmt: *extraTraitDecls, ctx: &traitMethodFmt) << "\n";
493
494 os << " };\n";
495}
496
497static void emitInterfaceDeclMethods(const Interface &interface,
498 raw_ostream &os, StringRef valueType,
499 bool isOpInterface,
500 tblgen::FmtContext &extraDeclsFmt) {
501 for (auto &method : interface.getMethods()) {
502 emitInterfaceMethodDoc(method, os, prefix: " ");
503 emitCPPType(type: method.getReturnType(), os&: os << " ");
504 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
505 /*addConst=*/!isOpInterface);
506 os << ";\n";
507 }
508
509 // Emit any extra declarations.
510 if (std::optional<StringRef> extraDecls =
511 interface.getExtraClassDeclaration())
512 os << extraDecls->rtrim() << "\n";
513 if (std::optional<StringRef> extraDecls =
514 interface.getExtraSharedClassDeclaration())
515 os << tblgen::tgfmt(fmt: extraDecls->rtrim(), ctx: &extraDeclsFmt) << "\n";
516}
517
518void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
519 llvm::SmallVector<StringRef, 2> namespaces;
520 llvm::SplitString(Source: interface.getCppNamespace(), OutFragments&: namespaces, Delimiters: "::");
521 for (StringRef ns : namespaces)
522 os << "namespace " << ns << " {\n";
523
524 StringRef interfaceName = interface.getName();
525 auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
526
527 // Emit a forward declaration of the interface class so that it becomes usable
528 // in the signature of its methods.
529 os << "class " << interfaceName << ";\n";
530
531 // Emit the traits struct containing the concept and model declarations.
532 os << "namespace detail {\n"
533 << "struct " << interfaceTraitsName << " {\n";
534 emitConceptDecl(interface);
535 emitModelDecl(interface);
536 os << "};\n";
537
538 // Emit the derived trait for the interface.
539 os << "template <typename " << valueTemplate << ">\n";
540 os << "struct " << interface.getName() << "Trait;\n";
541
542 os << "\n} // namespace detail\n";
543
544 // Emit the main interface class declaration.
545 os << llvm::formatv(Fmt: "class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
546 "public:\n"
547 " using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
548 Vals&: interfaceName, Vals&: interfaceName, Vals&: interfaceTraitsName,
549 Vals&: interfaceBaseType);
550
551 // Emit a utility wrapper trait class.
552 os << llvm::formatv(Fmt: " template <typename {1}>\n"
553 " struct Trait : public detail::{0}Trait<{1}> {{};\n",
554 Vals&: interfaceName, Vals&: valueTemplate);
555
556 // Insert the method declarations.
557 bool isOpInterface = isa<OpInterface>(Val: interface);
558 emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
559 extraDeclsFmt);
560
561 // Insert the method declarations for base classes.
562 for (auto &base : interface.getBaseInterfaces()) {
563 std::string baseQualName = base.getFullyQualifiedName();
564 os << " //"
565 "===---------------------------------------------------------------"
566 "-===//\n"
567 << " // Inherited from " << baseQualName << "\n"
568 << " //"
569 "===---------------------------------------------------------------"
570 "-===//\n\n";
571
572 // Allow implicit conversion to the base interface.
573 os << " operator " << baseQualName << " () const {\n"
574 << " if (!*this) return nullptr;\n"
575 << " return " << baseQualName << "(*this, getImpl()->impl"
576 << base.getName() << ");\n"
577 << " }\n\n";
578
579 // Inherit the base interface's methods.
580 emitInterfaceDeclMethods(interface: base, os, valueType, isOpInterface, extraDeclsFmt);
581 }
582
583 // Emit classof code if necessary.
584 if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
585 auto extraClassOfFmt = tblgen::FmtContext();
586 extraClassOfFmt.addSubst(placeholder: substVar, subst: "odsInterfaceInstance");
587 os << " static bool classof(" << valueType << " base) {\n"
588 << " auto* interface = getInterfaceFor(base);\n"
589 << " if (!interface)\n"
590 " return false;\n"
591 " " << interfaceName << " odsInterfaceInstance(base, interface);\n"
592 << " " << tblgen::tgfmt(fmt: extraClassOf->trim(), ctx: &extraClassOfFmt)
593 << "\n }\n";
594 }
595
596 os << "};\n";
597
598 os << "namespace detail {\n";
599 emitTraitDecl(interface, interfaceName, interfaceTraitsName);
600 os << "}// namespace detail\n";
601
602 for (StringRef ns : llvm::reverse(C&: namespaces))
603 os << "} // namespace " << ns << "\n";
604}
605
606bool InterfaceGenerator::emitInterfaceDecls() {
607 llvm::emitSourceFileHeader(Desc: "Interface Declarations", OS&: os);
608 // Sort according to ID, so defs are emitted in the order in which they appear
609 // in the Tablegen file.
610 std::vector<llvm::Record *> sortedDefs(defs);
611 llvm::sort(C&: sortedDefs, Comp: [](llvm::Record *lhs, llvm::Record *rhs) {
612 return lhs->getID() < rhs->getID();
613 });
614 for (const llvm::Record *def : sortedDefs)
615 emitInterfaceDecl(interface: Interface(def));
616 for (const llvm::Record *def : sortedDefs)
617 emitModelMethodsDef(interface: Interface(def));
618 return false;
619}
620
621//===----------------------------------------------------------------------===//
622// GEN: Interface documentation
623//===----------------------------------------------------------------------===//
624
625static void emitInterfaceDoc(const llvm::Record &interfaceDef,
626 raw_ostream &os) {
627 Interface interface(&interfaceDef);
628
629 // Emit the interface name followed by the description.
630 os << "## " << interface.getName() << " (`" << interfaceDef.getName()
631 << "`)\n\n";
632 if (auto description = interface.getDescription())
633 mlir::tblgen::emitDescription(description: *description, os);
634
635 // Emit the methods required by the interface.
636 os << "\n### Methods:\n";
637 for (const auto &method : interface.getMethods()) {
638 // Emit the method name.
639 os << "#### `" << method.getName() << "`\n\n```c++\n";
640
641 // Emit the method signature.
642 if (method.isStatic())
643 os << "static ";
644 emitCPPType(type: method.getReturnType(), os) << method.getName() << '(';
645 llvm::interleaveComma(c: method.getArguments(), os,
646 each_fn: [&](const InterfaceMethod::Argument &arg) {
647 emitCPPType(type: arg.type, os) << arg.name;
648 });
649 os << ");\n```\n";
650
651 // Emit the description.
652 if (auto description = method.getDescription())
653 mlir::tblgen::emitDescription(description: *description, os);
654
655 // If the body is not provided, this method must be provided by the user.
656 if (!method.getBody())
657 os << "\nNOTE: This method *must* be implemented by the user.";
658
659 os << "\n\n";
660 }
661}
662
663bool InterfaceGenerator::emitInterfaceDocs() {
664 os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
665 os << "# " << interfaceBaseType << " definitions\n";
666
667 for (const auto *def : defs)
668 emitInterfaceDoc(interfaceDef: *def, os);
669 return false;
670}
671
672//===----------------------------------------------------------------------===//
673// GEN: Interface registration hooks
674//===----------------------------------------------------------------------===//
675
676namespace {
677template <typename GeneratorT>
678struct InterfaceGenRegistration {
679 InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
680 : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
681 genDefArg(("gen-" + genArg + "-interface-defs").str()),
682 genDocArg(("gen-" + genArg + "-interface-docs").str()),
683 genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
684 genDefDesc(("Generate " + genDesc + " interface definitions").str()),
685 genDocDesc(("Generate " + genDesc + " interface documentation").str()),
686 genDecls(genDeclArg, genDeclDesc,
687 [](const llvm::RecordKeeper &records, raw_ostream &os) {
688 return GeneratorT(records, os).emitInterfaceDecls();
689 }),
690 genDefs(genDefArg, genDefDesc,
691 [](const llvm::RecordKeeper &records, raw_ostream &os) {
692 return GeneratorT(records, os).emitInterfaceDefs();
693 }),
694 genDocs(genDocArg, genDocDesc,
695 [](const llvm::RecordKeeper &records, raw_ostream &os) {
696 return GeneratorT(records, os).emitInterfaceDocs();
697 }) {}
698
699 std::string genDeclArg, genDefArg, genDocArg;
700 std::string genDeclDesc, genDefDesc, genDocDesc;
701 mlir::GenRegistration genDecls, genDefs, genDocs;
702};
703} // namespace
704
705static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
706 "attribute");
707static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
708static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");
709

source code of mlir/tools/mlir-tblgen/OpInterfacesGen.cpp