1//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10// binding classes wrapping a generic operation API.
11//
12//===----------------------------------------------------------------------===//
13
14#include "OpGenHelpers.h"
15
16#include "mlir/TableGen/GenInfo.h"
17#include "mlir/TableGen/Operator.h"
18#include "llvm/ADT/StringSet.h"
19#include "llvm/Support/CommandLine.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/TableGen/Error.h"
22#include "llvm/TableGen/Record.h"
23
24using namespace mlir;
25using namespace mlir::tblgen;
26
27/// File header and includes.
28/// {0} is the dialect namespace.
29constexpr const char *fileHeader = R"Py(
30# Autogenerated by mlir-tblgen; don't manually edit.
31
32from ._ods_common import _cext as _ods_cext
33from ._ods_common import (
34 equally_sized_accessor as _ods_equally_sized_accessor,
35 get_default_loc_context as _ods_get_default_loc_context,
36 get_op_result_or_op_results as _get_op_result_or_op_results,
37 get_op_result_or_value as _get_op_result_or_value,
38 get_op_results_or_values as _get_op_results_or_values,
39 segmented_accessor as _ods_segmented_accessor,
40)
41_ods_ir = _ods_cext.ir
42
43import builtins
44from typing import Sequence as _Sequence, Union as _Union
45
46)Py";
47
48/// Template for dialect class:
49/// {0} is the dialect namespace.
50constexpr const char *dialectClassTemplate = R"Py(
51@_ods_cext.register_dialect
52class _Dialect(_ods_ir.Dialect):
53 DIALECT_NAMESPACE = "{0}"
54)Py";
55
56constexpr const char *dialectExtensionTemplate = R"Py(
57from ._{0}_ops_gen import _Dialect
58)Py";
59
60/// Template for operation class:
61/// {0} is the Python class name;
62/// {1} is the operation name.
63constexpr const char *opClassTemplate = R"Py(
64@_ods_cext.register_operation(_Dialect)
65class {0}(_ods_ir.OpView):
66 OPERATION_NAME = "{1}"
67)Py";
68
69/// Template for class level declarations of operand and result
70/// segment specs.
71/// {0} is either "OPERAND" or "RESULT"
72/// {1} is the segment spec
73/// Each segment spec is either None (default) or an array of integers
74/// where:
75/// 1 = single element (expect non sequence operand/result)
76/// 0 = optional element (expect a value or std::nullopt)
77/// -1 = operand/result is a sequence corresponding to a variadic
78constexpr const char *opClassSizedSegmentsTemplate = R"Py(
79 _ODS_{0}_SEGMENTS = {1}
80)Py";
81
82/// Template for class level declarations of the _ODS_REGIONS spec:
83/// {0} is the minimum number of regions
84/// {1} is the Python bool literal for hasNoVariadicRegions
85constexpr const char *opClassRegionSpecTemplate = R"Py(
86 _ODS_REGIONS = ({0}, {1})
87)Py";
88
89/// Template for single-element accessor:
90/// {0} is the name of the accessor;
91/// {1} is either 'operand' or 'result';
92/// {2} is the position in the element list.
93constexpr const char *opSingleTemplate = R"Py(
94 @builtins.property
95 def {0}(self):
96 return self.operation.{1}s[{2}]
97)Py";
98
99/// Template for single-element accessor after a variable-length group:
100/// {0} is the name of the accessor;
101/// {1} is either 'operand' or 'result';
102/// {2} is the total number of element groups;
103/// {3} is the position of the current group in the group list.
104/// This works for both a single variadic group (non-negative length) and an
105/// single optional element (zero length if the element is absent).
106constexpr const char *opSingleAfterVariableTemplate = R"Py(
107 @builtins.property
108 def {0}(self):
109 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
110 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
111)Py";
112
113/// Template for an optional element accessor:
114/// {0} is the name of the accessor;
115/// {1} is either 'operand' or 'result';
116/// {2} is the total number of element groups;
117/// {3} is the position of the current group in the group list.
118/// This works if we have only one variable-length group (and it's the optional
119/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
120/// smaller than the total number of groups.
121constexpr const char *opOneOptionalTemplate = R"Py(
122 @builtins.property
123 def {0}(self):
124 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
125)Py";
126
127/// Template for the variadic group accessor in the single variadic group case:
128/// {0} is the name of the accessor;
129/// {1} is either 'operand' or 'result';
130/// {2} is the total number of element groups;
131/// {3} is the position of the current group in the group list.
132constexpr const char *opOneVariadicTemplate = R"Py(
133 @builtins.property
134 def {0}(self):
135 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
136 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
137)Py";
138
139/// First part of the template for equally-sized variadic group accessor:
140/// {0} is the name of the accessor;
141/// {1} is either 'operand' or 'result';
142/// {2} is the total number of variadic groups;
143/// {3} is the number of non-variadic groups preceding the current group;
144/// {3} is the number of variadic groups preceding the current group.
145constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
146 @builtins.property
147 def {0}(self):
148 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
149
150/// Second part of the template for equally-sized case, accessing a single
151/// element:
152/// {0} is either 'operand' or 'result'.
153constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
154 return self.operation.{0}s[start]
155)Py";
156
157/// Second part of the template for equally-sized case, accessing a variadic
158/// group:
159/// {0} is either 'operand' or 'result'.
160constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
161 return self.operation.{0}s[start:start + pg]
162)Py";
163
164/// Template for an attribute-sized group accessor:
165/// {0} is the name of the accessor;
166/// {1} is either 'operand' or 'result';
167/// {2} is the position of the group in the group list;
168/// {3} is a return suffix (expected [0] for single-element, empty for
169/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
170constexpr const char *opVariadicSegmentTemplate = R"Py(
171 @builtins.property
172 def {0}(self):
173 {1}_range = _ods_segmented_accessor(
174 self.operation.{1}s,
175 self.operation.attributes["{1}SegmentSizes"], {2})
176 return {1}_range{3}
177)Py";
178
179/// Template for a suffix when accessing an optional element in the
180/// attribute-sized case:
181/// {0} is either 'operand' or 'result';
182constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
183 R"Py([0] if len({0}_range) > 0 else None)Py";
184
185/// Template for an operation attribute getter:
186/// {0} is the name of the attribute sanitized for Python;
187/// {1} is the original name of the attribute.
188constexpr const char *attributeGetterTemplate = R"Py(
189 @builtins.property
190 def {0}(self):
191 return self.operation.attributes["{1}"]
192)Py";
193
194/// Template for an optional operation attribute getter:
195/// {0} is the name of the attribute sanitized for Python;
196/// {1} is the original name of the attribute.
197constexpr const char *optionalAttributeGetterTemplate = R"Py(
198 @builtins.property
199 def {0}(self):
200 if "{1}" not in self.operation.attributes:
201 return None
202 return self.operation.attributes["{1}"]
203)Py";
204
205/// Template for a getter of a unit operation attribute, returns True of the
206/// unit attribute is present, False otherwise (unit attributes have meaning
207/// by mere presence):
208/// {0} is the name of the attribute sanitized for Python,
209/// {1} is the original name of the attribute.
210constexpr const char *unitAttributeGetterTemplate = R"Py(
211 @builtins.property
212 def {0}(self):
213 return "{1}" in self.operation.attributes
214)Py";
215
216/// Template for an operation attribute setter:
217/// {0} is the name of the attribute sanitized for Python;
218/// {1} is the original name of the attribute.
219constexpr const char *attributeSetterTemplate = R"Py(
220 @{0}.setter
221 def {0}(self, value):
222 if value is None:
223 raise ValueError("'None' not allowed as value for mandatory attributes")
224 self.operation.attributes["{1}"] = value
225)Py";
226
227/// Template for a setter of an optional operation attribute, setting to None
228/// removes the attribute:
229/// {0} is the name of the attribute sanitized for Python;
230/// {1} is the original name of the attribute.
231constexpr const char *optionalAttributeSetterTemplate = R"Py(
232 @{0}.setter
233 def {0}(self, value):
234 if value is not None:
235 self.operation.attributes["{1}"] = value
236 elif "{1}" in self.operation.attributes:
237 del self.operation.attributes["{1}"]
238)Py";
239
240/// Template for a setter of a unit operation attribute, setting to None or
241/// False removes the attribute:
242/// {0} is the name of the attribute sanitized for Python;
243/// {1} is the original name of the attribute.
244constexpr const char *unitAttributeSetterTemplate = R"Py(
245 @{0}.setter
246 def {0}(self, value):
247 if bool(value):
248 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
249 elif "{1}" in self.operation.attributes:
250 del self.operation.attributes["{1}"]
251)Py";
252
253/// Template for a deleter of an optional or a unit operation attribute, removes
254/// the attribute from the operation:
255/// {0} is the name of the attribute sanitized for Python;
256/// {1} is the original name of the attribute.
257constexpr const char *attributeDeleterTemplate = R"Py(
258 @{0}.deleter
259 def {0}(self):
260 del self.operation.attributes["{1}"]
261)Py";
262
263constexpr const char *regionAccessorTemplate = R"Py(
264 @builtins.property
265 def {0}(self):
266 return self.regions[{1}]
267)Py";
268
269constexpr const char *valueBuilderTemplate = R"Py(
270def {0}({2}) -> {4}:
271 return _get_op_result_or_op_results({1}({3}))
272)Py";
273
274static llvm::cl::OptionCategory
275 clOpPythonBindingCat("Options for -gen-python-op-bindings");
276
277static llvm::cl::opt<std::string>
278 clDialectName("bind-dialect",
279 llvm::cl::desc("The dialect to run the generator for"),
280 llvm::cl::init(Val: ""), llvm::cl::cat(clOpPythonBindingCat));
281
282static llvm::cl::opt<std::string> clDialectExtensionName(
283 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
284 llvm::cl::init(Val: ""), llvm::cl::cat(clOpPythonBindingCat));
285
286using AttributeClasses = DenseMap<StringRef, StringRef>;
287
288/// Checks whether `str` would shadow a generated variable or attribute
289/// part of the OpView API.
290static bool isODSReserved(StringRef str) {
291 static llvm::StringSet<> reserved(
292 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
293 "loc", "verify", "regions", "results", "self", "operation",
294 "DIALECT_NAMESPACE", "OPERATION_NAME"});
295 return str.starts_with(Prefix: "_ods_") || str.ends_with(Suffix: "_ods") ||
296 reserved.contains(key: str);
297}
298
299/// Modifies the `name` in a way that it becomes suitable for Python bindings
300/// (does not change the `name` if it already is suitable) and returns the
301/// modified version.
302static std::string sanitizeName(StringRef name) {
303 std::string processedStr = name.str();
304 std::replace_if(
305 first: processedStr.begin(), last: processedStr.end(),
306 pred: [](char c) { return !llvm::isAlnum(C: c); }, new_value: '_');
307
308 if (llvm::isDigit(C: *processedStr.begin()))
309 return "_" + processedStr;
310
311 if (isPythonReserved(str: processedStr) || isODSReserved(str: processedStr))
312 return processedStr + "_";
313 return processedStr;
314}
315
316static std::string attrSizedTraitForKind(const char *kind) {
317 return llvm::formatv(Fmt: "::mlir::OpTrait::AttrSized{0}{1}Segments",
318 Vals: llvm::StringRef(kind).take_front().upper(),
319 Vals: llvm::StringRef(kind).drop_front());
320}
321
322/// Emits accessors to "elements" of an Op definition. Currently, the supported
323/// elements are operands and results, indicated by `kind`, which must be either
324/// `operand` or `result` and is used verbatim in the emitted code.
325static void emitElementAccessors(
326 const Operator &op, raw_ostream &os, const char *kind,
327 llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
328 llvm::function_ref<int(const Operator &)> getNumElements,
329 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
330 getElement) {
331 assert(llvm::is_contained(
332 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
333 "unsupported kind");
334
335 // Traits indicating how to process variadic elements.
336 std::string sameSizeTrait =
337 llvm::formatv(Fmt: "::mlir::OpTrait::SameVariadic{0}{1}Size",
338 Vals: llvm::StringRef(kind).take_front().upper(),
339 Vals: llvm::StringRef(kind).drop_front());
340 std::string attrSizedTrait = attrSizedTraitForKind(kind);
341
342 unsigned numVariableLength = getNumVariableLength(op);
343
344 // If there is only one variable-length element group, its size can be
345 // inferred from the total number of elements. If there are none, the
346 // generation is straightforward.
347 if (numVariableLength <= 1) {
348 bool seenVariableLength = false;
349 for (int i = 0, e = getNumElements(op); i < e; ++i) {
350 const NamedTypeConstraint &element = getElement(op, i);
351 if (element.isVariableLength())
352 seenVariableLength = true;
353 if (element.name.empty())
354 continue;
355 if (element.isVariableLength()) {
356 os << llvm::formatv(Fmt: element.isOptional() ? opOneOptionalTemplate
357 : opOneVariadicTemplate,
358 Vals: sanitizeName(name: element.name), Vals&: kind,
359 Vals: getNumElements(op), Vals&: i);
360 } else if (seenVariableLength) {
361 os << llvm::formatv(Fmt: opSingleAfterVariableTemplate,
362 Vals: sanitizeName(name: element.name), Vals&: kind,
363 Vals: getNumElements(op), Vals&: i);
364 } else {
365 os << llvm::formatv(Fmt: opSingleTemplate, Vals: sanitizeName(name: element.name), Vals&: kind,
366 Vals&: i);
367 }
368 }
369 return;
370 }
371
372 // Handle the operations where variadic groups have the same size.
373 if (op.getTrait(trait: sameSizeTrait)) {
374 int numPrecedingSimple = 0;
375 int numPrecedingVariadic = 0;
376 for (int i = 0, e = getNumElements(op); i < e; ++i) {
377 const NamedTypeConstraint &element = getElement(op, i);
378 if (!element.name.empty()) {
379 os << llvm::formatv(Fmt: opVariadicEqualPrefixTemplate,
380 Vals: sanitizeName(name: element.name), Vals&: kind, Vals&: numVariableLength,
381 Vals&: numPrecedingSimple, Vals&: numPrecedingVariadic);
382 os << llvm::formatv(Fmt: element.isVariableLength()
383 ? opVariadicEqualVariadicTemplate
384 : opVariadicEqualSimpleTemplate,
385 Vals&: kind);
386 }
387 if (element.isVariableLength())
388 ++numPrecedingVariadic;
389 else
390 ++numPrecedingSimple;
391 }
392 return;
393 }
394
395 // Handle the operations where the size of groups (variadic or not) is
396 // provided as an attribute. For non-variadic elements, make sure to return
397 // an element rather than a singleton container.
398 if (op.getTrait(trait: attrSizedTrait)) {
399 for (int i = 0, e = getNumElements(op); i < e; ++i) {
400 const NamedTypeConstraint &element = getElement(op, i);
401 if (element.name.empty())
402 continue;
403 std::string trailing;
404 if (!element.isVariableLength())
405 trailing = "[0]";
406 else if (element.isOptional())
407 trailing = std::string(
408 llvm::formatv(Fmt: opVariadicSegmentOptionalTrailingTemplate, Vals&: kind));
409 os << llvm::formatv(Fmt: opVariadicSegmentTemplate, Vals: sanitizeName(name: element.name),
410 Vals&: kind, Vals&: i, Vals&: trailing);
411 }
412 return;
413 }
414
415 llvm::PrintFatalError(Msg: "unsupported " + llvm::Twine(kind) + " structure");
416}
417
418/// Free function helpers accessing Operator components.
419static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
420static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
421 return op.getOperand(index: i);
422}
423static int getNumResults(const Operator &op) { return op.getNumResults(); }
424static const NamedTypeConstraint &getResult(const Operator &op, int i) {
425 return op.getResult(index: i);
426}
427
428/// Emits accessors to Op operands.
429static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
430 auto getNumVariableLengthOperands = [](const Operator &oper) {
431 return oper.getNumVariableLengthOperands();
432 };
433 emitElementAccessors(op, os, kind: "operand", getNumVariableLength: getNumVariableLengthOperands,
434 getNumElements: getNumOperands, getElement: getOperand);
435}
436
437/// Emits accessors Op results.
438static void emitResultAccessors(const Operator &op, raw_ostream &os) {
439 auto getNumVariableLengthResults = [](const Operator &oper) {
440 return oper.getNumVariableLengthResults();
441 };
442 emitElementAccessors(op, os, kind: "result", getNumVariableLength: getNumVariableLengthResults,
443 getNumElements: getNumResults, getElement: getResult);
444}
445
446/// Emits accessors to Op attributes.
447static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
448 for (const auto &namedAttr : op.getAttributes()) {
449 // Skip "derived" attributes because they are just C++ functions that we
450 // don't currently expose.
451 if (namedAttr.attr.isDerivedAttr())
452 continue;
453
454 if (namedAttr.name.empty())
455 continue;
456
457 std::string sanitizedName = sanitizeName(name: namedAttr.name);
458
459 // Unit attributes are handled specially.
460 if (namedAttr.attr.getStorageType().trim().equals(RHS: "::mlir::UnitAttr")) {
461 os << llvm::formatv(Fmt: unitAttributeGetterTemplate, Vals&: sanitizedName,
462 Vals: namedAttr.name);
463 os << llvm::formatv(Fmt: unitAttributeSetterTemplate, Vals&: sanitizedName,
464 Vals: namedAttr.name);
465 os << llvm::formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName,
466 Vals: namedAttr.name);
467 continue;
468 }
469
470 if (namedAttr.attr.isOptional()) {
471 os << llvm::formatv(Fmt: optionalAttributeGetterTemplate, Vals&: sanitizedName,
472 Vals: namedAttr.name);
473 os << llvm::formatv(Fmt: optionalAttributeSetterTemplate, Vals&: sanitizedName,
474 Vals: namedAttr.name);
475 os << llvm::formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName,
476 Vals: namedAttr.name);
477 } else {
478 os << llvm::formatv(Fmt: attributeGetterTemplate, Vals&: sanitizedName,
479 Vals: namedAttr.name);
480 os << llvm::formatv(Fmt: attributeSetterTemplate, Vals&: sanitizedName,
481 Vals: namedAttr.name);
482 // Non-optional attributes cannot be deleted.
483 }
484 }
485}
486
487/// Template for the default auto-generated builder.
488/// {0} is a comma-separated list of builder arguments, including the trailing
489/// `loc` and `ip`;
490/// {1} is the code populating `operands`, `results` and `attributes`,
491/// `successors` fields.
492constexpr const char *initTemplate = R"Py(
493 def __init__(self, {0}):
494 operands = []
495 results = []
496 attributes = {{}
497 regions = None
498 {1}
499 super().__init__(self.build_generic({2}))
500)Py";
501
502/// Template for appending a single element to the operand/result list.
503/// {0} is the field name.
504constexpr const char *singleOperandAppendTemplate =
505 "operands.append(_get_op_result_or_value({0}))";
506constexpr const char *singleResultAppendTemplate = "results.append({0})";
507
508/// Template for appending an optional element to the operand/result list.
509/// {0} is the field name.
510constexpr const char *optionalAppendOperandTemplate =
511 "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
512constexpr const char *optionalAppendAttrSizedOperandsTemplate =
513 "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
514 "None)";
515constexpr const char *optionalAppendResultTemplate =
516 "if {0} is not None: results.append({0})";
517
518/// Template for appending a list of elements to the operand/result list.
519/// {0} is the field name.
520constexpr const char *multiOperandAppendTemplate =
521 "operands.extend(_get_op_results_or_values({0}))";
522constexpr const char *multiOperandAppendPackTemplate =
523 "operands.append(_get_op_results_or_values({0}))";
524constexpr const char *multiResultAppendTemplate = "results.extend({0})";
525
526/// Template for attribute builder from raw input in the operation builder.
527/// {0} is the builder argument name;
528/// {1} is the attribute builder from raw;
529/// {2} is the attribute builder from raw.
530/// Use the value the user passed in if either it is already an Attribute or
531/// there is no method registered to make it an Attribute.
532constexpr const char *initAttributeWithBuilderTemplate =
533 R"Py(attributes["{1}"] = ({0} if (
534 isinstance({0}, _ods_ir.Attribute) or
535 not _ods_ir.AttrBuilder.contains('{2}')) else
536 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
537
538/// Template for attribute builder from raw input for optional attribute in the
539/// operation builder.
540/// {0} is the builder argument name;
541/// {1} is the attribute builder from raw;
542/// {2} is the attribute builder from raw.
543/// Use the value the user passed in if either it is already an Attribute or
544/// there is no method registered to make it an Attribute.
545constexpr const char *initOptionalAttributeWithBuilderTemplate =
546 R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
547 isinstance({0}, _ods_ir.Attribute) or
548 not _ods_ir.AttrBuilder.contains('{2}')) else
549 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
550
551constexpr const char *initUnitAttributeTemplate =
552 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
553 _ods_get_default_loc_context(loc)))Py";
554
555/// Template to initialize the successors list in the builder if there are any
556/// successors.
557/// {0} is the value to initialize the successors list to.
558constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
559
560/// Template to append or extend the list of successors in the builder.
561/// {0} is the list method ('append' or 'extend');
562/// {1} is the value to add.
563constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
564
565/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
566/// result types of the given operation.
567static bool hasSameArgumentAndResultTypes(const Operator &op) {
568 return op.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType") &&
569 op.getNumVariableLengthResults() == 0;
570}
571
572/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
573/// result types of the given operation.
574static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
575 return op.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType") &&
576 op.getNumVariableLengthResults() == 0;
577}
578
579/// Returns true if the InferTypeOpInterface can be used to infer result types
580/// of the given operation.
581static bool hasInferTypeInterface(const Operator &op) {
582 return op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait") &&
583 op.getNumRegions() == 0;
584}
585
586/// Returns true if there is a trait or interface that can be used to infer
587/// result types of the given operation.
588static bool canInferType(const Operator &op) {
589 return hasSameArgumentAndResultTypes(op) ||
590 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
591}
592
593/// Populates `builderArgs` with result names if the builder is expected to
594/// accept them as arguments.
595static void
596populateBuilderArgsResults(const Operator &op,
597 llvm::SmallVectorImpl<std::string> &builderArgs) {
598 if (canInferType(op))
599 return;
600
601 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
602 std::string name = op.getResultName(index: i).str();
603 if (name.empty()) {
604 if (op.getNumResults() == 1) {
605 // Special case for one result, make the default name be 'result'
606 // to properly match the built-in result accessor.
607 name = "result";
608 } else {
609 name = llvm::formatv(Fmt: "_gen_res_{0}", Vals&: i);
610 }
611 }
612 name = sanitizeName(name);
613 builderArgs.push_back(Elt: name);
614 }
615}
616
617/// Populates `builderArgs` with the Python-compatible names of builder function
618/// arguments using intermixed attributes and operands in the same order as they
619/// appear in the `arguments` field of the op definition. Additionally,
620/// `operandNames` is populated with names of operands in their order of
621/// appearance.
622static void
623populateBuilderArgs(const Operator &op,
624 llvm::SmallVectorImpl<std::string> &builderArgs,
625 llvm::SmallVectorImpl<std::string> &operandNames) {
626 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
627 std::string name = op.getArgName(index: i).str();
628 if (name.empty())
629 name = llvm::formatv(Fmt: "_gen_arg_{0}", Vals&: i);
630 name = sanitizeName(name);
631 builderArgs.push_back(Elt: name);
632 if (!op.getArg(index: i).is<NamedAttribute *>())
633 operandNames.push_back(Elt: name);
634 }
635}
636
637/// Populates `builderArgs` with the Python-compatible names of builder function
638/// successor arguments. Additionally, `successorArgNames` is also populated.
639static void populateBuilderArgsSuccessors(
640 const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
641 llvm::SmallVectorImpl<std::string> &successorArgNames) {
642
643 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
644 NamedSuccessor successor = op.getSuccessor(index: i);
645 std::string name = std::string(successor.name);
646 if (name.empty())
647 name = llvm::formatv(Fmt: "_gen_successor_{0}", Vals&: i);
648 name = sanitizeName(name);
649 builderArgs.push_back(Elt: name);
650 successorArgNames.push_back(Elt: name);
651 }
652}
653
654/// Populates `builderLines` with additional lines that are required in the
655/// builder to set up operation attributes. `argNames` is expected to contain
656/// the names of builder arguments that correspond to op arguments, i.e. to the
657/// operands and attributes in the same order as they appear in the `arguments`
658/// field.
659static void
660populateBuilderLinesAttr(const Operator &op,
661 llvm::ArrayRef<std::string> argNames,
662 llvm::SmallVectorImpl<std::string> &builderLines) {
663 builderLines.push_back(Elt: "_ods_context = _ods_get_default_loc_context(loc)");
664 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
665 Argument arg = op.getArg(index: i);
666 auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg);
667 if (!attribute)
668 continue;
669
670 // Unit attributes are handled specially.
671 if (attribute->attr.getStorageType().trim().equals(RHS: "::mlir::UnitAttr")) {
672 builderLines.push_back(Elt: llvm::formatv(Fmt: initUnitAttributeTemplate,
673 Vals&: attribute->name, Vals: argNames[i]));
674 continue;
675 }
676
677 builderLines.push_back(Elt: llvm::formatv(
678 Fmt: attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
679 ? initOptionalAttributeWithBuilderTemplate
680 : initAttributeWithBuilderTemplate,
681 Vals: argNames[i], Vals&: attribute->name, Vals: attribute->attr.getAttrDefName()));
682 }
683}
684
685/// Populates `builderLines` with additional lines that are required in the
686/// builder to set up successors. successorArgNames is expected to correspond
687/// to the Python argument name for each successor on the op.
688static void populateBuilderLinesSuccessors(
689 const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
690 llvm::SmallVectorImpl<std::string> &builderLines) {
691 if (successorArgNames.empty()) {
692 builderLines.push_back(Elt: llvm::formatv(Fmt: initSuccessorsTemplate, Vals: "None"));
693 return;
694 }
695
696 builderLines.push_back(Elt: llvm::formatv(Fmt: initSuccessorsTemplate, Vals: "[]"));
697 for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
698 auto &argName = successorArgNames[i];
699 const NamedSuccessor &successor = op.getSuccessor(index: i);
700 builderLines.push_back(
701 Elt: llvm::formatv(Fmt: addSuccessorTemplate,
702 Vals: successor.isVariadic() ? "extend" : "append", Vals: argName));
703 }
704}
705
706/// Populates `builderLines` with additional lines that are required in the
707/// builder to set up op operands.
708static void
709populateBuilderLinesOperand(const Operator &op,
710 llvm::ArrayRef<std::string> names,
711 llvm::SmallVectorImpl<std::string> &builderLines) {
712 bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "operand")) != nullptr;
713
714 // For each element, find or generate a name.
715 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
716 const NamedTypeConstraint &element = op.getOperand(index: i);
717 std::string name = names[i];
718
719 // Choose the formatting string based on the element kind.
720 llvm::StringRef formatString;
721 if (!element.isVariableLength()) {
722 formatString = singleOperandAppendTemplate;
723 } else if (element.isOptional()) {
724 if (sizedSegments) {
725 formatString = optionalAppendAttrSizedOperandsTemplate;
726 } else {
727 formatString = optionalAppendOperandTemplate;
728 }
729 } else {
730 assert(element.isVariadic() && "unhandled element group type");
731 // If emitting with sizedSegments, then we add the actual list-typed
732 // element. Otherwise, we extend the actual operands.
733 if (sizedSegments) {
734 formatString = multiOperandAppendPackTemplate;
735 } else {
736 formatString = multiOperandAppendTemplate;
737 }
738 }
739
740 builderLines.push_back(Elt: llvm::formatv(Fmt: formatString.data(), Vals&: name));
741 }
742}
743
744/// Python code template for deriving the operation result types from its
745/// attribute:
746/// - {0} is the name of the attribute from which to derive the types.
747constexpr const char *deriveTypeFromAttrTemplate =
748 R"Py(_ods_result_type_source_attr = attributes["{0}"]
749_ods_derived_result_type = (
750 _ods_ir.TypeAttr(_ods_result_type_source_attr).value
751 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
752 _ods_result_type_source_attr.type))Py";
753
754/// Python code template appending {0} type {1} times to the results list.
755constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
756
757/// Appends the given multiline string as individual strings into
758/// `builderLines`.
759static void appendLineByLine(StringRef string,
760 llvm::SmallVectorImpl<std::string> &builderLines) {
761
762 std::pair<StringRef, StringRef> split = std::make_pair(x&: string, y&: string);
763 do {
764 split = split.second.split(Separator: '\n');
765 builderLines.push_back(Elt: split.first.str());
766 } while (!split.second.empty());
767}
768
769/// Populates `builderLines` with additional lines that are required in the
770/// builder to set up op results.
771static void
772populateBuilderLinesResult(const Operator &op,
773 llvm::ArrayRef<std::string> names,
774 llvm::SmallVectorImpl<std::string> &builderLines) {
775 bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "result")) != nullptr;
776
777 if (hasSameArgumentAndResultTypes(op)) {
778 builderLines.push_back(Elt: llvm::formatv(
779 Fmt: appendSameResultsTemplate, Vals: "operands[0].type", Vals: op.getNumResults()));
780 return;
781 }
782
783 if (hasFirstAttrDerivedResultTypes(op)) {
784 const NamedAttribute &firstAttr = op.getAttribute(index: 0);
785 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
786 "from which the type is derived");
787 appendLineByLine(
788 string: llvm::formatv(Fmt: deriveTypeFromAttrTemplate, Vals: firstAttr.name).str(),
789 builderLines);
790 builderLines.push_back(Elt: llvm::formatv(Fmt: appendSameResultsTemplate,
791 Vals: "_ods_derived_result_type",
792 Vals: op.getNumResults()));
793 return;
794 }
795
796 if (hasInferTypeInterface(op))
797 return;
798
799 // For each element, find or generate a name.
800 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
801 const NamedTypeConstraint &element = op.getResult(index: i);
802 std::string name = names[i];
803
804 // Choose the formatting string based on the element kind.
805 llvm::StringRef formatString;
806 if (!element.isVariableLength()) {
807 formatString = singleResultAppendTemplate;
808 } else if (element.isOptional()) {
809 formatString = optionalAppendResultTemplate;
810 } else {
811 assert(element.isVariadic() && "unhandled element group type");
812 // If emitting with sizedSegments, then we add the actual list-typed
813 // element. Otherwise, we extend the actual operands.
814 if (sizedSegments) {
815 formatString = singleResultAppendTemplate;
816 } else {
817 formatString = multiResultAppendTemplate;
818 }
819 }
820
821 builderLines.push_back(Elt: llvm::formatv(Fmt: formatString.data(), Vals&: name));
822 }
823}
824
825/// If the operation has variadic regions, adds a builder argument to specify
826/// the number of those regions and builder lines to forward it to the generic
827/// constructor.
828static void
829populateBuilderRegions(const Operator &op,
830 llvm::SmallVectorImpl<std::string> &builderArgs,
831 llvm::SmallVectorImpl<std::string> &builderLines) {
832 if (op.hasNoVariadicRegions())
833 return;
834
835 // This is currently enforced when Operator is constructed.
836 assert(op.getNumVariadicRegions() == 1 &&
837 op.getRegion(op.getNumRegions() - 1).isVariadic() &&
838 "expected the last region to be varidic");
839
840 const NamedRegion &region = op.getRegion(index: op.getNumRegions() - 1);
841 std::string name =
842 ("num_" + region.name.take_front().lower() + region.name.drop_front())
843 .str();
844 builderArgs.push_back(Elt: name);
845 builderLines.push_back(
846 Elt: llvm::formatv(Fmt: "regions = {0} + {1}", Vals: op.getNumRegions() - 1, Vals&: name));
847}
848
849/// Emits a default builder constructing an operation from the list of its
850/// result types, followed by a list of its operands. Returns vector
851/// of fully built functionArgs for downstream users (to save having to
852/// rebuild anew).
853static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
854 raw_ostream &os) {
855 llvm::SmallVector<std::string> builderArgs;
856 llvm::SmallVector<std::string> builderLines;
857 llvm::SmallVector<std::string> operandArgNames;
858 llvm::SmallVector<std::string> successorArgNames;
859 builderArgs.reserve(N: op.getNumOperands() + op.getNumResults() +
860 op.getNumNativeAttributes() + op.getNumSuccessors());
861 populateBuilderArgsResults(op, builderArgs);
862 size_t numResultArgs = builderArgs.size();
863 populateBuilderArgs(op, builderArgs, operandNames&: operandArgNames);
864 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
865 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
866
867 populateBuilderLinesOperand(op, names: operandArgNames, builderLines);
868 populateBuilderLinesAttr(
869 op, argNames: llvm::ArrayRef(builderArgs).drop_front(N: numResultArgs), builderLines);
870 populateBuilderLinesResult(
871 op, names: llvm::ArrayRef(builderArgs).take_front(N: numResultArgs), builderLines);
872 populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
873 populateBuilderRegions(op, builderArgs, builderLines);
874
875 // Layout of builderArgs vector elements:
876 // [ result_args operand_attr_args successor_args regions ]
877
878 // Determine whether the argument corresponding to a given index into the
879 // builderArgs vector is a python keyword argument or not.
880 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
881 // All result, successor, and region arguments are positional arguments.
882 if ((builderArgIndex < numResultArgs) ||
883 (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
884 return false;
885 // Keyword arguments:
886 // - optional named attributes (including unit attributes)
887 // - default-valued named attributes
888 // - optional operands
889 Argument a = op.getArg(index: builderArgIndex - numResultArgs);
890 if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: a))
891 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
892 if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: a))
893 return ntype->isOptional();
894 return false;
895 };
896
897 // StringRefs in functionArgs refer to strings allocated by builderArgs.
898 llvm::SmallVector<llvm::StringRef> functionArgs;
899
900 // Add positional arguments.
901 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
902 if (!isKeywordArgFn(i))
903 functionArgs.push_back(Elt: builderArgs[i]);
904 }
905
906 // Add a bare '*' to indicate that all following arguments must be keyword
907 // arguments.
908 functionArgs.push_back(Elt: "*");
909
910 // Add a default 'None' value to each keyword arg string, and then add to the
911 // function args list.
912 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
913 if (isKeywordArgFn(i)) {
914 builderArgs[i].append(s: "=None");
915 functionArgs.push_back(Elt: builderArgs[i]);
916 }
917 }
918 functionArgs.push_back(Elt: "loc=None");
919 functionArgs.push_back(Elt: "ip=None");
920
921 SmallVector<std::string> initArgs;
922 initArgs.push_back(Elt: "attributes=attributes");
923 if (!hasInferTypeInterface(op))
924 initArgs.push_back(Elt: "results=results");
925 initArgs.push_back(Elt: "operands=operands");
926 initArgs.push_back(Elt: "successors=_ods_successors");
927 initArgs.push_back(Elt: "regions=regions");
928 initArgs.push_back(Elt: "loc=loc");
929 initArgs.push_back(Elt: "ip=ip");
930
931 os << llvm::formatv(Fmt: initTemplate, Vals: llvm::join(R&: functionArgs, Separator: ", "),
932 Vals: llvm::join(R&: builderLines, Separator: "\n "),
933 Vals: llvm::join(R&: initArgs, Separator: ", "));
934 return llvm::to_vector<8>(
935 Range: llvm::map_range(C&: functionArgs, F: [](llvm::StringRef s) { return s.str(); }));
936}
937
938static void emitSegmentSpec(
939 const Operator &op, const char *kind,
940 llvm::function_ref<int(const Operator &)> getNumElements,
941 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
942 getElement,
943 raw_ostream &os) {
944 std::string segmentSpec("[");
945 for (int i = 0, e = getNumElements(op); i < e; ++i) {
946 const NamedTypeConstraint &element = getElement(op, i);
947 if (element.isOptional()) {
948 segmentSpec.append(s: "0,");
949 } else if (element.isVariadic()) {
950 segmentSpec.append(s: "-1,");
951 } else {
952 segmentSpec.append(s: "1,");
953 }
954 }
955 segmentSpec.append(s: "]");
956
957 os << llvm::formatv(Fmt: opClassSizedSegmentsTemplate, Vals&: kind, Vals&: segmentSpec);
958}
959
960static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
961 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
962 // Note that the base OpView class defines this as (0, True).
963 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
964 os << llvm::formatv(Fmt: opClassRegionSpecTemplate, Vals&: minRegionCount,
965 Vals: op.hasNoVariadicRegions() ? "True" : "False");
966}
967
968/// Emits named accessors to regions.
969static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
970 for (const auto &en : llvm::enumerate(First: op.getRegions())) {
971 const NamedRegion &region = en.value();
972 if (region.name.empty())
973 continue;
974
975 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
976 "expected only the last region to be variadic");
977 os << llvm::formatv(Fmt: regionAccessorTemplate, Vals: sanitizeName(name: region.name),
978 Vals: std::to_string(val: en.index()) +
979 (region.isVariadic() ? ":" : ""));
980 }
981}
982
983/// Emits builder that extracts results from op
984static void emitValueBuilder(const Operator &op,
985 llvm::SmallVector<std::string> functionArgs,
986 raw_ostream &os) {
987 // Params with (possibly) default args.
988 auto valueBuilderParams =
989 llvm::map_range(C&: functionArgs, F: [](const std::string &argAndMaybeDefault) {
990 llvm::SmallVector<llvm::StringRef> argMaybeDefault =
991 llvm::to_vector<2>(Range: llvm::split(Str: argAndMaybeDefault, Separator: "="));
992 auto arg = llvm::convertToSnakeFromCamelCase(input: argMaybeDefault[0]);
993 if (argMaybeDefault.size() == 2)
994 return arg + "=" + argMaybeDefault[1].str();
995 return arg;
996 });
997 // Actual args passed to op builder (e.g., opParam=op_param).
998 auto opBuilderArgs = llvm::map_range(
999 C: llvm::make_filter_range(Range&: functionArgs,
1000 Pred: [](const std::string &s) { return s != "*"; }),
1001 F: [](const std::string &arg) {
1002 auto lhs = *llvm::split(Str: arg, Separator: "=").begin();
1003 return (lhs + "=" + llvm::convertToSnakeFromCamelCase(input: lhs)).str();
1004 });
1005 std::string nameWithoutDialect =
1006 op.getOperationName().substr(pos: op.getOperationName().find(c: '.') + 1);
1007 os << llvm::formatv(
1008 Fmt: valueBuilderTemplate, Vals: sanitizeName(name: nameWithoutDialect),
1009 Vals: op.getCppClassName(), Vals: llvm::join(R&: valueBuilderParams, Separator: ", "),
1010 Vals: llvm::join(R&: opBuilderArgs, Separator: ", "),
1011 Vals: (op.getNumResults() > 1
1012 ? "_Sequence[_ods_ir.Value]"
1013 : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
1014}
1015
1016/// Emits bindings for a specific Op to the given output stream.
1017static void emitOpBindings(const Operator &op, raw_ostream &os) {
1018 os << llvm::formatv(Fmt: opClassTemplate, Vals: op.getCppClassName(),
1019 Vals: op.getOperationName());
1020
1021 // Sized segments.
1022 if (op.getTrait(trait: attrSizedTraitForKind(kind: "operand")) != nullptr) {
1023 emitSegmentSpec(op, kind: "OPERAND", getNumElements: getNumOperands, getElement: getOperand, os);
1024 }
1025 if (op.getTrait(trait: attrSizedTraitForKind(kind: "result")) != nullptr) {
1026 emitSegmentSpec(op, kind: "RESULT", getNumElements: getNumResults, getElement: getResult, os);
1027 }
1028
1029 emitRegionAttributes(op, os);
1030 llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
1031 emitOperandAccessors(op, os);
1032 emitAttributeAccessors(op, os);
1033 emitResultAccessors(op, os);
1034 emitRegionAccessors(op, os);
1035 emitValueBuilder(op, functionArgs, os);
1036}
1037
1038/// Emits bindings for the dialect specified in the command line, including file
1039/// headers and utilities. Returns `false` on success to comply with Tablegen
1040/// registration requirements.
1041static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1042 if (clDialectName.empty())
1043 llvm::PrintFatalError(Msg: "dialect name not provided");
1044
1045 os << fileHeader;
1046 if (!clDialectExtensionName.empty())
1047 os << llvm::formatv(Fmt: dialectExtensionTemplate, Vals&: clDialectName.getValue());
1048 else
1049 os << llvm::formatv(Fmt: dialectClassTemplate, Vals&: clDialectName.getValue());
1050
1051 for (const llvm::Record *rec : records.getAllDerivedDefinitions(ClassName: "Op")) {
1052 Operator op(rec);
1053 if (op.getDialectName() == clDialectName.getValue())
1054 emitOpBindings(op, os);
1055 }
1056 return false;
1057}
1058
1059static GenRegistration
1060 genPythonBindings("gen-python-op-bindings",
1061 "Generate Python bindings for MLIR Ops", &emitAllOps);
1062

source code of mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp