1 | //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python |
10 | // binding classes wrapping a generic operation API. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "OpGenHelpers.h" |
15 | |
16 | #include "mlir/TableGen/GenInfo.h" |
17 | #include "mlir/TableGen/Operator.h" |
18 | #include "llvm/ADT/StringSet.h" |
19 | #include "llvm/Support/CommandLine.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "llvm/TableGen/Error.h" |
22 | #include "llvm/TableGen/Record.h" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::tblgen; |
26 | using llvm::formatv; |
27 | using llvm::Record; |
28 | using llvm::RecordKeeper; |
29 | |
30 | /// File header and includes. |
31 | /// {0} is the dialect namespace. |
32 | constexpr const char * = R"Py( |
33 | # Autogenerated by mlir-tblgen; don't manually edit. |
34 | |
35 | from ._ods_common import _cext as _ods_cext |
36 | from ._ods_common import ( |
37 | equally_sized_accessor as _ods_equally_sized_accessor, |
38 | get_default_loc_context as _ods_get_default_loc_context, |
39 | get_op_result_or_op_results as _get_op_result_or_op_results, |
40 | get_op_results_or_values as _get_op_results_or_values, |
41 | segmented_accessor as _ods_segmented_accessor, |
42 | ) |
43 | _ods_ir = _ods_cext.ir |
44 | |
45 | import builtins |
46 | from typing import Sequence as _Sequence, Union as _Union |
47 | |
48 | )Py" ; |
49 | |
50 | /// Template for dialect class: |
51 | /// {0} is the dialect namespace. |
52 | constexpr const char *dialectClassTemplate = R"Py( |
53 | @_ods_cext.register_dialect |
54 | class _Dialect(_ods_ir.Dialect): |
55 | DIALECT_NAMESPACE = "{0}" |
56 | )Py" ; |
57 | |
58 | constexpr const char *dialectExtensionTemplate = R"Py( |
59 | from ._{0}_ops_gen import _Dialect |
60 | )Py" ; |
61 | |
62 | /// Template for operation class: |
63 | /// {0} is the Python class name; |
64 | /// {1} is the operation name. |
65 | constexpr const char *opClassTemplate = R"Py( |
66 | @_ods_cext.register_operation(_Dialect) |
67 | class {0}(_ods_ir.OpView): |
68 | OPERATION_NAME = "{1}" |
69 | )Py" ; |
70 | |
71 | /// Template for class level declarations of operand and result |
72 | /// segment specs. |
73 | /// {0} is either "OPERAND" or "RESULT" |
74 | /// {1} is the segment spec |
75 | /// Each segment spec is either None (default) or an array of integers |
76 | /// where: |
77 | /// 1 = single element (expect non sequence operand/result) |
78 | /// 0 = optional element (expect a value or std::nullopt) |
79 | /// -1 = operand/result is a sequence corresponding to a variadic |
80 | constexpr const char *opClassSizedSegmentsTemplate = R"Py( |
81 | _ODS_{0}_SEGMENTS = {1} |
82 | )Py" ; |
83 | |
84 | /// Template for class level declarations of the _ODS_REGIONS spec: |
85 | /// {0} is the minimum number of regions |
86 | /// {1} is the Python bool literal for hasNoVariadicRegions |
87 | constexpr const char *opClassRegionSpecTemplate = R"Py( |
88 | _ODS_REGIONS = ({0}, {1}) |
89 | )Py" ; |
90 | |
91 | /// Template for single-element accessor: |
92 | /// {0} is the name of the accessor; |
93 | /// {1} is either 'operand' or 'result'; |
94 | /// {2} is the position in the element list. |
95 | constexpr const char *opSingleTemplate = R"Py( |
96 | @builtins.property |
97 | def {0}(self): |
98 | return self.operation.{1}s[{2}] |
99 | )Py" ; |
100 | |
101 | /// Template for single-element accessor after a variable-length group: |
102 | /// {0} is the name of the accessor; |
103 | /// {1} is either 'operand' or 'result'; |
104 | /// {2} is the total number of element groups; |
105 | /// {3} is the position of the current group in the group list. |
106 | /// This works for both a single variadic group (non-negative length) and an |
107 | /// single optional element (zero length if the element is absent). |
108 | constexpr const char *opSingleAfterVariableTemplate = R"Py( |
109 | @builtins.property |
110 | def {0}(self): |
111 | _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
112 | return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] |
113 | )Py" ; |
114 | |
115 | /// Template for an optional element accessor: |
116 | /// {0} is the name of the accessor; |
117 | /// {1} is either 'operand' or 'result'; |
118 | /// {2} is the total number of element groups; |
119 | /// {3} is the position of the current group in the group list. |
120 | /// This works if we have only one variable-length group (and it's the optional |
121 | /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is |
122 | /// smaller than the total number of groups. |
123 | constexpr const char *opOneOptionalTemplate = R"Py( |
124 | @builtins.property |
125 | def {0}(self): |
126 | return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] |
127 | )Py" ; |
128 | |
129 | /// Template for the variadic group accessor in the single variadic group case: |
130 | /// {0} is the name of the accessor; |
131 | /// {1} is either 'operand' or 'result'; |
132 | /// {2} is the total number of element groups; |
133 | /// {3} is the position of the current group in the group list. |
134 | constexpr const char *opOneVariadicTemplate = R"Py( |
135 | @builtins.property |
136 | def {0}(self): |
137 | _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
138 | return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] |
139 | )Py" ; |
140 | |
141 | /// First part of the template for equally-sized variadic group accessor: |
142 | /// {0} is the name of the accessor; |
143 | /// {1} is either 'operand' or 'result'; |
144 | /// {2} is the total number of non-variadic groups; |
145 | /// {3} is the total number of variadic groups; |
146 | /// {4} is the number of non-variadic groups preceding the current group; |
147 | /// {5} is the number of variadic groups preceding the current group. |
148 | constexpr const char *opVariadicEqualPrefixTemplate = R"Py( |
149 | @builtins.property |
150 | def {0}(self): |
151 | start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py" ; |
152 | |
153 | /// Second part of the template for equally-sized case, accessing a single |
154 | /// element: |
155 | /// {0} is either 'operand' or 'result'. |
156 | constexpr const char *opVariadicEqualSimpleTemplate = R"Py( |
157 | return self.operation.{0}s[start] |
158 | )Py" ; |
159 | |
160 | /// Second part of the template for equally-sized case, accessing a variadic |
161 | /// group: |
162 | /// {0} is either 'operand' or 'result'. |
163 | constexpr const char *opVariadicEqualVariadicTemplate = R"Py( |
164 | return self.operation.{0}s[start:start + elements_per_group] |
165 | )Py" ; |
166 | |
167 | /// Template for an attribute-sized group accessor: |
168 | /// {0} is the name of the accessor; |
169 | /// {1} is either 'operand' or 'result'; |
170 | /// {2} is the position of the group in the group list; |
171 | /// {3} is a return suffix (expected [0] for single-element, empty for |
172 | /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). |
173 | constexpr const char *opVariadicSegmentTemplate = R"Py( |
174 | @builtins.property |
175 | def {0}(self): |
176 | {1}_range = _ods_segmented_accessor( |
177 | self.operation.{1}s, |
178 | self.operation.attributes["{1}SegmentSizes"], {2}) |
179 | return {1}_range{3} |
180 | )Py" ; |
181 | |
182 | /// Template for a suffix when accessing an optional element in the |
183 | /// attribute-sized case: |
184 | /// {0} is either 'operand' or 'result'; |
185 | constexpr const char *opVariadicSegmentOptionalTrailingTemplate = |
186 | R"Py([0] if len({0}_range) > 0 else None)Py" ; |
187 | |
188 | /// Template for an operation attribute getter: |
189 | /// {0} is the name of the attribute sanitized for Python; |
190 | /// {1} is the original name of the attribute. |
191 | constexpr const char *attributeGetterTemplate = R"Py( |
192 | @builtins.property |
193 | def {0}(self): |
194 | return self.operation.attributes["{1}"] |
195 | )Py" ; |
196 | |
197 | /// Template for an optional operation attribute getter: |
198 | /// {0} is the name of the attribute sanitized for Python; |
199 | /// {1} is the original name of the attribute. |
200 | constexpr const char *optionalAttributeGetterTemplate = R"Py( |
201 | @builtins.property |
202 | def {0}(self): |
203 | if "{1}" not in self.operation.attributes: |
204 | return None |
205 | return self.operation.attributes["{1}"] |
206 | )Py" ; |
207 | |
208 | /// Template for a getter of a unit operation attribute, returns True of the |
209 | /// unit attribute is present, False otherwise (unit attributes have meaning |
210 | /// by mere presence): |
211 | /// {0} is the name of the attribute sanitized for Python, |
212 | /// {1} is the original name of the attribute. |
213 | constexpr const char *unitAttributeGetterTemplate = R"Py( |
214 | @builtins.property |
215 | def {0}(self): |
216 | return "{1}" in self.operation.attributes |
217 | )Py" ; |
218 | |
219 | /// Template for an operation attribute setter: |
220 | /// {0} is the name of the attribute sanitized for Python; |
221 | /// {1} is the original name of the attribute. |
222 | constexpr const char *attributeSetterTemplate = R"Py( |
223 | @{0}.setter |
224 | def {0}(self, value): |
225 | if value is None: |
226 | raise ValueError("'None' not allowed as value for mandatory attributes") |
227 | self.operation.attributes["{1}"] = value |
228 | )Py" ; |
229 | |
230 | /// Template for a setter of an optional operation attribute, setting to None |
231 | /// removes the attribute: |
232 | /// {0} is the name of the attribute sanitized for Python; |
233 | /// {1} is the original name of the attribute. |
234 | constexpr const char *optionalAttributeSetterTemplate = R"Py( |
235 | @{0}.setter |
236 | def {0}(self, value): |
237 | if value is not None: |
238 | self.operation.attributes["{1}"] = value |
239 | elif "{1}" in self.operation.attributes: |
240 | del self.operation.attributes["{1}"] |
241 | )Py" ; |
242 | |
243 | /// Template for a setter of a unit operation attribute, setting to None or |
244 | /// False removes the attribute: |
245 | /// {0} is the name of the attribute sanitized for Python; |
246 | /// {1} is the original name of the attribute. |
247 | constexpr const char *unitAttributeSetterTemplate = R"Py( |
248 | @{0}.setter |
249 | def {0}(self, value): |
250 | if bool(value): |
251 | self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() |
252 | elif "{1}" in self.operation.attributes: |
253 | del self.operation.attributes["{1}"] |
254 | )Py" ; |
255 | |
256 | /// Template for a deleter of an optional or a unit operation attribute, removes |
257 | /// the attribute from the operation: |
258 | /// {0} is the name of the attribute sanitized for Python; |
259 | /// {1} is the original name of the attribute. |
260 | constexpr const char *attributeDeleterTemplate = R"Py( |
261 | @{0}.deleter |
262 | def {0}(self): |
263 | del self.operation.attributes["{1}"] |
264 | )Py" ; |
265 | |
266 | constexpr const char *regionAccessorTemplate = R"Py( |
267 | @builtins.property |
268 | def {0}(self): |
269 | return self.regions[{1}] |
270 | )Py" ; |
271 | |
272 | constexpr const char *valueBuilderTemplate = R"Py( |
273 | def {0}({2}) -> {4}: |
274 | return {1}({3}){5} |
275 | )Py" ; |
276 | |
277 | constexpr const char *valueBuilderVariadicTemplate = R"Py( |
278 | def {0}({2}) -> {4}: |
279 | return _get_op_result_or_op_results({1}({3})) |
280 | )Py" ; |
281 | |
282 | static llvm::cl::OptionCategory |
283 | clOpPythonBindingCat("Options for -gen-python-op-bindings" ); |
284 | |
285 | static llvm::cl::opt<std::string> |
286 | clDialectName("bind-dialect" , |
287 | llvm::cl::desc("The dialect to run the generator for" ), |
288 | llvm::cl::init(Val: "" ), llvm::cl::cat(clOpPythonBindingCat)); |
289 | |
290 | static llvm::cl::opt<std::string> clDialectExtensionName( |
291 | "dialect-extension" , llvm::cl::desc("The prefix of the dialect extension" ), |
292 | llvm::cl::init(Val: "" ), llvm::cl::cat(clOpPythonBindingCat)); |
293 | |
294 | using AttributeClasses = DenseMap<StringRef, StringRef>; |
295 | |
296 | /// Checks whether `str` would shadow a generated variable or attribute |
297 | /// part of the OpView API. |
298 | static bool isODSReserved(StringRef str) { |
299 | static llvm::StringSet<> reserved( |
300 | {"attributes" , "create" , "context" , "ip" , "operands" , "print" , "get_asm" , |
301 | "loc" , "verify" , "regions" , "results" , "self" , "operation" , |
302 | "DIALECT_NAMESPACE" , "OPERATION_NAME" }); |
303 | return str.starts_with(Prefix: "_ods_" ) || str.ends_with(Suffix: "_ods" ) || |
304 | reserved.contains(key: str); |
305 | } |
306 | |
307 | /// Modifies the `name` in a way that it becomes suitable for Python bindings |
308 | /// (does not change the `name` if it already is suitable) and returns the |
309 | /// modified version. |
310 | static std::string sanitizeName(StringRef name) { |
311 | std::string processedStr = name.str(); |
312 | std::replace_if( |
313 | first: processedStr.begin(), last: processedStr.end(), |
314 | pred: [](char c) { return !llvm::isAlnum(C: c); }, new_value: '_'); |
315 | |
316 | if (llvm::isDigit(C: *processedStr.begin())) |
317 | return "_" + processedStr; |
318 | |
319 | if (isPythonReserved(str: processedStr) || isODSReserved(str: processedStr)) |
320 | return processedStr + "_" ; |
321 | return processedStr; |
322 | } |
323 | |
324 | static std::string attrSizedTraitForKind(const char *kind) { |
325 | return formatv(Fmt: "::mlir::OpTrait::AttrSized{0}{1}Segments" , |
326 | Vals: StringRef(kind).take_front().upper(), |
327 | Vals: StringRef(kind).drop_front()); |
328 | } |
329 | |
330 | /// Emits accessors to "elements" of an Op definition. Currently, the supported |
331 | /// elements are operands and results, indicated by `kind`, which must be either |
332 | /// `operand` or `result` and is used verbatim in the emitted code. |
333 | static void emitElementAccessors( |
334 | const Operator &op, raw_ostream &os, const char *kind, |
335 | unsigned numVariadicGroups, unsigned numElements, |
336 | llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
337 | getElement) { |
338 | assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand" , "result" }, |
339 | kind) && |
340 | "unsupported kind" ); |
341 | |
342 | // Traits indicating how to process variadic elements. |
343 | std::string sameSizeTrait = formatv(Fmt: "::mlir::OpTrait::SameVariadic{0}{1}Size" , |
344 | Vals: StringRef(kind).take_front().upper(), |
345 | Vals: StringRef(kind).drop_front()); |
346 | std::string attrSizedTrait = attrSizedTraitForKind(kind); |
347 | |
348 | // If there is only one variable-length element group, its size can be |
349 | // inferred from the total number of elements. If there are none, the |
350 | // generation is straightforward. |
351 | if (numVariadicGroups <= 1) { |
352 | bool seenVariableLength = false; |
353 | for (unsigned i = 0; i < numElements; ++i) { |
354 | const NamedTypeConstraint &element = getElement(op, i); |
355 | if (element.isVariableLength()) |
356 | seenVariableLength = true; |
357 | if (element.name.empty()) |
358 | continue; |
359 | if (element.isVariableLength()) { |
360 | os << formatv(Fmt: element.isOptional() ? opOneOptionalTemplate |
361 | : opOneVariadicTemplate, |
362 | Vals: sanitizeName(name: element.name), Vals&: kind, Vals&: numElements, Vals&: i); |
363 | } else if (seenVariableLength) { |
364 | os << formatv(Fmt: opSingleAfterVariableTemplate, Vals: sanitizeName(name: element.name), |
365 | Vals&: kind, Vals&: numElements, Vals&: i); |
366 | } else { |
367 | os << formatv(Fmt: opSingleTemplate, Vals: sanitizeName(name: element.name), Vals&: kind, Vals&: i); |
368 | } |
369 | } |
370 | return; |
371 | } |
372 | |
373 | // Handle the operations where variadic groups have the same size. |
374 | if (op.getTrait(trait: sameSizeTrait)) { |
375 | // Count the number of simple elements |
376 | unsigned numSimpleLength = 0; |
377 | for (unsigned i = 0; i < numElements; ++i) { |
378 | const NamedTypeConstraint &element = getElement(op, i); |
379 | if (!element.isVariableLength()) { |
380 | ++numSimpleLength; |
381 | } |
382 | } |
383 | |
384 | // Generate the accessors |
385 | int numPrecedingSimple = 0; |
386 | int numPrecedingVariadic = 0; |
387 | for (unsigned i = 0; i < numElements; ++i) { |
388 | const NamedTypeConstraint &element = getElement(op, i); |
389 | if (!element.name.empty()) { |
390 | os << formatv(Fmt: opVariadicEqualPrefixTemplate, Vals: sanitizeName(name: element.name), |
391 | Vals&: kind, Vals&: numSimpleLength, Vals&: numVariadicGroups, |
392 | Vals&: numPrecedingSimple, Vals&: numPrecedingVariadic); |
393 | os << formatv(Fmt: element.isVariableLength() |
394 | ? opVariadicEqualVariadicTemplate |
395 | : opVariadicEqualSimpleTemplate, |
396 | Vals&: kind); |
397 | } |
398 | if (element.isVariableLength()) |
399 | ++numPrecedingVariadic; |
400 | else |
401 | ++numPrecedingSimple; |
402 | } |
403 | return; |
404 | } |
405 | |
406 | // Handle the operations where the size of groups (variadic or not) is |
407 | // provided as an attribute. For non-variadic elements, make sure to return |
408 | // an element rather than a singleton container. |
409 | if (op.getTrait(trait: attrSizedTrait)) { |
410 | for (unsigned i = 0; i < numElements; ++i) { |
411 | const NamedTypeConstraint &element = getElement(op, i); |
412 | if (element.name.empty()) |
413 | continue; |
414 | std::string trailing; |
415 | if (!element.isVariableLength()) |
416 | trailing = "[0]" ; |
417 | else if (element.isOptional()) |
418 | trailing = std::string( |
419 | formatv(Fmt: opVariadicSegmentOptionalTrailingTemplate, Vals&: kind)); |
420 | os << formatv(Fmt: opVariadicSegmentTemplate, Vals: sanitizeName(name: element.name), Vals&: kind, |
421 | Vals&: i, Vals&: trailing); |
422 | } |
423 | return; |
424 | } |
425 | |
426 | llvm::PrintFatalError(Msg: "unsupported " + llvm::Twine(kind) + " structure" ); |
427 | } |
428 | |
429 | /// Free function helpers accessing Operator components. |
430 | static int getNumOperands(const Operator &op) { return op.getNumOperands(); } |
431 | static const NamedTypeConstraint &getOperand(const Operator &op, int i) { |
432 | return op.getOperand(index: i); |
433 | } |
434 | static int getNumResults(const Operator &op) { return op.getNumResults(); } |
435 | static const NamedTypeConstraint &getResult(const Operator &op, int i) { |
436 | return op.getResult(index: i); |
437 | } |
438 | |
439 | /// Emits accessors to Op operands. |
440 | static void emitOperandAccessors(const Operator &op, raw_ostream &os) { |
441 | emitElementAccessors(op, os, kind: "operand" , numVariadicGroups: op.getNumVariableLengthOperands(), |
442 | numElements: getNumOperands(op), getElement: getOperand); |
443 | } |
444 | |
445 | /// Emits accessors Op results. |
446 | static void emitResultAccessors(const Operator &op, raw_ostream &os) { |
447 | emitElementAccessors(op, os, kind: "result" , numVariadicGroups: op.getNumVariableLengthResults(), |
448 | numElements: getNumResults(op), getElement: getResult); |
449 | } |
450 | |
451 | /// Emits accessors to Op attributes. |
452 | static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { |
453 | for (const auto &namedAttr : op.getAttributes()) { |
454 | // Skip "derived" attributes because they are just C++ functions that we |
455 | // don't currently expose. |
456 | if (namedAttr.attr.isDerivedAttr()) |
457 | continue; |
458 | |
459 | if (namedAttr.name.empty()) |
460 | continue; |
461 | |
462 | std::string sanitizedName = sanitizeName(name: namedAttr.name); |
463 | |
464 | // Unit attributes are handled specially. |
465 | if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr" ) { |
466 | os << formatv(Fmt: unitAttributeGetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
467 | os << formatv(Fmt: unitAttributeSetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
468 | os << formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
469 | continue; |
470 | } |
471 | |
472 | if (namedAttr.attr.isOptional()) { |
473 | os << formatv(Fmt: optionalAttributeGetterTemplate, Vals&: sanitizedName, |
474 | Vals: namedAttr.name); |
475 | os << formatv(Fmt: optionalAttributeSetterTemplate, Vals&: sanitizedName, |
476 | Vals: namedAttr.name); |
477 | os << formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
478 | } else { |
479 | os << formatv(Fmt: attributeGetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
480 | os << formatv(Fmt: attributeSetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
481 | // Non-optional attributes cannot be deleted. |
482 | } |
483 | } |
484 | } |
485 | |
486 | /// Template for the default auto-generated builder. |
487 | /// {0} is a comma-separated list of builder arguments, including the trailing |
488 | /// `loc` and `ip`; |
489 | /// {1} is the code populating `operands`, `results` and `attributes`, |
490 | /// `successors` fields. |
491 | constexpr const char *initTemplate = R"Py( |
492 | def __init__(self, {0}): |
493 | operands = [] |
494 | results = [] |
495 | attributes = {{} |
496 | regions = None |
497 | {1} |
498 | super().__init__({2}) |
499 | )Py" ; |
500 | |
501 | /// Template for appending a single element to the operand/result list. |
502 | /// {0} is the field name. |
503 | constexpr const char *singleOperandAppendTemplate = "operands.append({0})" ; |
504 | constexpr const char *singleResultAppendTemplate = "results.append({0})" ; |
505 | |
506 | /// Template for appending an optional element to the operand/result list. |
507 | /// {0} is the field name. |
508 | constexpr const char *optionalAppendOperandTemplate = |
509 | "if {0} is not None: operands.append({0})" ; |
510 | constexpr const char *optionalAppendAttrSizedOperandsTemplate = |
511 | "operands.append({0})" ; |
512 | constexpr const char *optionalAppendResultTemplate = |
513 | "if {0} is not None: results.append({0})" ; |
514 | |
515 | /// Template for appending a list of elements to the operand/result list. |
516 | /// {0} is the field name. |
517 | constexpr const char *multiOperandAppendTemplate = |
518 | "operands.extend(_get_op_results_or_values({0}))" ; |
519 | constexpr const char *multiOperandAppendPackTemplate = |
520 | "operands.append(_get_op_results_or_values({0}))" ; |
521 | constexpr const char *multiResultAppendTemplate = "results.extend({0})" ; |
522 | |
523 | /// Template for attribute builder from raw input in the operation builder. |
524 | /// {0} is the builder argument name; |
525 | /// {1} is the attribute builder from raw; |
526 | /// {2} is the attribute builder from raw. |
527 | /// Use the value the user passed in if either it is already an Attribute or |
528 | /// there is no method registered to make it an Attribute. |
529 | constexpr const char *initAttributeWithBuilderTemplate = |
530 | R"Py(attributes["{1}"] = ({0} if ( |
531 | isinstance({0}, _ods_ir.Attribute) or |
532 | not _ods_ir.AttrBuilder.contains('{2}')) else |
533 | _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py" ; |
534 | |
535 | /// Template for attribute builder from raw input for optional attribute in the |
536 | /// operation builder. |
537 | /// {0} is the builder argument name; |
538 | /// {1} is the attribute builder from raw; |
539 | /// {2} is the attribute builder from raw. |
540 | /// Use the value the user passed in if either it is already an Attribute or |
541 | /// there is no method registered to make it an Attribute. |
542 | constexpr const char *initOptionalAttributeWithBuilderTemplate = |
543 | R"Py(if {0} is not None: attributes["{1}"] = ({0} if ( |
544 | isinstance({0}, _ods_ir.Attribute) or |
545 | not _ods_ir.AttrBuilder.contains('{2}')) else |
546 | _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py" ; |
547 | |
548 | constexpr const char *initUnitAttributeTemplate = |
549 | R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( |
550 | _ods_get_default_loc_context(loc)))Py" ; |
551 | |
552 | /// Template to initialize the successors list in the builder if there are any |
553 | /// successors. |
554 | /// {0} is the value to initialize the successors list to. |
555 | constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py" ; |
556 | |
557 | /// Template to append or extend the list of successors in the builder. |
558 | /// {0} is the list method ('append' or 'extend'); |
559 | /// {1} is the value to add. |
560 | constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py" ; |
561 | |
562 | /// Returns true if the SameArgumentAndResultTypes trait can be used to infer |
563 | /// result types of the given operation. |
564 | static bool hasSameArgumentAndResultTypes(const Operator &op) { |
565 | return op.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType" ) && |
566 | op.getNumVariableLengthResults() == 0; |
567 | } |
568 | |
569 | /// Returns true if the FirstAttrDerivedResultType trait can be used to infer |
570 | /// result types of the given operation. |
571 | static bool hasFirstAttrDerivedResultTypes(const Operator &op) { |
572 | return op.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType" ) && |
573 | op.getNumVariableLengthResults() == 0; |
574 | } |
575 | |
576 | /// Returns true if the InferTypeOpInterface can be used to infer result types |
577 | /// of the given operation. |
578 | static bool hasInferTypeInterface(const Operator &op) { |
579 | return op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait" ) && |
580 | op.getNumRegions() == 0; |
581 | } |
582 | |
583 | /// Returns true if there is a trait or interface that can be used to infer |
584 | /// result types of the given operation. |
585 | static bool canInferType(const Operator &op) { |
586 | return hasSameArgumentAndResultTypes(op) || |
587 | hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); |
588 | } |
589 | |
590 | /// Populates `builderArgs` with result names if the builder is expected to |
591 | /// accept them as arguments. |
592 | static void |
593 | populateBuilderArgsResults(const Operator &op, |
594 | SmallVectorImpl<std::string> &builderArgs) { |
595 | if (canInferType(op)) |
596 | return; |
597 | |
598 | for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
599 | std::string name = op.getResultName(index: i).str(); |
600 | if (name.empty()) { |
601 | if (op.getNumResults() == 1) { |
602 | // Special case for one result, make the default name be 'result' |
603 | // to properly match the built-in result accessor. |
604 | name = "result" ; |
605 | } else { |
606 | name = formatv(Fmt: "_gen_res_{0}" , Vals&: i); |
607 | } |
608 | } |
609 | name = sanitizeName(name); |
610 | builderArgs.push_back(Elt: name); |
611 | } |
612 | } |
613 | |
614 | /// Populates `builderArgs` with the Python-compatible names of builder function |
615 | /// arguments using intermixed attributes and operands in the same order as they |
616 | /// appear in the `arguments` field of the op definition. Additionally, |
617 | /// `operandNames` is populated with names of operands in their order of |
618 | /// appearance. |
619 | static void populateBuilderArgs(const Operator &op, |
620 | SmallVectorImpl<std::string> &builderArgs, |
621 | SmallVectorImpl<std::string> &operandNames) { |
622 | for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
623 | std::string name = op.getArgName(index: i).str(); |
624 | if (name.empty()) |
625 | name = formatv(Fmt: "_gen_arg_{0}" , Vals&: i); |
626 | name = sanitizeName(name); |
627 | builderArgs.push_back(Elt: name); |
628 | if (!isa<NamedAttribute *>(Val: op.getArg(index: i))) |
629 | operandNames.push_back(Elt: name); |
630 | } |
631 | } |
632 | |
633 | /// Populates `builderArgs` with the Python-compatible names of builder function |
634 | /// successor arguments. Additionally, `successorArgNames` is also populated. |
635 | static void |
636 | populateBuilderArgsSuccessors(const Operator &op, |
637 | SmallVectorImpl<std::string> &builderArgs, |
638 | SmallVectorImpl<std::string> &successorArgNames) { |
639 | |
640 | for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { |
641 | NamedSuccessor successor = op.getSuccessor(index: i); |
642 | std::string name = std::string(successor.name); |
643 | if (name.empty()) |
644 | name = formatv(Fmt: "_gen_successor_{0}" , Vals&: i); |
645 | name = sanitizeName(name); |
646 | builderArgs.push_back(Elt: name); |
647 | successorArgNames.push_back(Elt: name); |
648 | } |
649 | } |
650 | |
651 | /// Populates `builderLines` with additional lines that are required in the |
652 | /// builder to set up operation attributes. `argNames` is expected to contain |
653 | /// the names of builder arguments that correspond to op arguments, i.e. to the |
654 | /// operands and attributes in the same order as they appear in the `arguments` |
655 | /// field. |
656 | static void |
657 | populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames, |
658 | SmallVectorImpl<std::string> &builderLines) { |
659 | builderLines.push_back(Elt: "_ods_context = _ods_get_default_loc_context(loc)" ); |
660 | for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
661 | Argument arg = op.getArg(index: i); |
662 | auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg); |
663 | if (!attribute) |
664 | continue; |
665 | |
666 | // Unit attributes are handled specially. |
667 | if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr" ) { |
668 | builderLines.push_back( |
669 | Elt: formatv(Fmt: initUnitAttributeTemplate, Vals&: attribute->name, Vals: argNames[i])); |
670 | continue; |
671 | } |
672 | |
673 | builderLines.push_back(Elt: formatv( |
674 | Fmt: attribute->attr.isOptional() || attribute->attr.hasDefaultValue() |
675 | ? initOptionalAttributeWithBuilderTemplate |
676 | : initAttributeWithBuilderTemplate, |
677 | Vals: argNames[i], Vals&: attribute->name, Vals: attribute->attr.getAttrDefName())); |
678 | } |
679 | } |
680 | |
681 | /// Populates `builderLines` with additional lines that are required in the |
682 | /// builder to set up successors. successorArgNames is expected to correspond |
683 | /// to the Python argument name for each successor on the op. |
684 | static void |
685 | populateBuilderLinesSuccessors(const Operator &op, |
686 | ArrayRef<std::string> successorArgNames, |
687 | SmallVectorImpl<std::string> &builderLines) { |
688 | if (successorArgNames.empty()) { |
689 | builderLines.push_back(Elt: formatv(Fmt: initSuccessorsTemplate, Vals: "None" )); |
690 | return; |
691 | } |
692 | |
693 | builderLines.push_back(Elt: formatv(Fmt: initSuccessorsTemplate, Vals: "[]" )); |
694 | for (int i = 0, e = successorArgNames.size(); i < e; ++i) { |
695 | auto &argName = successorArgNames[i]; |
696 | const NamedSuccessor &successor = op.getSuccessor(index: i); |
697 | builderLines.push_back(Elt: formatv(Fmt: addSuccessorTemplate, |
698 | Vals: successor.isVariadic() ? "extend" : "append" , |
699 | Vals: argName)); |
700 | } |
701 | } |
702 | |
703 | /// Populates `builderLines` with additional lines that are required in the |
704 | /// builder to set up op operands. |
705 | static void |
706 | populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names, |
707 | SmallVectorImpl<std::string> &builderLines) { |
708 | bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "operand" )) != nullptr; |
709 | |
710 | // For each element, find or generate a name. |
711 | for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
712 | const NamedTypeConstraint &element = op.getOperand(index: i); |
713 | std::string name = names[i]; |
714 | |
715 | // Choose the formatting string based on the element kind. |
716 | StringRef formatString; |
717 | if (!element.isVariableLength()) { |
718 | formatString = singleOperandAppendTemplate; |
719 | } else if (element.isOptional()) { |
720 | if (sizedSegments) { |
721 | formatString = optionalAppendAttrSizedOperandsTemplate; |
722 | } else { |
723 | formatString = optionalAppendOperandTemplate; |
724 | } |
725 | } else { |
726 | assert(element.isVariadic() && "unhandled element group type" ); |
727 | // If emitting with sizedSegments, then we add the actual list-typed |
728 | // element. Otherwise, we extend the actual operands. |
729 | if (sizedSegments) { |
730 | formatString = multiOperandAppendPackTemplate; |
731 | } else { |
732 | formatString = multiOperandAppendTemplate; |
733 | } |
734 | } |
735 | |
736 | builderLines.push_back(Elt: formatv(Fmt: formatString.data(), Vals&: name)); |
737 | } |
738 | } |
739 | |
740 | /// Python code template for deriving the operation result types from its |
741 | /// attribute: |
742 | /// - {0} is the name of the attribute from which to derive the types. |
743 | constexpr const char *deriveTypeFromAttrTemplate = |
744 | R"Py(_ods_result_type_source_attr = attributes["{0}"] |
745 | _ods_derived_result_type = ( |
746 | _ods_ir.TypeAttr(_ods_result_type_source_attr).value |
747 | if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else |
748 | _ods_result_type_source_attr.type))Py" ; |
749 | |
750 | /// Python code template appending {0} type {1} times to the results list. |
751 | constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})" ; |
752 | |
753 | /// Appends the given multiline string as individual strings into |
754 | /// `builderLines`. |
755 | static void appendLineByLine(StringRef string, |
756 | SmallVectorImpl<std::string> &builderLines) { |
757 | |
758 | std::pair<StringRef, StringRef> split = std::make_pair(x&: string, y&: string); |
759 | do { |
760 | split = split.second.split(Separator: '\n'); |
761 | builderLines.push_back(Elt: split.first.str()); |
762 | } while (!split.second.empty()); |
763 | } |
764 | |
765 | /// Populates `builderLines` with additional lines that are required in the |
766 | /// builder to set up op results. |
767 | static void |
768 | populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names, |
769 | SmallVectorImpl<std::string> &builderLines) { |
770 | bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "result" )) != nullptr; |
771 | |
772 | if (hasSameArgumentAndResultTypes(op)) { |
773 | builderLines.push_back(Elt: formatv(Fmt: appendSameResultsTemplate, |
774 | Vals: "operands[0].type" , Vals: op.getNumResults())); |
775 | return; |
776 | } |
777 | |
778 | if (hasFirstAttrDerivedResultTypes(op)) { |
779 | const NamedAttribute &firstAttr = op.getAttribute(index: 0); |
780 | assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " |
781 | "from which the type is derived" ); |
782 | appendLineByLine(string: formatv(Fmt: deriveTypeFromAttrTemplate, Vals: firstAttr.name).str(), |
783 | builderLines); |
784 | builderLines.push_back(Elt: formatv(Fmt: appendSameResultsTemplate, |
785 | Vals: "_ods_derived_result_type" , |
786 | Vals: op.getNumResults())); |
787 | return; |
788 | } |
789 | |
790 | if (hasInferTypeInterface(op)) |
791 | return; |
792 | |
793 | // For each element, find or generate a name. |
794 | for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
795 | const NamedTypeConstraint &element = op.getResult(index: i); |
796 | std::string name = names[i]; |
797 | |
798 | // Choose the formatting string based on the element kind. |
799 | StringRef formatString; |
800 | if (!element.isVariableLength()) { |
801 | formatString = singleResultAppendTemplate; |
802 | } else if (element.isOptional()) { |
803 | formatString = optionalAppendResultTemplate; |
804 | } else { |
805 | assert(element.isVariadic() && "unhandled element group type" ); |
806 | // If emitting with sizedSegments, then we add the actual list-typed |
807 | // element. Otherwise, we extend the actual operands. |
808 | if (sizedSegments) { |
809 | formatString = singleResultAppendTemplate; |
810 | } else { |
811 | formatString = multiResultAppendTemplate; |
812 | } |
813 | } |
814 | |
815 | builderLines.push_back(Elt: formatv(Fmt: formatString.data(), Vals&: name)); |
816 | } |
817 | } |
818 | |
819 | /// If the operation has variadic regions, adds a builder argument to specify |
820 | /// the number of those regions and builder lines to forward it to the generic |
821 | /// constructor. |
822 | static void populateBuilderRegions(const Operator &op, |
823 | SmallVectorImpl<std::string> &builderArgs, |
824 | SmallVectorImpl<std::string> &builderLines) { |
825 | if (op.hasNoVariadicRegions()) |
826 | return; |
827 | |
828 | // This is currently enforced when Operator is constructed. |
829 | assert(op.getNumVariadicRegions() == 1 && |
830 | op.getRegion(op.getNumRegions() - 1).isVariadic() && |
831 | "expected the last region to be varidic" ); |
832 | |
833 | const NamedRegion ®ion = op.getRegion(index: op.getNumRegions() - 1); |
834 | std::string name = |
835 | ("num_" + region.name.take_front().lower() + region.name.drop_front()) |
836 | .str(); |
837 | builderArgs.push_back(Elt: name); |
838 | builderLines.push_back( |
839 | Elt: formatv(Fmt: "regions = {0} + {1}" , Vals: op.getNumRegions() - 1, Vals&: name)); |
840 | } |
841 | |
842 | /// Emits a default builder constructing an operation from the list of its |
843 | /// result types, followed by a list of its operands. Returns vector |
844 | /// of fully built functionArgs for downstream users (to save having to |
845 | /// rebuild anew). |
846 | static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op, |
847 | raw_ostream &os) { |
848 | SmallVector<std::string> builderArgs; |
849 | SmallVector<std::string> builderLines; |
850 | SmallVector<std::string> operandArgNames; |
851 | SmallVector<std::string> successorArgNames; |
852 | builderArgs.reserve(N: op.getNumOperands() + op.getNumResults() + |
853 | op.getNumNativeAttributes() + op.getNumSuccessors()); |
854 | populateBuilderArgsResults(op, builderArgs); |
855 | size_t numResultArgs = builderArgs.size(); |
856 | populateBuilderArgs(op, builderArgs, operandNames&: operandArgNames); |
857 | size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; |
858 | populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); |
859 | |
860 | populateBuilderLinesOperand(op, names: operandArgNames, builderLines); |
861 | populateBuilderLinesAttr(op, argNames: ArrayRef(builderArgs).drop_front(N: numResultArgs), |
862 | builderLines); |
863 | populateBuilderLinesResult( |
864 | op, names: ArrayRef(builderArgs).take_front(N: numResultArgs), builderLines); |
865 | populateBuilderLinesSuccessors(op, successorArgNames, builderLines); |
866 | populateBuilderRegions(op, builderArgs, builderLines); |
867 | |
868 | // Layout of builderArgs vector elements: |
869 | // [ result_args operand_attr_args successor_args regions ] |
870 | |
871 | // Determine whether the argument corresponding to a given index into the |
872 | // builderArgs vector is a python keyword argument or not. |
873 | auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { |
874 | // All result, successor, and region arguments are positional arguments. |
875 | if ((builderArgIndex < numResultArgs) || |
876 | (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) |
877 | return false; |
878 | // Keyword arguments: |
879 | // - optional named attributes (including unit attributes) |
880 | // - default-valued named attributes |
881 | // - optional operands |
882 | Argument a = op.getArg(index: builderArgIndex - numResultArgs); |
883 | if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: a)) |
884 | return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); |
885 | if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: a)) |
886 | return ntype->isOptional(); |
887 | return false; |
888 | }; |
889 | |
890 | // StringRefs in functionArgs refer to strings allocated by builderArgs. |
891 | SmallVector<StringRef> functionArgs; |
892 | |
893 | // Add positional arguments. |
894 | for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { |
895 | if (!isKeywordArgFn(i)) |
896 | functionArgs.push_back(Elt: builderArgs[i]); |
897 | } |
898 | |
899 | // Add a bare '*' to indicate that all following arguments must be keyword |
900 | // arguments. |
901 | functionArgs.push_back(Elt: "*" ); |
902 | |
903 | // Add a default 'None' value to each keyword arg string, and then add to the |
904 | // function args list. |
905 | for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { |
906 | if (isKeywordArgFn(i)) { |
907 | builderArgs[i].append(s: "=None" ); |
908 | functionArgs.push_back(Elt: builderArgs[i]); |
909 | } |
910 | } |
911 | functionArgs.push_back(Elt: "loc=None" ); |
912 | functionArgs.push_back(Elt: "ip=None" ); |
913 | |
914 | SmallVector<std::string> initArgs; |
915 | initArgs.push_back(Elt: "self.OPERATION_NAME" ); |
916 | initArgs.push_back(Elt: "self._ODS_REGIONS" ); |
917 | initArgs.push_back(Elt: "self._ODS_OPERAND_SEGMENTS" ); |
918 | initArgs.push_back(Elt: "self._ODS_RESULT_SEGMENTS" ); |
919 | initArgs.push_back(Elt: "attributes=attributes" ); |
920 | if (!hasInferTypeInterface(op)) |
921 | initArgs.push_back(Elt: "results=results" ); |
922 | initArgs.push_back(Elt: "operands=operands" ); |
923 | initArgs.push_back(Elt: "successors=_ods_successors" ); |
924 | initArgs.push_back(Elt: "regions=regions" ); |
925 | initArgs.push_back(Elt: "loc=loc" ); |
926 | initArgs.push_back(Elt: "ip=ip" ); |
927 | |
928 | os << formatv(Fmt: initTemplate, Vals: llvm::join(R&: functionArgs, Separator: ", " ), |
929 | Vals: llvm::join(R&: builderLines, Separator: "\n " ), Vals: llvm::join(R&: initArgs, Separator: ", " )); |
930 | return llvm::to_vector<8>( |
931 | Range: llvm::map_range(C&: functionArgs, F: [](StringRef s) { return s.str(); })); |
932 | } |
933 | |
934 | static void emitSegmentSpec( |
935 | const Operator &op, const char *kind, |
936 | llvm::function_ref<int(const Operator &)> getNumElements, |
937 | llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
938 | getElement, |
939 | raw_ostream &os) { |
940 | std::string segmentSpec("[" ); |
941 | for (int i = 0, e = getNumElements(op); i < e; ++i) { |
942 | const NamedTypeConstraint &element = getElement(op, i); |
943 | if (element.isOptional()) { |
944 | segmentSpec.append(s: "0," ); |
945 | } else if (element.isVariadic()) { |
946 | segmentSpec.append(s: "-1," ); |
947 | } else { |
948 | segmentSpec.append(s: "1," ); |
949 | } |
950 | } |
951 | segmentSpec.append(s: "]" ); |
952 | |
953 | os << formatv(Fmt: opClassSizedSegmentsTemplate, Vals&: kind, Vals&: segmentSpec); |
954 | } |
955 | |
956 | static void emitRegionAttributes(const Operator &op, raw_ostream &os) { |
957 | // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). |
958 | // Note that the base OpView class defines this as (0, True). |
959 | unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); |
960 | os << formatv(Fmt: opClassRegionSpecTemplate, Vals&: minRegionCount, |
961 | Vals: op.hasNoVariadicRegions() ? "True" : "False" ); |
962 | } |
963 | |
964 | /// Emits named accessors to regions. |
965 | static void emitRegionAccessors(const Operator &op, raw_ostream &os) { |
966 | for (const auto &en : llvm::enumerate(First: op.getRegions())) { |
967 | const NamedRegion ®ion = en.value(); |
968 | if (region.name.empty()) |
969 | continue; |
970 | |
971 | assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && |
972 | "expected only the last region to be variadic" ); |
973 | os << formatv(Fmt: regionAccessorTemplate, Vals: sanitizeName(name: region.name), |
974 | Vals: std::to_string(val: en.index()) + |
975 | (region.isVariadic() ? ":" : "" )); |
976 | } |
977 | } |
978 | |
979 | /// Emits builder that extracts results from op |
980 | static void emitValueBuilder(const Operator &op, |
981 | SmallVector<std::string> functionArgs, |
982 | raw_ostream &os) { |
983 | // Params with (possibly) default args. |
984 | auto valueBuilderParams = |
985 | llvm::map_range(C&: functionArgs, F: [](const std::string &argAndMaybeDefault) { |
986 | SmallVector<StringRef> argMaybeDefault = |
987 | llvm::to_vector<2>(Range: llvm::split(Str: argAndMaybeDefault, Separator: "=" )); |
988 | auto arg = llvm::convertToSnakeFromCamelCase(input: argMaybeDefault[0]); |
989 | if (argMaybeDefault.size() == 2) |
990 | return arg + "=" + argMaybeDefault[1].str(); |
991 | return arg; |
992 | }); |
993 | // Actual args passed to op builder (e.g., opParam=op_param). |
994 | auto opBuilderArgs = llvm::map_range( |
995 | C: llvm::make_filter_range(Range&: functionArgs, |
996 | Pred: [](const std::string &s) { return s != "*" ; }), |
997 | F: [](const std::string &arg) { |
998 | auto lhs = *llvm::split(Str: arg, Separator: "=" ).begin(); |
999 | return (lhs + "=" + llvm::convertToSnakeFromCamelCase(input: lhs)).str(); |
1000 | }); |
1001 | std::string nameWithoutDialect = sanitizeName( |
1002 | name: op.getOperationName().substr(pos: op.getOperationName().find(c: '.') + 1)); |
1003 | if (nameWithoutDialect == op.getCppClassName()) |
1004 | nameWithoutDialect += "_" ; |
1005 | std::string params = llvm::join(R&: valueBuilderParams, Separator: ", " ); |
1006 | std::string args = llvm::join(R&: opBuilderArgs, Separator: ", " ); |
1007 | const char *type = |
1008 | (op.getNumResults() > 1 |
1009 | ? "_Sequence[_ods_ir.Value]" |
1010 | : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation" )); |
1011 | if (op.getNumVariableLengthResults() > 0) { |
1012 | os << formatv(Fmt: valueBuilderVariadicTemplate, Vals&: nameWithoutDialect, |
1013 | Vals: op.getCppClassName(), Vals&: params, Vals&: args, Vals&: type); |
1014 | } else { |
1015 | const char *results; |
1016 | if (op.getNumResults() == 0) { |
1017 | results = "" ; |
1018 | } else if (op.getNumResults() == 1) { |
1019 | results = ".result" ; |
1020 | } else { |
1021 | results = ".results" ; |
1022 | } |
1023 | os << formatv(Fmt: valueBuilderTemplate, Vals&: nameWithoutDialect, |
1024 | Vals: op.getCppClassName(), Vals&: params, Vals&: args, Vals&: type, Vals&: results); |
1025 | } |
1026 | } |
1027 | |
1028 | /// Emits bindings for a specific Op to the given output stream. |
1029 | static void emitOpBindings(const Operator &op, raw_ostream &os) { |
1030 | os << formatv(Fmt: opClassTemplate, Vals: op.getCppClassName(), Vals: op.getOperationName()); |
1031 | |
1032 | // Sized segments. |
1033 | if (op.getTrait(trait: attrSizedTraitForKind(kind: "operand" )) != nullptr) { |
1034 | emitSegmentSpec(op, kind: "OPERAND" , getNumElements: getNumOperands, getElement: getOperand, os); |
1035 | } |
1036 | if (op.getTrait(trait: attrSizedTraitForKind(kind: "result" )) != nullptr) { |
1037 | emitSegmentSpec(op, kind: "RESULT" , getNumElements: getNumResults, getElement: getResult, os); |
1038 | } |
1039 | |
1040 | emitRegionAttributes(op, os); |
1041 | SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os); |
1042 | emitOperandAccessors(op, os); |
1043 | emitAttributeAccessors(op, os); |
1044 | emitResultAccessors(op, os); |
1045 | emitRegionAccessors(op, os); |
1046 | emitValueBuilder(op, functionArgs, os); |
1047 | } |
1048 | |
1049 | /// Emits bindings for the dialect specified in the command line, including file |
1050 | /// headers and utilities. Returns `false` on success to comply with Tablegen |
1051 | /// registration requirements. |
1052 | static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) { |
1053 | if (clDialectName.empty()) |
1054 | llvm::PrintFatalError(Msg: "dialect name not provided" ); |
1055 | |
1056 | os << fileHeader; |
1057 | if (!clDialectExtensionName.empty()) |
1058 | os << formatv(Fmt: dialectExtensionTemplate, Vals&: clDialectName.getValue()); |
1059 | else |
1060 | os << formatv(Fmt: dialectClassTemplate, Vals&: clDialectName.getValue()); |
1061 | |
1062 | for (const Record *rec : records.getAllDerivedDefinitions(ClassName: "Op" )) { |
1063 | Operator op(rec); |
1064 | if (op.getDialectName() == clDialectName.getValue()) |
1065 | emitOpBindings(op, os); |
1066 | } |
1067 | return false; |
1068 | } |
1069 | |
1070 | static GenRegistration |
1071 | genPythonBindings("gen-python-op-bindings" , |
1072 | "Generate Python bindings for MLIR Ops" , &emitAllOps); |
1073 | |