1 | //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // OpInterfacesGen generates definitions for operation interfaces. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "DocGenUtilities.h" |
14 | #include "mlir/TableGen/Format.h" |
15 | #include "mlir/TableGen/GenInfo.h" |
16 | #include "mlir/TableGen/Interfaces.h" |
17 | #include "llvm/ADT/SmallVector.h" |
18 | #include "llvm/ADT/StringExtras.h" |
19 | #include "llvm/Support/FormatVariadic.h" |
20 | #include "llvm/Support/raw_ostream.h" |
21 | #include "llvm/TableGen/Error.h" |
22 | #include "llvm/TableGen/Record.h" |
23 | #include "llvm/TableGen/TableGenBackend.h" |
24 | |
25 | using namespace mlir; |
26 | using mlir::tblgen::Interface; |
27 | using mlir::tblgen::InterfaceMethod; |
28 | using mlir::tblgen::OpInterface; |
29 | |
30 | /// Emit a string corresponding to a C++ type, followed by a space if necessary. |
31 | static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { |
32 | type = type.trim(); |
33 | os << type; |
34 | if (type.back() != '&' && type.back() != '*') |
35 | os << " " ; |
36 | return os; |
37 | } |
38 | |
39 | /// Emit the method name and argument list for the given method. If 'addThisArg' |
40 | /// is true, then an argument is added to the beginning of the argument list for |
41 | /// the concrete value. |
42 | static void emitMethodNameAndArgs(const InterfaceMethod &method, |
43 | raw_ostream &os, StringRef valueType, |
44 | bool addThisArg, bool addConst) { |
45 | os << method.getName() << '('; |
46 | if (addThisArg) { |
47 | if (addConst) |
48 | os << "const " ; |
49 | os << "const Concept *impl, " ; |
50 | emitCPPType(type: valueType, os) |
51 | << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", " ); |
52 | } |
53 | llvm::interleaveComma(c: method.getArguments(), os, |
54 | each_fn: [&](const InterfaceMethod::Argument &arg) { |
55 | os << arg.type << " " << arg.name; |
56 | }); |
57 | os << ')'; |
58 | if (addConst) |
59 | os << " const" ; |
60 | } |
61 | |
62 | /// Get an array of all OpInterface definitions but exclude those subclassing |
63 | /// "DeclareOpInterfaceMethods". |
64 | static std::vector<llvm::Record *> |
65 | getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper, |
66 | StringRef name) { |
67 | std::vector<llvm::Record *> defs = |
68 | recordKeeper.getAllDerivedDefinitions(ClassName: (name + "Interface" ).str()); |
69 | |
70 | std::string declareName = ("Declare" + name + "InterfaceMethods" ).str(); |
71 | llvm::erase_if(C&: defs, P: [&](const llvm::Record *def) { |
72 | // Ignore any "declare methods" interfaces. |
73 | if (def->isSubClassOf(Name: declareName)) |
74 | return true; |
75 | // Ignore interfaces defined outside of the top-level file. |
76 | return llvm::SrcMgr.FindBufferContainingLoc(Loc: def->getLoc()[0]) != |
77 | llvm::SrcMgr.getMainFileID(); |
78 | }); |
79 | return defs; |
80 | } |
81 | |
82 | namespace { |
83 | /// This struct is the base generator used when processing tablegen interfaces. |
84 | class InterfaceGenerator { |
85 | public: |
86 | bool emitInterfaceDefs(); |
87 | bool emitInterfaceDecls(); |
88 | bool emitInterfaceDocs(); |
89 | |
90 | protected: |
91 | InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os) |
92 | : defs(std::move(defs)), os(os) {} |
93 | |
94 | void emitConceptDecl(const Interface &interface); |
95 | void emitModelDecl(const Interface &interface); |
96 | void emitModelMethodsDef(const Interface &interface); |
97 | void emitTraitDecl(const Interface &interface, StringRef interfaceName, |
98 | StringRef interfaceTraitsName); |
99 | void emitInterfaceDecl(const Interface &interface); |
100 | |
101 | /// The set of interface records to emit. |
102 | std::vector<llvm::Record *> defs; |
103 | // The stream to emit to. |
104 | raw_ostream &os; |
105 | /// The C++ value type of the interface, e.g. Operation*. |
106 | StringRef valueType; |
107 | /// The C++ base interface type. |
108 | StringRef interfaceBaseType; |
109 | /// The name of the typename for the value template. |
110 | StringRef valueTemplate; |
111 | /// The name of the substituion variable for the value. |
112 | StringRef substVar; |
113 | /// The format context to use for methods. |
114 | tblgen::FmtContext nonStaticMethodFmt; |
115 | tblgen::FmtContext traitMethodFmt; |
116 | tblgen::FmtContext ; |
117 | }; |
118 | |
119 | /// A specialized generator for attribute interfaces. |
120 | struct AttrInterfaceGenerator : public InterfaceGenerator { |
121 | AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) |
122 | : InterfaceGenerator(getAllInterfaceDefinitions(recordKeeper: records, name: "Attr" ), os) { |
123 | valueType = "::mlir::Attribute" ; |
124 | interfaceBaseType = "AttributeInterface" ; |
125 | valueTemplate = "ConcreteAttr" ; |
126 | substVar = "_attr" ; |
127 | StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))" ; |
128 | nonStaticMethodFmt.addSubst(placeholder: substVar, subst: castCode).withSelf(subst: castCode); |
129 | traitMethodFmt.addSubst(placeholder: substVar, |
130 | subst: "(*static_cast<const ConcreteAttr *>(this))" ); |
131 | extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)" ); |
132 | } |
133 | }; |
134 | /// A specialized generator for operation interfaces. |
135 | struct OpInterfaceGenerator : public InterfaceGenerator { |
136 | OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) |
137 | : InterfaceGenerator(getAllInterfaceDefinitions(recordKeeper: records, name: "Op" ), os) { |
138 | valueType = "::mlir::Operation *" ; |
139 | interfaceBaseType = "OpInterface" ; |
140 | valueTemplate = "ConcreteOp" ; |
141 | substVar = "_op" ; |
142 | StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))" ; |
143 | nonStaticMethodFmt.addSubst(placeholder: "_this" , subst: "impl" ) |
144 | .addSubst(placeholder: substVar, subst: castCode) |
145 | .withSelf(subst: castCode); |
146 | traitMethodFmt.addSubst(placeholder: substVar, subst: "(*static_cast<ConcreteOp *>(this))" ); |
147 | extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)" ); |
148 | } |
149 | }; |
150 | /// A specialized generator for type interfaces. |
151 | struct TypeInterfaceGenerator : public InterfaceGenerator { |
152 | TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) |
153 | : InterfaceGenerator(getAllInterfaceDefinitions(recordKeeper: records, name: "Type" ), os) { |
154 | valueType = "::mlir::Type" ; |
155 | interfaceBaseType = "TypeInterface" ; |
156 | valueTemplate = "ConcreteType" ; |
157 | substVar = "_type" ; |
158 | StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))" ; |
159 | nonStaticMethodFmt.addSubst(placeholder: substVar, subst: castCode).withSelf(subst: castCode); |
160 | traitMethodFmt.addSubst(placeholder: substVar, |
161 | subst: "(*static_cast<const ConcreteType *>(this))" ); |
162 | extraDeclsFmt.addSubst(placeholder: substVar, subst: "(*this)" ); |
163 | } |
164 | }; |
165 | } // namespace |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // GEN: Interface definitions |
169 | //===----------------------------------------------------------------------===// |
170 | |
171 | static void emitInterfaceMethodDoc(const InterfaceMethod &method, |
172 | raw_ostream &os, StringRef prefix = "" ) { |
173 | if (std::optional<StringRef> description = method.getDescription()) |
174 | tblgen::emitDescriptionComment(description: *description, os, prefix); |
175 | } |
176 | static void emitInterfaceDefMethods(StringRef interfaceQualName, |
177 | const Interface &interface, |
178 | StringRef valueType, const Twine &implValue, |
179 | raw_ostream &os, bool isOpInterface) { |
180 | for (auto &method : interface.getMethods()) { |
181 | emitInterfaceMethodDoc(method, os); |
182 | emitCPPType(type: method.getReturnType(), os); |
183 | os << interfaceQualName << "::" ; |
184 | emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, |
185 | /*addConst=*/!isOpInterface); |
186 | |
187 | // Forward to the method on the concrete operation type. |
188 | os << " {\n return " << implValue << "->" << method.getName() << '('; |
189 | if (!method.isStatic()) { |
190 | os << implValue << ", " ; |
191 | os << (isOpInterface ? "getOperation()" : "*this" ); |
192 | os << (method.arg_empty() ? "" : ", " ); |
193 | } |
194 | llvm::interleaveComma( |
195 | c: method.getArguments(), os, |
196 | each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); |
197 | os << ");\n }\n" ; |
198 | } |
199 | } |
200 | |
201 | static void emitInterfaceDef(const Interface &interface, StringRef valueType, |
202 | raw_ostream &os) { |
203 | std::string interfaceQualNameStr = interface.getFullyQualifiedName(); |
204 | StringRef interfaceQualName = interfaceQualNameStr; |
205 | interfaceQualName.consume_front(Prefix: "::" ); |
206 | |
207 | // Insert the method definitions. |
208 | bool isOpInterface = isa<OpInterface>(Val: interface); |
209 | emitInterfaceDefMethods(interfaceQualName, interface, valueType, implValue: "getImpl()" , |
210 | os, isOpInterface); |
211 | |
212 | // Insert the method definitions for base classes. |
213 | for (auto &base : interface.getBaseInterfaces()) { |
214 | emitInterfaceDefMethods(interfaceQualName, interface: base, valueType, |
215 | implValue: "getImpl()->impl" + base.getName(), os, |
216 | isOpInterface); |
217 | } |
218 | } |
219 | |
220 | bool InterfaceGenerator::emitInterfaceDefs() { |
221 | llvm::emitSourceFileHeader(Desc: "Interface Definitions" , OS&: os); |
222 | |
223 | for (const auto *def : defs) |
224 | emitInterfaceDef(interface: Interface(def), valueType, os); |
225 | return false; |
226 | } |
227 | |
228 | //===----------------------------------------------------------------------===// |
229 | // GEN: Interface declarations |
230 | //===----------------------------------------------------------------------===// |
231 | |
232 | void InterfaceGenerator::emitConceptDecl(const Interface &interface) { |
233 | os << " struct Concept {\n" ; |
234 | |
235 | // Insert each of the pure virtual concept methods. |
236 | os << " /// The methods defined by the interface.\n" ; |
237 | for (auto &method : interface.getMethods()) { |
238 | os << " " ; |
239 | emitCPPType(type: method.getReturnType(), os); |
240 | os << "(*" << method.getName() << ")(" ; |
241 | if (!method.isStatic()) { |
242 | os << "const Concept *impl, " ; |
243 | emitCPPType(type: valueType, os) << (method.arg_empty() ? "" : ", " ); |
244 | } |
245 | llvm::interleaveComma( |
246 | c: method.getArguments(), os, |
247 | each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.type; }); |
248 | os << ");\n" ; |
249 | } |
250 | |
251 | // Insert a field containing a concept for each of the base interfaces. |
252 | auto baseInterfaces = interface.getBaseInterfaces(); |
253 | if (!baseInterfaces.empty()) { |
254 | os << " /// The base classes of this interface.\n" ; |
255 | for (const auto &base : interface.getBaseInterfaces()) { |
256 | os << " const " << base.getFullyQualifiedName() << "::Concept *impl" |
257 | << base.getName() << " = nullptr;\n" ; |
258 | } |
259 | |
260 | // Define an "initialize" method that allows for the initialization of the |
261 | // base class concepts. |
262 | os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap " |
263 | "&interfaceMap) {\n" ; |
264 | std::string interfaceQualName = interface.getFullyQualifiedName(); |
265 | for (const auto &base : interface.getBaseInterfaces()) { |
266 | StringRef baseName = base.getName(); |
267 | std::string baseQualName = base.getFullyQualifiedName(); |
268 | os << " impl" << baseName << " = interfaceMap.lookup<" |
269 | << baseQualName << ">();\n" |
270 | << " assert(impl" << baseName << " && \"`" << interfaceQualName |
271 | << "` expected its base interface `" << baseQualName |
272 | << "` to be registered\");\n" ; |
273 | } |
274 | os << " }\n" ; |
275 | } |
276 | |
277 | os << " };\n" ; |
278 | } |
279 | |
280 | void InterfaceGenerator::emitModelDecl(const Interface &interface) { |
281 | // Emit the basic model and the fallback model. |
282 | for (const char *modelClass : {"Model" , "FallbackModel" }) { |
283 | os << " template<typename " << valueTemplate << ">\n" ; |
284 | os << " class " << modelClass << " : public Concept {\n public:\n" ; |
285 | os << " using Interface = " << interface.getFullyQualifiedName() |
286 | << ";\n" ; |
287 | os << " " << modelClass << "() : Concept{" ; |
288 | llvm::interleaveComma( |
289 | c: interface.getMethods(), os, |
290 | each_fn: [&](const InterfaceMethod &method) { os << method.getName(); }); |
291 | os << "} {}\n\n" ; |
292 | |
293 | // Insert each of the virtual method overrides. |
294 | for (auto &method : interface.getMethods()) { |
295 | emitCPPType(type: method.getReturnType(), os&: os << " static inline " ); |
296 | emitMethodNameAndArgs(method, os, valueType, |
297 | /*addThisArg=*/!method.isStatic(), |
298 | /*addConst=*/false); |
299 | os << ";\n" ; |
300 | } |
301 | os << " };\n" ; |
302 | } |
303 | |
304 | // Emit the template for the external model. |
305 | os << " template<typename ConcreteModel, typename " << valueTemplate |
306 | << ">\n" ; |
307 | os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n" ; |
308 | os << " public:\n" ; |
309 | os << " using ConcreteEntity = " << valueTemplate << ";\n" ; |
310 | |
311 | // Emit declarations for methods that have default implementations. Other |
312 | // methods are expected to be implemented by the concrete derived model. |
313 | for (auto &method : interface.getMethods()) { |
314 | if (!method.getDefaultImplementation()) |
315 | continue; |
316 | os << " " ; |
317 | if (method.isStatic()) |
318 | os << "static " ; |
319 | emitCPPType(type: method.getReturnType(), os); |
320 | os << method.getName() << "(" ; |
321 | if (!method.isStatic()) { |
322 | emitCPPType(type: valueType, os); |
323 | os << "tablegen_opaque_val" ; |
324 | if (!method.arg_empty()) |
325 | os << ", " ; |
326 | } |
327 | llvm::interleaveComma(c: method.getArguments(), os, |
328 | each_fn: [&](const InterfaceMethod::Argument &arg) { |
329 | emitCPPType(type: arg.type, os); |
330 | os << arg.name; |
331 | }); |
332 | os << ")" ; |
333 | if (!method.isStatic()) |
334 | os << " const" ; |
335 | os << ";\n" ; |
336 | } |
337 | os << " };\n" ; |
338 | } |
339 | |
340 | void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { |
341 | llvm::SmallVector<StringRef, 2> namespaces; |
342 | llvm::SplitString(Source: interface.getCppNamespace(), OutFragments&: namespaces, Delimiters: "::" ); |
343 | for (StringRef ns : namespaces) |
344 | os << "namespace " << ns << " {\n" ; |
345 | |
346 | for (auto &method : interface.getMethods()) { |
347 | os << "template<typename " << valueTemplate << ">\n" ; |
348 | emitCPPType(type: method.getReturnType(), os); |
349 | os << "detail::" << interface.getName() << "InterfaceTraits::Model<" |
350 | << valueTemplate << ">::" ; |
351 | emitMethodNameAndArgs(method, os, valueType, |
352 | /*addThisArg=*/!method.isStatic(), |
353 | /*addConst=*/false); |
354 | os << " {\n " ; |
355 | |
356 | // Check for a provided body to the function. |
357 | if (std::optional<StringRef> body = method.getBody()) { |
358 | if (method.isStatic()) |
359 | os << body->trim(); |
360 | else |
361 | os << tblgen::tgfmt(fmt: body->trim(), ctx: &nonStaticMethodFmt); |
362 | os << "\n}\n" ; |
363 | continue; |
364 | } |
365 | |
366 | // Forward to the method on the concrete operation type. |
367 | if (method.isStatic()) |
368 | os << "return " << valueTemplate << "::" ; |
369 | else |
370 | os << tblgen::tgfmt(fmt: "return $_self." , ctx: &nonStaticMethodFmt); |
371 | |
372 | // Add the arguments to the call. |
373 | os << method.getName() << '('; |
374 | llvm::interleaveComma( |
375 | c: method.getArguments(), os, |
376 | each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); |
377 | os << ");\n}\n" ; |
378 | } |
379 | |
380 | for (auto &method : interface.getMethods()) { |
381 | os << "template<typename " << valueTemplate << ">\n" ; |
382 | emitCPPType(type: method.getReturnType(), os); |
383 | os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<" |
384 | << valueTemplate << ">::" ; |
385 | emitMethodNameAndArgs(method, os, valueType, |
386 | /*addThisArg=*/!method.isStatic(), |
387 | /*addConst=*/false); |
388 | os << " {\n " ; |
389 | |
390 | // Forward to the method on the concrete Model implementation. |
391 | if (method.isStatic()) |
392 | os << "return " << valueTemplate << "::" ; |
393 | else |
394 | os << "return static_cast<const " << valueTemplate << " *>(impl)->" ; |
395 | |
396 | // Add the arguments to the call. |
397 | os << method.getName() << '('; |
398 | if (!method.isStatic()) |
399 | os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", " ); |
400 | llvm::interleaveComma( |
401 | c: method.getArguments(), os, |
402 | each_fn: [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); |
403 | os << ");\n}\n" ; |
404 | } |
405 | |
406 | // Emit default implementations for the external model. |
407 | for (auto &method : interface.getMethods()) { |
408 | if (!method.getDefaultImplementation()) |
409 | continue; |
410 | os << "template<typename ConcreteModel, typename " << valueTemplate |
411 | << ">\n" ; |
412 | emitCPPType(type: method.getReturnType(), os); |
413 | os << "detail::" << interface.getName() |
414 | << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate |
415 | << ">::" ; |
416 | |
417 | os << method.getName() << "(" ; |
418 | if (!method.isStatic()) { |
419 | emitCPPType(type: valueType, os); |
420 | os << "tablegen_opaque_val" ; |
421 | if (!method.arg_empty()) |
422 | os << ", " ; |
423 | } |
424 | llvm::interleaveComma(c: method.getArguments(), os, |
425 | each_fn: [&](const InterfaceMethod::Argument &arg) { |
426 | emitCPPType(type: arg.type, os); |
427 | os << arg.name; |
428 | }); |
429 | os << ")" ; |
430 | if (!method.isStatic()) |
431 | os << " const" ; |
432 | |
433 | os << " {\n" ; |
434 | |
435 | // Use the empty context for static methods. |
436 | tblgen::FmtContext ctx; |
437 | os << tblgen::tgfmt(fmt: method.getDefaultImplementation()->trim(), |
438 | ctx: method.isStatic() ? &ctx : &nonStaticMethodFmt); |
439 | os << "\n}\n" ; |
440 | } |
441 | |
442 | for (StringRef ns : llvm::reverse(C&: namespaces)) |
443 | os << "} // namespace " << ns << "\n" ; |
444 | } |
445 | |
446 | void InterfaceGenerator::emitTraitDecl(const Interface &interface, |
447 | StringRef interfaceName, |
448 | StringRef interfaceTraitsName) { |
449 | os << llvm::formatv(Fmt: " template <typename {3}>\n" |
450 | " struct {0}Trait : public ::mlir::{2}<{0}," |
451 | " detail::{1}>::Trait<{3}> {{\n" , |
452 | Vals&: interfaceName, Vals&: interfaceTraitsName, Vals&: interfaceBaseType, |
453 | Vals&: valueTemplate); |
454 | |
455 | // Insert the default implementation for any methods. |
456 | bool isOpInterface = isa<OpInterface>(Val: interface); |
457 | for (auto &method : interface.getMethods()) { |
458 | // Flag interface methods named verifyTrait. |
459 | if (method.getName() == "verifyTrait" ) |
460 | PrintFatalError( |
461 | Msg: formatv(Fmt: "'verifyTrait' method cannot be specified as interface " |
462 | "method for '{0}'; use the 'verify' field instead" , |
463 | Vals&: interfaceName)); |
464 | auto defaultImpl = method.getDefaultImplementation(); |
465 | if (!defaultImpl) |
466 | continue; |
467 | |
468 | emitInterfaceMethodDoc(method, os, prefix: " " ); |
469 | os << " " << (method.isStatic() ? "static " : "" ); |
470 | emitCPPType(type: method.getReturnType(), os); |
471 | emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, |
472 | /*addConst=*/!isOpInterface && !method.isStatic()); |
473 | os << " {\n " << tblgen::tgfmt(fmt: defaultImpl->trim(), ctx: &traitMethodFmt) |
474 | << "\n }\n" ; |
475 | } |
476 | |
477 | if (auto verify = interface.getVerify()) { |
478 | assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'" ); |
479 | |
480 | tblgen::FmtContext verifyCtx; |
481 | verifyCtx.addSubst(placeholder: "_op" , subst: "op" ); |
482 | os << llvm::formatv( |
483 | Fmt: " static ::mlir::LogicalResult {0}(::mlir::Operation *op) " , |
484 | Vals: (interface.verifyWithRegions() ? "verifyRegionTrait" |
485 | : "verifyTrait" )) |
486 | << "{\n " << tblgen::tgfmt(fmt: verify->trim(), ctx: &verifyCtx) |
487 | << "\n }\n" ; |
488 | } |
489 | if (auto = interface.getExtraTraitClassDeclaration()) |
490 | os << tblgen::tgfmt(fmt: *extraTraitDecls, ctx: &traitMethodFmt) << "\n" ; |
491 | if (auto = interface.getExtraSharedClassDeclaration()) |
492 | os << tblgen::tgfmt(fmt: *extraTraitDecls, ctx: &traitMethodFmt) << "\n" ; |
493 | |
494 | os << " };\n" ; |
495 | } |
496 | |
497 | static void emitInterfaceDeclMethods(const Interface &interface, |
498 | raw_ostream &os, StringRef valueType, |
499 | bool isOpInterface, |
500 | tblgen::FmtContext &) { |
501 | for (auto &method : interface.getMethods()) { |
502 | emitInterfaceMethodDoc(method, os, prefix: " " ); |
503 | emitCPPType(type: method.getReturnType(), os&: os << " " ); |
504 | emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, |
505 | /*addConst=*/!isOpInterface); |
506 | os << ";\n" ; |
507 | } |
508 | |
509 | // Emit any extra declarations. |
510 | if (std::optional<StringRef> = |
511 | interface.getExtraClassDeclaration()) |
512 | os << extraDecls->rtrim() << "\n" ; |
513 | if (std::optional<StringRef> = |
514 | interface.getExtraSharedClassDeclaration()) |
515 | os << tblgen::tgfmt(fmt: extraDecls->rtrim(), ctx: &extraDeclsFmt) << "\n" ; |
516 | } |
517 | |
518 | void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { |
519 | llvm::SmallVector<StringRef, 2> namespaces; |
520 | llvm::SplitString(Source: interface.getCppNamespace(), OutFragments&: namespaces, Delimiters: "::" ); |
521 | for (StringRef ns : namespaces) |
522 | os << "namespace " << ns << " {\n" ; |
523 | |
524 | StringRef interfaceName = interface.getName(); |
525 | auto interfaceTraitsName = (interfaceName + "InterfaceTraits" ).str(); |
526 | |
527 | // Emit a forward declaration of the interface class so that it becomes usable |
528 | // in the signature of its methods. |
529 | os << "class " << interfaceName << ";\n" ; |
530 | |
531 | // Emit the traits struct containing the concept and model declarations. |
532 | os << "namespace detail {\n" |
533 | << "struct " << interfaceTraitsName << " {\n" ; |
534 | emitConceptDecl(interface); |
535 | emitModelDecl(interface); |
536 | os << "};\n" ; |
537 | |
538 | // Emit the derived trait for the interface. |
539 | os << "template <typename " << valueTemplate << ">\n" ; |
540 | os << "struct " << interface.getName() << "Trait;\n" ; |
541 | |
542 | os << "\n} // namespace detail\n" ; |
543 | |
544 | // Emit the main interface class declaration. |
545 | os << llvm::formatv(Fmt: "class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" |
546 | "public:\n" |
547 | " using ::mlir::{3}<{1}, detail::{2}>::{3};\n" , |
548 | Vals&: interfaceName, Vals&: interfaceName, Vals&: interfaceTraitsName, |
549 | Vals&: interfaceBaseType); |
550 | |
551 | // Emit a utility wrapper trait class. |
552 | os << llvm::formatv(Fmt: " template <typename {1}>\n" |
553 | " struct Trait : public detail::{0}Trait<{1}> {{};\n" , |
554 | Vals&: interfaceName, Vals&: valueTemplate); |
555 | |
556 | // Insert the method declarations. |
557 | bool isOpInterface = isa<OpInterface>(Val: interface); |
558 | emitInterfaceDeclMethods(interface, os, valueType, isOpInterface, |
559 | extraDeclsFmt); |
560 | |
561 | // Insert the method declarations for base classes. |
562 | for (auto &base : interface.getBaseInterfaces()) { |
563 | std::string baseQualName = base.getFullyQualifiedName(); |
564 | os << " //" |
565 | "===---------------------------------------------------------------" |
566 | "-===//\n" |
567 | << " // Inherited from " << baseQualName << "\n" |
568 | << " //" |
569 | "===---------------------------------------------------------------" |
570 | "-===//\n\n" ; |
571 | |
572 | // Allow implicit conversion to the base interface. |
573 | os << " operator " << baseQualName << " () const {\n" |
574 | << " if (!*this) return nullptr;\n" |
575 | << " return " << baseQualName << "(*this, getImpl()->impl" |
576 | << base.getName() << ");\n" |
577 | << " }\n\n" ; |
578 | |
579 | // Inherit the base interface's methods. |
580 | emitInterfaceDeclMethods(interface: base, os, valueType, isOpInterface, extraDeclsFmt); |
581 | } |
582 | |
583 | // Emit classof code if necessary. |
584 | if (std::optional<StringRef> = interface.getExtraClassOf()) { |
585 | auto = tblgen::FmtContext(); |
586 | extraClassOfFmt.addSubst(placeholder: substVar, subst: "odsInterfaceInstance" ); |
587 | os << " static bool classof(" << valueType << " base) {\n" |
588 | << " auto* interface = getInterfaceFor(base);\n" |
589 | << " if (!interface)\n" |
590 | " return false;\n" |
591 | " " << interfaceName << " odsInterfaceInstance(base, interface);\n" |
592 | << " " << tblgen::tgfmt(fmt: extraClassOf->trim(), ctx: &extraClassOfFmt) |
593 | << "\n }\n" ; |
594 | } |
595 | |
596 | os << "};\n" ; |
597 | |
598 | os << "namespace detail {\n" ; |
599 | emitTraitDecl(interface, interfaceName, interfaceTraitsName); |
600 | os << "}// namespace detail\n" ; |
601 | |
602 | for (StringRef ns : llvm::reverse(C&: namespaces)) |
603 | os << "} // namespace " << ns << "\n" ; |
604 | } |
605 | |
606 | bool InterfaceGenerator::emitInterfaceDecls() { |
607 | llvm::emitSourceFileHeader(Desc: "Interface Declarations" , OS&: os); |
608 | // Sort according to ID, so defs are emitted in the order in which they appear |
609 | // in the Tablegen file. |
610 | std::vector<llvm::Record *> sortedDefs(defs); |
611 | llvm::sort(C&: sortedDefs, Comp: [](llvm::Record *lhs, llvm::Record *rhs) { |
612 | return lhs->getID() < rhs->getID(); |
613 | }); |
614 | for (const llvm::Record *def : sortedDefs) |
615 | emitInterfaceDecl(interface: Interface(def)); |
616 | for (const llvm::Record *def : sortedDefs) |
617 | emitModelMethodsDef(interface: Interface(def)); |
618 | return false; |
619 | } |
620 | |
621 | //===----------------------------------------------------------------------===// |
622 | // GEN: Interface documentation |
623 | //===----------------------------------------------------------------------===// |
624 | |
625 | static void emitInterfaceDoc(const llvm::Record &interfaceDef, |
626 | raw_ostream &os) { |
627 | Interface interface(&interfaceDef); |
628 | |
629 | // Emit the interface name followed by the description. |
630 | os << "## " << interface.getName() << " (`" << interfaceDef.getName() |
631 | << "`)\n\n" ; |
632 | if (auto description = interface.getDescription()) |
633 | mlir::tblgen::emitDescription(description: *description, os); |
634 | |
635 | // Emit the methods required by the interface. |
636 | os << "\n### Methods:\n" ; |
637 | for (const auto &method : interface.getMethods()) { |
638 | // Emit the method name. |
639 | os << "#### `" << method.getName() << "`\n\n```c++\n" ; |
640 | |
641 | // Emit the method signature. |
642 | if (method.isStatic()) |
643 | os << "static " ; |
644 | emitCPPType(type: method.getReturnType(), os) << method.getName() << '('; |
645 | llvm::interleaveComma(c: method.getArguments(), os, |
646 | each_fn: [&](const InterfaceMethod::Argument &arg) { |
647 | emitCPPType(type: arg.type, os) << arg.name; |
648 | }); |
649 | os << ");\n```\n" ; |
650 | |
651 | // Emit the description. |
652 | if (auto description = method.getDescription()) |
653 | mlir::tblgen::emitDescription(description: *description, os); |
654 | |
655 | // If the body is not provided, this method must be provided by the user. |
656 | if (!method.getBody()) |
657 | os << "\nNOTE: This method *must* be implemented by the user." ; |
658 | |
659 | os << "\n\n" ; |
660 | } |
661 | } |
662 | |
663 | bool InterfaceGenerator::emitInterfaceDocs() { |
664 | os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n" ; |
665 | os << "# " << interfaceBaseType << " definitions\n" ; |
666 | |
667 | for (const auto *def : defs) |
668 | emitInterfaceDoc(interfaceDef: *def, os); |
669 | return false; |
670 | } |
671 | |
672 | //===----------------------------------------------------------------------===// |
673 | // GEN: Interface registration hooks |
674 | //===----------------------------------------------------------------------===// |
675 | |
676 | namespace { |
677 | template <typename GeneratorT> |
678 | struct InterfaceGenRegistration { |
679 | InterfaceGenRegistration(StringRef genArg, StringRef genDesc) |
680 | : genDeclArg(("gen-" + genArg + "-interface-decls" ).str()), |
681 | genDefArg(("gen-" + genArg + "-interface-defs" ).str()), |
682 | genDocArg(("gen-" + genArg + "-interface-docs" ).str()), |
683 | genDeclDesc(("Generate " + genDesc + " interface declarations" ).str()), |
684 | genDefDesc(("Generate " + genDesc + " interface definitions" ).str()), |
685 | genDocDesc(("Generate " + genDesc + " interface documentation" ).str()), |
686 | genDecls(genDeclArg, genDeclDesc, |
687 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
688 | return GeneratorT(records, os).emitInterfaceDecls(); |
689 | }), |
690 | genDefs(genDefArg, genDefDesc, |
691 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
692 | return GeneratorT(records, os).emitInterfaceDefs(); |
693 | }), |
694 | genDocs(genDocArg, genDocDesc, |
695 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
696 | return GeneratorT(records, os).emitInterfaceDocs(); |
697 | }) {} |
698 | |
699 | std::string genDeclArg, genDefArg, genDocArg; |
700 | std::string genDeclDesc, genDefDesc, genDocDesc; |
701 | mlir::GenRegistration genDecls, genDefs, genDocs; |
702 | }; |
703 | } // namespace |
704 | |
705 | static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr" , |
706 | "attribute" ); |
707 | static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op" , "op" ); |
708 | static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type" , "type" ); |
709 | |