1//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10// binding classes wrapping a generic operation API.
11//
12//===----------------------------------------------------------------------===//
13
14#include "OpGenHelpers.h"
15
16#include "mlir/TableGen/GenInfo.h"
17#include "mlir/TableGen/Operator.h"
18#include "llvm/ADT/StringSet.h"
19#include "llvm/Support/CommandLine.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/TableGen/Error.h"
22#include "llvm/TableGen/Record.h"
23
24using namespace mlir;
25using namespace mlir::tblgen;
26using llvm::formatv;
27using llvm::Record;
28using llvm::RecordKeeper;
29
30/// File header and includes.
31/// {0} is the dialect namespace.
32constexpr const char *fileHeader = R"Py(
33# Autogenerated by mlir-tblgen; don't manually edit.
34
35from ._ods_common import _cext as _ods_cext
36from ._ods_common import (
37 equally_sized_accessor as _ods_equally_sized_accessor,
38 get_default_loc_context as _ods_get_default_loc_context,
39 get_op_result_or_op_results as _get_op_result_or_op_results,
40 get_op_results_or_values as _get_op_results_or_values,
41 segmented_accessor as _ods_segmented_accessor,
42)
43_ods_ir = _ods_cext.ir
44
45import builtins
46from typing import Sequence as _Sequence, Union as _Union
47
48)Py";
49
50/// Template for dialect class:
51/// {0} is the dialect namespace.
52constexpr const char *dialectClassTemplate = R"Py(
53@_ods_cext.register_dialect
54class _Dialect(_ods_ir.Dialect):
55 DIALECT_NAMESPACE = "{0}"
56)Py";
57
58constexpr const char *dialectExtensionTemplate = R"Py(
59from ._{0}_ops_gen import _Dialect
60)Py";
61
62/// Template for operation class:
63/// {0} is the Python class name;
64/// {1} is the operation name.
65constexpr const char *opClassTemplate = R"Py(
66@_ods_cext.register_operation(_Dialect)
67class {0}(_ods_ir.OpView):
68 OPERATION_NAME = "{1}"
69)Py";
70
71/// Template for class level declarations of operand and result
72/// segment specs.
73/// {0} is either "OPERAND" or "RESULT"
74/// {1} is the segment spec
75/// Each segment spec is either None (default) or an array of integers
76/// where:
77/// 1 = single element (expect non sequence operand/result)
78/// 0 = optional element (expect a value or std::nullopt)
79/// -1 = operand/result is a sequence corresponding to a variadic
80constexpr const char *opClassSizedSegmentsTemplate = R"Py(
81 _ODS_{0}_SEGMENTS = {1}
82)Py";
83
84/// Template for class level declarations of the _ODS_REGIONS spec:
85/// {0} is the minimum number of regions
86/// {1} is the Python bool literal for hasNoVariadicRegions
87constexpr const char *opClassRegionSpecTemplate = R"Py(
88 _ODS_REGIONS = ({0}, {1})
89)Py";
90
91/// Template for single-element accessor:
92/// {0} is the name of the accessor;
93/// {1} is either 'operand' or 'result';
94/// {2} is the position in the element list.
95constexpr const char *opSingleTemplate = R"Py(
96 @builtins.property
97 def {0}(self):
98 return self.operation.{1}s[{2}]
99)Py";
100
101/// Template for single-element accessor after a variable-length group:
102/// {0} is the name of the accessor;
103/// {1} is either 'operand' or 'result';
104/// {2} is the total number of element groups;
105/// {3} is the position of the current group in the group list.
106/// This works for both a single variadic group (non-negative length) and an
107/// single optional element (zero length if the element is absent).
108constexpr const char *opSingleAfterVariableTemplate = R"Py(
109 @builtins.property
110 def {0}(self):
111 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
112 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
113)Py";
114
115/// Template for an optional element accessor:
116/// {0} is the name of the accessor;
117/// {1} is either 'operand' or 'result';
118/// {2} is the total number of element groups;
119/// {3} is the position of the current group in the group list.
120/// This works if we have only one variable-length group (and it's the optional
121/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
122/// smaller than the total number of groups.
123constexpr const char *opOneOptionalTemplate = R"Py(
124 @builtins.property
125 def {0}(self):
126 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
127)Py";
128
129/// Template for the variadic group accessor in the single variadic group case:
130/// {0} is the name of the accessor;
131/// {1} is either 'operand' or 'result';
132/// {2} is the total number of element groups;
133/// {3} is the position of the current group in the group list.
134constexpr const char *opOneVariadicTemplate = R"Py(
135 @builtins.property
136 def {0}(self):
137 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
138 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
139)Py";
140
141/// First part of the template for equally-sized variadic group accessor:
142/// {0} is the name of the accessor;
143/// {1} is either 'operand' or 'result';
144/// {2} is the total number of non-variadic groups;
145/// {3} is the total number of variadic groups;
146/// {4} is the number of non-variadic groups preceding the current group;
147/// {5} is the number of variadic groups preceding the current group.
148constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
149 @builtins.property
150 def {0}(self):
151 start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
152
153/// Second part of the template for equally-sized case, accessing a single
154/// element:
155/// {0} is either 'operand' or 'result'.
156constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
157 return self.operation.{0}s[start]
158)Py";
159
160/// Second part of the template for equally-sized case, accessing a variadic
161/// group:
162/// {0} is either 'operand' or 'result'.
163constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
164 return self.operation.{0}s[start:start + elements_per_group]
165)Py";
166
167/// Template for an attribute-sized group accessor:
168/// {0} is the name of the accessor;
169/// {1} is either 'operand' or 'result';
170/// {2} is the position of the group in the group list;
171/// {3} is a return suffix (expected [0] for single-element, empty for
172/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
173constexpr const char *opVariadicSegmentTemplate = R"Py(
174 @builtins.property
175 def {0}(self):
176 {1}_range = _ods_segmented_accessor(
177 self.operation.{1}s,
178 self.operation.attributes["{1}SegmentSizes"], {2})
179 return {1}_range{3}
180)Py";
181
182/// Template for a suffix when accessing an optional element in the
183/// attribute-sized case:
184/// {0} is either 'operand' or 'result';
185constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
186 R"Py([0] if len({0}_range) > 0 else None)Py";
187
188/// Template for an operation attribute getter:
189/// {0} is the name of the attribute sanitized for Python;
190/// {1} is the original name of the attribute.
191constexpr const char *attributeGetterTemplate = R"Py(
192 @builtins.property
193 def {0}(self):
194 return self.operation.attributes["{1}"]
195)Py";
196
197/// Template for an optional operation attribute getter:
198/// {0} is the name of the attribute sanitized for Python;
199/// {1} is the original name of the attribute.
200constexpr const char *optionalAttributeGetterTemplate = R"Py(
201 @builtins.property
202 def {0}(self):
203 if "{1}" not in self.operation.attributes:
204 return None
205 return self.operation.attributes["{1}"]
206)Py";
207
208/// Template for a getter of a unit operation attribute, returns True of the
209/// unit attribute is present, False otherwise (unit attributes have meaning
210/// by mere presence):
211/// {0} is the name of the attribute sanitized for Python,
212/// {1} is the original name of the attribute.
213constexpr const char *unitAttributeGetterTemplate = R"Py(
214 @builtins.property
215 def {0}(self):
216 return "{1}" in self.operation.attributes
217)Py";
218
219/// Template for an operation attribute setter:
220/// {0} is the name of the attribute sanitized for Python;
221/// {1} is the original name of the attribute.
222constexpr const char *attributeSetterTemplate = R"Py(
223 @{0}.setter
224 def {0}(self, value):
225 if value is None:
226 raise ValueError("'None' not allowed as value for mandatory attributes")
227 self.operation.attributes["{1}"] = value
228)Py";
229
230/// Template for a setter of an optional operation attribute, setting to None
231/// removes the attribute:
232/// {0} is the name of the attribute sanitized for Python;
233/// {1} is the original name of the attribute.
234constexpr const char *optionalAttributeSetterTemplate = R"Py(
235 @{0}.setter
236 def {0}(self, value):
237 if value is not None:
238 self.operation.attributes["{1}"] = value
239 elif "{1}" in self.operation.attributes:
240 del self.operation.attributes["{1}"]
241)Py";
242
243/// Template for a setter of a unit operation attribute, setting to None or
244/// False removes the attribute:
245/// {0} is the name of the attribute sanitized for Python;
246/// {1} is the original name of the attribute.
247constexpr const char *unitAttributeSetterTemplate = R"Py(
248 @{0}.setter
249 def {0}(self, value):
250 if bool(value):
251 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
252 elif "{1}" in self.operation.attributes:
253 del self.operation.attributes["{1}"]
254)Py";
255
256/// Template for a deleter of an optional or a unit operation attribute, removes
257/// the attribute from the operation:
258/// {0} is the name of the attribute sanitized for Python;
259/// {1} is the original name of the attribute.
260constexpr const char *attributeDeleterTemplate = R"Py(
261 @{0}.deleter
262 def {0}(self):
263 del self.operation.attributes["{1}"]
264)Py";
265
266constexpr const char *regionAccessorTemplate = R"Py(
267 @builtins.property
268 def {0}(self):
269 return self.regions[{1}]
270)Py";
271
272constexpr const char *valueBuilderTemplate = R"Py(
273def {0}({2}) -> {4}:
274 return {1}({3}){5}
275)Py";
276
277constexpr const char *valueBuilderVariadicTemplate = R"Py(
278def {0}({2}) -> {4}:
279 return _get_op_result_or_op_results({1}({3}))
280)Py";
281
282static llvm::cl::OptionCategory
283 clOpPythonBindingCat("Options for -gen-python-op-bindings");
284
285static llvm::cl::opt<std::string>
286 clDialectName("bind-dialect",
287 llvm::cl::desc("The dialect to run the generator for"),
288 llvm::cl::init(Val: ""), llvm::cl::cat(clOpPythonBindingCat));
289
290static llvm::cl::opt<std::string> clDialectExtensionName(
291 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
292 llvm::cl::init(Val: ""), llvm::cl::cat(clOpPythonBindingCat));
293
294using AttributeClasses = DenseMap<StringRef, StringRef>;
295
296/// Checks whether `str` would shadow a generated variable or attribute
297/// part of the OpView API.
298static bool isODSReserved(StringRef str) {
299 static llvm::StringSet<> reserved(
300 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
301 "loc", "verify", "regions", "results", "self", "operation",
302 "DIALECT_NAMESPACE", "OPERATION_NAME"});
303 return str.starts_with(Prefix: "_ods_") || str.ends_with(Suffix: "_ods") ||
304 reserved.contains(key: str);
305}
306
307/// Modifies the `name` in a way that it becomes suitable for Python bindings
308/// (does not change the `name` if it already is suitable) and returns the
309/// modified version.
310static std::string sanitizeName(StringRef name) {
311 std::string processedStr = name.str();
312 std::replace_if(
313 first: processedStr.begin(), last: processedStr.end(),
314 pred: [](char c) { return !llvm::isAlnum(C: c); }, new_value: '_');
315
316 if (llvm::isDigit(C: *processedStr.begin()))
317 return "_" + processedStr;
318
319 if (isPythonReserved(str: processedStr) || isODSReserved(str: processedStr))
320 return processedStr + "_";
321 return processedStr;
322}
323
324static std::string attrSizedTraitForKind(const char *kind) {
325 return formatv(Fmt: "::mlir::OpTrait::AttrSized{0}{1}Segments",
326 Vals: StringRef(kind).take_front().upper(),
327 Vals: StringRef(kind).drop_front());
328}
329
330/// Emits accessors to "elements" of an Op definition. Currently, the supported
331/// elements are operands and results, indicated by `kind`, which must be either
332/// `operand` or `result` and is used verbatim in the emitted code.
333static void emitElementAccessors(
334 const Operator &op, raw_ostream &os, const char *kind,
335 unsigned numVariadicGroups, unsigned numElements,
336 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
337 getElement) {
338 assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
339 kind) &&
340 "unsupported kind");
341
342 // Traits indicating how to process variadic elements.
343 std::string sameSizeTrait = formatv(Fmt: "::mlir::OpTrait::SameVariadic{0}{1}Size",
344 Vals: StringRef(kind).take_front().upper(),
345 Vals: StringRef(kind).drop_front());
346 std::string attrSizedTrait = attrSizedTraitForKind(kind);
347
348 // If there is only one variable-length element group, its size can be
349 // inferred from the total number of elements. If there are none, the
350 // generation is straightforward.
351 if (numVariadicGroups <= 1) {
352 bool seenVariableLength = false;
353 for (unsigned i = 0; i < numElements; ++i) {
354 const NamedTypeConstraint &element = getElement(op, i);
355 if (element.isVariableLength())
356 seenVariableLength = true;
357 if (element.name.empty())
358 continue;
359 if (element.isVariableLength()) {
360 os << formatv(Fmt: element.isOptional() ? opOneOptionalTemplate
361 : opOneVariadicTemplate,
362 Vals: sanitizeName(name: element.name), Vals&: kind, Vals&: numElements, Vals&: i);
363 } else if (seenVariableLength) {
364 os << formatv(Fmt: opSingleAfterVariableTemplate, Vals: sanitizeName(name: element.name),
365 Vals&: kind, Vals&: numElements, Vals&: i);
366 } else {
367 os << formatv(Fmt: opSingleTemplate, Vals: sanitizeName(name: element.name), Vals&: kind, Vals&: i);
368 }
369 }
370 return;
371 }
372
373 // Handle the operations where variadic groups have the same size.
374 if (op.getTrait(trait: sameSizeTrait)) {
375 // Count the number of simple elements
376 unsigned numSimpleLength = 0;
377 for (unsigned i = 0; i < numElements; ++i) {
378 const NamedTypeConstraint &element = getElement(op, i);
379 if (!element.isVariableLength()) {
380 ++numSimpleLength;
381 }
382 }
383
384 // Generate the accessors
385 int numPrecedingSimple = 0;
386 int numPrecedingVariadic = 0;
387 for (unsigned i = 0; i < numElements; ++i) {
388 const NamedTypeConstraint &element = getElement(op, i);
389 if (!element.name.empty()) {
390 os << formatv(Fmt: opVariadicEqualPrefixTemplate, Vals: sanitizeName(name: element.name),
391 Vals&: kind, Vals&: numSimpleLength, Vals&: numVariadicGroups,
392 Vals&: numPrecedingSimple, Vals&: numPrecedingVariadic);
393 os << formatv(Fmt: element.isVariableLength()
394 ? opVariadicEqualVariadicTemplate
395 : opVariadicEqualSimpleTemplate,
396 Vals&: kind);
397 }
398 if (element.isVariableLength())
399 ++numPrecedingVariadic;
400 else
401 ++numPrecedingSimple;
402 }
403 return;
404 }
405
406 // Handle the operations where the size of groups (variadic or not) is
407 // provided as an attribute. For non-variadic elements, make sure to return
408 // an element rather than a singleton container.
409 if (op.getTrait(trait: attrSizedTrait)) {
410 for (unsigned i = 0; i < numElements; ++i) {
411 const NamedTypeConstraint &element = getElement(op, i);
412 if (element.name.empty())
413 continue;
414 std::string trailing;
415 if (!element.isVariableLength())
416 trailing = "[0]";
417 else if (element.isOptional())
418 trailing = std::string(
419 formatv(Fmt: opVariadicSegmentOptionalTrailingTemplate, Vals&: kind));
420 os << formatv(Fmt: opVariadicSegmentTemplate, Vals: sanitizeName(name: element.name), Vals&: kind,
421 Vals&: i, Vals&: trailing);
422 }
423 return;
424 }
425
426 llvm::PrintFatalError(Msg: "unsupported " + llvm::Twine(kind) + " structure");
427}
428
429/// Free function helpers accessing Operator components.
430static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
431static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
432 return op.getOperand(index: i);
433}
434static int getNumResults(const Operator &op) { return op.getNumResults(); }
435static const NamedTypeConstraint &getResult(const Operator &op, int i) {
436 return op.getResult(index: i);
437}
438
439/// Emits accessors to Op operands.
440static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
441 emitElementAccessors(op, os, kind: "operand", numVariadicGroups: op.getNumVariableLengthOperands(),
442 numElements: getNumOperands(op), getElement: getOperand);
443}
444
445/// Emits accessors Op results.
446static void emitResultAccessors(const Operator &op, raw_ostream &os) {
447 emitElementAccessors(op, os, kind: "result", numVariadicGroups: op.getNumVariableLengthResults(),
448 numElements: getNumResults(op), getElement: getResult);
449}
450
451/// Emits accessors to Op attributes.
452static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
453 for (const auto &namedAttr : op.getAttributes()) {
454 // Skip "derived" attributes because they are just C++ functions that we
455 // don't currently expose.
456 if (namedAttr.attr.isDerivedAttr())
457 continue;
458
459 if (namedAttr.name.empty())
460 continue;
461
462 std::string sanitizedName = sanitizeName(name: namedAttr.name);
463
464 // Unit attributes are handled specially.
465 if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
466 os << formatv(Fmt: unitAttributeGetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name);
467 os << formatv(Fmt: unitAttributeSetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name);
468 os << formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName, Vals: namedAttr.name);
469 continue;
470 }
471
472 if (namedAttr.attr.isOptional()) {
473 os << formatv(Fmt: optionalAttributeGetterTemplate, Vals&: sanitizedName,
474 Vals: namedAttr.name);
475 os << formatv(Fmt: optionalAttributeSetterTemplate, Vals&: sanitizedName,
476 Vals: namedAttr.name);
477 os << formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName, Vals: namedAttr.name);
478 } else {
479 os << formatv(Fmt: attributeGetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name);
480 os << formatv(Fmt: attributeSetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name);
481 // Non-optional attributes cannot be deleted.
482 }
483 }
484}
485
486/// Template for the default auto-generated builder.
487/// {0} is a comma-separated list of builder arguments, including the trailing
488/// `loc` and `ip`;
489/// {1} is the code populating `operands`, `results` and `attributes`,
490/// `successors` fields.
491constexpr const char *initTemplate = R"Py(
492 def __init__(self, {0}):
493 operands = []
494 results = []
495 attributes = {{}
496 regions = None
497 {1}
498 super().__init__({2})
499)Py";
500
501/// Template for appending a single element to the operand/result list.
502/// {0} is the field name.
503constexpr const char *singleOperandAppendTemplate = "operands.append({0})";
504constexpr const char *singleResultAppendTemplate = "results.append({0})";
505
506/// Template for appending an optional element to the operand/result list.
507/// {0} is the field name.
508constexpr const char *optionalAppendOperandTemplate =
509 "if {0} is not None: operands.append({0})";
510constexpr const char *optionalAppendAttrSizedOperandsTemplate =
511 "operands.append({0})";
512constexpr const char *optionalAppendResultTemplate =
513 "if {0} is not None: results.append({0})";
514
515/// Template for appending a list of elements to the operand/result list.
516/// {0} is the field name.
517constexpr const char *multiOperandAppendTemplate =
518 "operands.extend(_get_op_results_or_values({0}))";
519constexpr const char *multiOperandAppendPackTemplate =
520 "operands.append(_get_op_results_or_values({0}))";
521constexpr const char *multiResultAppendTemplate = "results.extend({0})";
522
523/// Template for attribute builder from raw input in the operation builder.
524/// {0} is the builder argument name;
525/// {1} is the attribute builder from raw;
526/// {2} is the attribute builder from raw.
527/// Use the value the user passed in if either it is already an Attribute or
528/// there is no method registered to make it an Attribute.
529constexpr const char *initAttributeWithBuilderTemplate =
530 R"Py(attributes["{1}"] = ({0} if (
531 isinstance({0}, _ods_ir.Attribute) or
532 not _ods_ir.AttrBuilder.contains('{2}')) else
533 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
534
535/// Template for attribute builder from raw input for optional attribute in the
536/// operation builder.
537/// {0} is the builder argument name;
538/// {1} is the attribute builder from raw;
539/// {2} is the attribute builder from raw.
540/// Use the value the user passed in if either it is already an Attribute or
541/// there is no method registered to make it an Attribute.
542constexpr const char *initOptionalAttributeWithBuilderTemplate =
543 R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
544 isinstance({0}, _ods_ir.Attribute) or
545 not _ods_ir.AttrBuilder.contains('{2}')) else
546 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
547
548constexpr const char *initUnitAttributeTemplate =
549 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
550 _ods_get_default_loc_context(loc)))Py";
551
552/// Template to initialize the successors list in the builder if there are any
553/// successors.
554/// {0} is the value to initialize the successors list to.
555constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
556
557/// Template to append or extend the list of successors in the builder.
558/// {0} is the list method ('append' or 'extend');
559/// {1} is the value to add.
560constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
561
562/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
563/// result types of the given operation.
564static bool hasSameArgumentAndResultTypes(const Operator &op) {
565 return op.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType") &&
566 op.getNumVariableLengthResults() == 0;
567}
568
569/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
570/// result types of the given operation.
571static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
572 return op.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType") &&
573 op.getNumVariableLengthResults() == 0;
574}
575
576/// Returns true if the InferTypeOpInterface can be used to infer result types
577/// of the given operation.
578static bool hasInferTypeInterface(const Operator &op) {
579 return op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait") &&
580 op.getNumRegions() == 0;
581}
582
583/// Returns true if there is a trait or interface that can be used to infer
584/// result types of the given operation.
585static bool canInferType(const Operator &op) {
586 return hasSameArgumentAndResultTypes(op) ||
587 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
588}
589
590/// Populates `builderArgs` with result names if the builder is expected to
591/// accept them as arguments.
592static void
593populateBuilderArgsResults(const Operator &op,
594 SmallVectorImpl<std::string> &builderArgs) {
595 if (canInferType(op))
596 return;
597
598 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
599 std::string name = op.getResultName(index: i).str();
600 if (name.empty()) {
601 if (op.getNumResults() == 1) {
602 // Special case for one result, make the default name be 'result'
603 // to properly match the built-in result accessor.
604 name = "result";
605 } else {
606 name = formatv(Fmt: "_gen_res_{0}", Vals&: i);
607 }
608 }
609 name = sanitizeName(name);
610 builderArgs.push_back(Elt: name);
611 }
612}
613
614/// Populates `builderArgs` with the Python-compatible names of builder function
615/// arguments using intermixed attributes and operands in the same order as they
616/// appear in the `arguments` field of the op definition. Additionally,
617/// `operandNames` is populated with names of operands in their order of
618/// appearance.
619static void populateBuilderArgs(const Operator &op,
620 SmallVectorImpl<std::string> &builderArgs,
621 SmallVectorImpl<std::string> &operandNames) {
622 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
623 std::string name = op.getArgName(index: i).str();
624 if (name.empty())
625 name = formatv(Fmt: "_gen_arg_{0}", Vals&: i);
626 name = sanitizeName(name);
627 builderArgs.push_back(Elt: name);
628 if (!isa<NamedAttribute *>(Val: op.getArg(index: i)))
629 operandNames.push_back(Elt: name);
630 }
631}
632
633/// Populates `builderArgs` with the Python-compatible names of builder function
634/// successor arguments. Additionally, `successorArgNames` is also populated.
635static void
636populateBuilderArgsSuccessors(const Operator &op,
637 SmallVectorImpl<std::string> &builderArgs,
638 SmallVectorImpl<std::string> &successorArgNames) {
639
640 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
641 NamedSuccessor successor = op.getSuccessor(index: i);
642 std::string name = std::string(successor.name);
643 if (name.empty())
644 name = formatv(Fmt: "_gen_successor_{0}", Vals&: i);
645 name = sanitizeName(name);
646 builderArgs.push_back(Elt: name);
647 successorArgNames.push_back(Elt: name);
648 }
649}
650
651/// Populates `builderLines` with additional lines that are required in the
652/// builder to set up operation attributes. `argNames` is expected to contain
653/// the names of builder arguments that correspond to op arguments, i.e. to the
654/// operands and attributes in the same order as they appear in the `arguments`
655/// field.
656static void
657populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
658 SmallVectorImpl<std::string> &builderLines) {
659 builderLines.push_back(Elt: "_ods_context = _ods_get_default_loc_context(loc)");
660 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
661 Argument arg = op.getArg(index: i);
662 auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg);
663 if (!attribute)
664 continue;
665
666 // Unit attributes are handled specially.
667 if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
668 builderLines.push_back(
669 Elt: formatv(Fmt: initUnitAttributeTemplate, Vals&: attribute->name, Vals: argNames[i]));
670 continue;
671 }
672
673 builderLines.push_back(Elt: formatv(
674 Fmt: attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
675 ? initOptionalAttributeWithBuilderTemplate
676 : initAttributeWithBuilderTemplate,
677 Vals: argNames[i], Vals&: attribute->name, Vals: attribute->attr.getAttrDefName()));
678 }
679}
680
681/// Populates `builderLines` with additional lines that are required in the
682/// builder to set up successors. successorArgNames is expected to correspond
683/// to the Python argument name for each successor on the op.
684static void
685populateBuilderLinesSuccessors(const Operator &op,
686 ArrayRef<std::string> successorArgNames,
687 SmallVectorImpl<std::string> &builderLines) {
688 if (successorArgNames.empty()) {
689 builderLines.push_back(Elt: formatv(Fmt: initSuccessorsTemplate, Vals: "None"));
690 return;
691 }
692
693 builderLines.push_back(Elt: formatv(Fmt: initSuccessorsTemplate, Vals: "[]"));
694 for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
695 auto &argName = successorArgNames[i];
696 const NamedSuccessor &successor = op.getSuccessor(index: i);
697 builderLines.push_back(Elt: formatv(Fmt: addSuccessorTemplate,
698 Vals: successor.isVariadic() ? "extend" : "append",
699 Vals: argName));
700 }
701}
702
703/// Populates `builderLines` with additional lines that are required in the
704/// builder to set up op operands.
705static void
706populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
707 SmallVectorImpl<std::string> &builderLines) {
708 bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "operand")) != nullptr;
709
710 // For each element, find or generate a name.
711 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
712 const NamedTypeConstraint &element = op.getOperand(index: i);
713 std::string name = names[i];
714
715 // Choose the formatting string based on the element kind.
716 StringRef formatString;
717 if (!element.isVariableLength()) {
718 formatString = singleOperandAppendTemplate;
719 } else if (element.isOptional()) {
720 if (sizedSegments) {
721 formatString = optionalAppendAttrSizedOperandsTemplate;
722 } else {
723 formatString = optionalAppendOperandTemplate;
724 }
725 } else {
726 assert(element.isVariadic() && "unhandled element group type");
727 // If emitting with sizedSegments, then we add the actual list-typed
728 // element. Otherwise, we extend the actual operands.
729 if (sizedSegments) {
730 formatString = multiOperandAppendPackTemplate;
731 } else {
732 formatString = multiOperandAppendTemplate;
733 }
734 }
735
736 builderLines.push_back(Elt: formatv(Fmt: formatString.data(), Vals&: name));
737 }
738}
739
740/// Python code template for deriving the operation result types from its
741/// attribute:
742/// - {0} is the name of the attribute from which to derive the types.
743constexpr const char *deriveTypeFromAttrTemplate =
744 R"Py(_ods_result_type_source_attr = attributes["{0}"]
745_ods_derived_result_type = (
746 _ods_ir.TypeAttr(_ods_result_type_source_attr).value
747 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
748 _ods_result_type_source_attr.type))Py";
749
750/// Python code template appending {0} type {1} times to the results list.
751constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
752
753/// Appends the given multiline string as individual strings into
754/// `builderLines`.
755static void appendLineByLine(StringRef string,
756 SmallVectorImpl<std::string> &builderLines) {
757
758 std::pair<StringRef, StringRef> split = std::make_pair(x&: string, y&: string);
759 do {
760 split = split.second.split(Separator: '\n');
761 builderLines.push_back(Elt: split.first.str());
762 } while (!split.second.empty());
763}
764
765/// Populates `builderLines` with additional lines that are required in the
766/// builder to set up op results.
767static void
768populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
769 SmallVectorImpl<std::string> &builderLines) {
770 bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "result")) != nullptr;
771
772 if (hasSameArgumentAndResultTypes(op)) {
773 builderLines.push_back(Elt: formatv(Fmt: appendSameResultsTemplate,
774 Vals: "operands[0].type", Vals: op.getNumResults()));
775 return;
776 }
777
778 if (hasFirstAttrDerivedResultTypes(op)) {
779 const NamedAttribute &firstAttr = op.getAttribute(index: 0);
780 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
781 "from which the type is derived");
782 appendLineByLine(string: formatv(Fmt: deriveTypeFromAttrTemplate, Vals: firstAttr.name).str(),
783 builderLines);
784 builderLines.push_back(Elt: formatv(Fmt: appendSameResultsTemplate,
785 Vals: "_ods_derived_result_type",
786 Vals: op.getNumResults()));
787 return;
788 }
789
790 if (hasInferTypeInterface(op))
791 return;
792
793 // For each element, find or generate a name.
794 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
795 const NamedTypeConstraint &element = op.getResult(index: i);
796 std::string name = names[i];
797
798 // Choose the formatting string based on the element kind.
799 StringRef formatString;
800 if (!element.isVariableLength()) {
801 formatString = singleResultAppendTemplate;
802 } else if (element.isOptional()) {
803 formatString = optionalAppendResultTemplate;
804 } else {
805 assert(element.isVariadic() && "unhandled element group type");
806 // If emitting with sizedSegments, then we add the actual list-typed
807 // element. Otherwise, we extend the actual operands.
808 if (sizedSegments) {
809 formatString = singleResultAppendTemplate;
810 } else {
811 formatString = multiResultAppendTemplate;
812 }
813 }
814
815 builderLines.push_back(Elt: formatv(Fmt: formatString.data(), Vals&: name));
816 }
817}
818
819/// If the operation has variadic regions, adds a builder argument to specify
820/// the number of those regions and builder lines to forward it to the generic
821/// constructor.
822static void populateBuilderRegions(const Operator &op,
823 SmallVectorImpl<std::string> &builderArgs,
824 SmallVectorImpl<std::string> &builderLines) {
825 if (op.hasNoVariadicRegions())
826 return;
827
828 // This is currently enforced when Operator is constructed.
829 assert(op.getNumVariadicRegions() == 1 &&
830 op.getRegion(op.getNumRegions() - 1).isVariadic() &&
831 "expected the last region to be varidic");
832
833 const NamedRegion &region = op.getRegion(index: op.getNumRegions() - 1);
834 std::string name =
835 ("num_" + region.name.take_front().lower() + region.name.drop_front())
836 .str();
837 builderArgs.push_back(Elt: name);
838 builderLines.push_back(
839 Elt: formatv(Fmt: "regions = {0} + {1}", Vals: op.getNumRegions() - 1, Vals&: name));
840}
841
842/// Emits a default builder constructing an operation from the list of its
843/// result types, followed by a list of its operands. Returns vector
844/// of fully built functionArgs for downstream users (to save having to
845/// rebuild anew).
846static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
847 raw_ostream &os) {
848 SmallVector<std::string> builderArgs;
849 SmallVector<std::string> builderLines;
850 SmallVector<std::string> operandArgNames;
851 SmallVector<std::string> successorArgNames;
852 builderArgs.reserve(N: op.getNumOperands() + op.getNumResults() +
853 op.getNumNativeAttributes() + op.getNumSuccessors());
854 populateBuilderArgsResults(op, builderArgs);
855 size_t numResultArgs = builderArgs.size();
856 populateBuilderArgs(op, builderArgs, operandNames&: operandArgNames);
857 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
858 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
859
860 populateBuilderLinesOperand(op, names: operandArgNames, builderLines);
861 populateBuilderLinesAttr(op, argNames: ArrayRef(builderArgs).drop_front(N: numResultArgs),
862 builderLines);
863 populateBuilderLinesResult(
864 op, names: ArrayRef(builderArgs).take_front(N: numResultArgs), builderLines);
865 populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
866 populateBuilderRegions(op, builderArgs, builderLines);
867
868 // Layout of builderArgs vector elements:
869 // [ result_args operand_attr_args successor_args regions ]
870
871 // Determine whether the argument corresponding to a given index into the
872 // builderArgs vector is a python keyword argument or not.
873 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
874 // All result, successor, and region arguments are positional arguments.
875 if ((builderArgIndex < numResultArgs) ||
876 (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
877 return false;
878 // Keyword arguments:
879 // - optional named attributes (including unit attributes)
880 // - default-valued named attributes
881 // - optional operands
882 Argument a = op.getArg(index: builderArgIndex - numResultArgs);
883 if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: a))
884 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
885 if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: a))
886 return ntype->isOptional();
887 return false;
888 };
889
890 // StringRefs in functionArgs refer to strings allocated by builderArgs.
891 SmallVector<StringRef> functionArgs;
892
893 // Add positional arguments.
894 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
895 if (!isKeywordArgFn(i))
896 functionArgs.push_back(Elt: builderArgs[i]);
897 }
898
899 // Add a bare '*' to indicate that all following arguments must be keyword
900 // arguments.
901 functionArgs.push_back(Elt: "*");
902
903 // Add a default 'None' value to each keyword arg string, and then add to the
904 // function args list.
905 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
906 if (isKeywordArgFn(i)) {
907 builderArgs[i].append(s: "=None");
908 functionArgs.push_back(Elt: builderArgs[i]);
909 }
910 }
911 functionArgs.push_back(Elt: "loc=None");
912 functionArgs.push_back(Elt: "ip=None");
913
914 SmallVector<std::string> initArgs;
915 initArgs.push_back(Elt: "self.OPERATION_NAME");
916 initArgs.push_back(Elt: "self._ODS_REGIONS");
917 initArgs.push_back(Elt: "self._ODS_OPERAND_SEGMENTS");
918 initArgs.push_back(Elt: "self._ODS_RESULT_SEGMENTS");
919 initArgs.push_back(Elt: "attributes=attributes");
920 if (!hasInferTypeInterface(op))
921 initArgs.push_back(Elt: "results=results");
922 initArgs.push_back(Elt: "operands=operands");
923 initArgs.push_back(Elt: "successors=_ods_successors");
924 initArgs.push_back(Elt: "regions=regions");
925 initArgs.push_back(Elt: "loc=loc");
926 initArgs.push_back(Elt: "ip=ip");
927
928 os << formatv(Fmt: initTemplate, Vals: llvm::join(R&: functionArgs, Separator: ", "),
929 Vals: llvm::join(R&: builderLines, Separator: "\n "), Vals: llvm::join(R&: initArgs, Separator: ", "));
930 return llvm::to_vector<8>(
931 Range: llvm::map_range(C&: functionArgs, F: [](StringRef s) { return s.str(); }));
932}
933
934static void emitSegmentSpec(
935 const Operator &op, const char *kind,
936 llvm::function_ref<int(const Operator &)> getNumElements,
937 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
938 getElement,
939 raw_ostream &os) {
940 std::string segmentSpec("[");
941 for (int i = 0, e = getNumElements(op); i < e; ++i) {
942 const NamedTypeConstraint &element = getElement(op, i);
943 if (element.isOptional()) {
944 segmentSpec.append(s: "0,");
945 } else if (element.isVariadic()) {
946 segmentSpec.append(s: "-1,");
947 } else {
948 segmentSpec.append(s: "1,");
949 }
950 }
951 segmentSpec.append(s: "]");
952
953 os << formatv(Fmt: opClassSizedSegmentsTemplate, Vals&: kind, Vals&: segmentSpec);
954}
955
956static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
957 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
958 // Note that the base OpView class defines this as (0, True).
959 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
960 os << formatv(Fmt: opClassRegionSpecTemplate, Vals&: minRegionCount,
961 Vals: op.hasNoVariadicRegions() ? "True" : "False");
962}
963
964/// Emits named accessors to regions.
965static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
966 for (const auto &en : llvm::enumerate(First: op.getRegions())) {
967 const NamedRegion &region = en.value();
968 if (region.name.empty())
969 continue;
970
971 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
972 "expected only the last region to be variadic");
973 os << formatv(Fmt: regionAccessorTemplate, Vals: sanitizeName(name: region.name),
974 Vals: std::to_string(val: en.index()) +
975 (region.isVariadic() ? ":" : ""));
976 }
977}
978
979/// Emits builder that extracts results from op
980static void emitValueBuilder(const Operator &op,
981 SmallVector<std::string> functionArgs,
982 raw_ostream &os) {
983 // Params with (possibly) default args.
984 auto valueBuilderParams =
985 llvm::map_range(C&: functionArgs, F: [](const std::string &argAndMaybeDefault) {
986 SmallVector<StringRef> argMaybeDefault =
987 llvm::to_vector<2>(Range: llvm::split(Str: argAndMaybeDefault, Separator: "="));
988 auto arg = llvm::convertToSnakeFromCamelCase(input: argMaybeDefault[0]);
989 if (argMaybeDefault.size() == 2)
990 return arg + "=" + argMaybeDefault[1].str();
991 return arg;
992 });
993 // Actual args passed to op builder (e.g., opParam=op_param).
994 auto opBuilderArgs = llvm::map_range(
995 C: llvm::make_filter_range(Range&: functionArgs,
996 Pred: [](const std::string &s) { return s != "*"; }),
997 F: [](const std::string &arg) {
998 auto lhs = *llvm::split(Str: arg, Separator: "=").begin();
999 return (lhs + "=" + llvm::convertToSnakeFromCamelCase(input: lhs)).str();
1000 });
1001 std::string nameWithoutDialect = sanitizeName(
1002 name: op.getOperationName().substr(pos: op.getOperationName().find(c: '.') + 1));
1003 if (nameWithoutDialect == op.getCppClassName())
1004 nameWithoutDialect += "_";
1005 std::string params = llvm::join(R&: valueBuilderParams, Separator: ", ");
1006 std::string args = llvm::join(R&: opBuilderArgs, Separator: ", ");
1007 const char *type =
1008 (op.getNumResults() > 1
1009 ? "_Sequence[_ods_ir.Value]"
1010 : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
1011 if (op.getNumVariableLengthResults() > 0) {
1012 os << formatv(Fmt: valueBuilderVariadicTemplate, Vals&: nameWithoutDialect,
1013 Vals: op.getCppClassName(), Vals&: params, Vals&: args, Vals&: type);
1014 } else {
1015 const char *results;
1016 if (op.getNumResults() == 0) {
1017 results = "";
1018 } else if (op.getNumResults() == 1) {
1019 results = ".result";
1020 } else {
1021 results = ".results";
1022 }
1023 os << formatv(Fmt: valueBuilderTemplate, Vals&: nameWithoutDialect,
1024 Vals: op.getCppClassName(), Vals&: params, Vals&: args, Vals&: type, Vals&: results);
1025 }
1026}
1027
1028/// Emits bindings for a specific Op to the given output stream.
1029static void emitOpBindings(const Operator &op, raw_ostream &os) {
1030 os << formatv(Fmt: opClassTemplate, Vals: op.getCppClassName(), Vals: op.getOperationName());
1031
1032 // Sized segments.
1033 if (op.getTrait(trait: attrSizedTraitForKind(kind: "operand")) != nullptr) {
1034 emitSegmentSpec(op, kind: "OPERAND", getNumElements: getNumOperands, getElement: getOperand, os);
1035 }
1036 if (op.getTrait(trait: attrSizedTraitForKind(kind: "result")) != nullptr) {
1037 emitSegmentSpec(op, kind: "RESULT", getNumElements: getNumResults, getElement: getResult, os);
1038 }
1039
1040 emitRegionAttributes(op, os);
1041 SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
1042 emitOperandAccessors(op, os);
1043 emitAttributeAccessors(op, os);
1044 emitResultAccessors(op, os);
1045 emitRegionAccessors(op, os);
1046 emitValueBuilder(op, functionArgs, os);
1047}
1048
1049/// Emits bindings for the dialect specified in the command line, including file
1050/// headers and utilities. Returns `false` on success to comply with Tablegen
1051/// registration requirements.
1052static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
1053 if (clDialectName.empty())
1054 llvm::PrintFatalError(Msg: "dialect name not provided");
1055
1056 os << fileHeader;
1057 if (!clDialectExtensionName.empty())
1058 os << formatv(Fmt: dialectExtensionTemplate, Vals&: clDialectName.getValue());
1059 else
1060 os << formatv(Fmt: dialectClassTemplate, Vals&: clDialectName.getValue());
1061
1062 for (const Record *rec : records.getAllDerivedDefinitions(ClassName: "Op")) {
1063 Operator op(rec);
1064 if (op.getDialectName() == clDialectName.getValue())
1065 emitOpBindings(op, os);
1066 }
1067 return false;
1068}
1069
1070static GenRegistration
1071 genPythonBindings("gen-python-op-bindings",
1072 "Generate Python bindings for MLIR Ops", &emitAllOps);
1073

source code of mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp