1 | //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python |
10 | // binding classes wrapping a generic operation API. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "OpGenHelpers.h" |
15 | |
16 | #include "mlir/TableGen/GenInfo.h" |
17 | #include "mlir/TableGen/Operator.h" |
18 | #include "llvm/ADT/StringSet.h" |
19 | #include "llvm/Support/CommandLine.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "llvm/TableGen/Error.h" |
22 | #include "llvm/TableGen/Record.h" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::tblgen; |
26 | |
27 | /// File header and includes. |
28 | /// {0} is the dialect namespace. |
29 | constexpr const char * = R"Py( |
30 | # Autogenerated by mlir-tblgen; don't manually edit. |
31 | |
32 | from ._ods_common import _cext as _ods_cext |
33 | from ._ods_common import ( |
34 | equally_sized_accessor as _ods_equally_sized_accessor, |
35 | get_default_loc_context as _ods_get_default_loc_context, |
36 | get_op_result_or_op_results as _get_op_result_or_op_results, |
37 | get_op_result_or_value as _get_op_result_or_value, |
38 | get_op_results_or_values as _get_op_results_or_values, |
39 | segmented_accessor as _ods_segmented_accessor, |
40 | ) |
41 | _ods_ir = _ods_cext.ir |
42 | |
43 | import builtins |
44 | from typing import Sequence as _Sequence, Union as _Union |
45 | |
46 | )Py" ; |
47 | |
48 | /// Template for dialect class: |
49 | /// {0} is the dialect namespace. |
50 | constexpr const char *dialectClassTemplate = R"Py( |
51 | @_ods_cext.register_dialect |
52 | class _Dialect(_ods_ir.Dialect): |
53 | DIALECT_NAMESPACE = "{0}" |
54 | )Py" ; |
55 | |
56 | constexpr const char *dialectExtensionTemplate = R"Py( |
57 | from ._{0}_ops_gen import _Dialect |
58 | )Py" ; |
59 | |
60 | /// Template for operation class: |
61 | /// {0} is the Python class name; |
62 | /// {1} is the operation name. |
63 | constexpr const char *opClassTemplate = R"Py( |
64 | @_ods_cext.register_operation(_Dialect) |
65 | class {0}(_ods_ir.OpView): |
66 | OPERATION_NAME = "{1}" |
67 | )Py" ; |
68 | |
69 | /// Template for class level declarations of operand and result |
70 | /// segment specs. |
71 | /// {0} is either "OPERAND" or "RESULT" |
72 | /// {1} is the segment spec |
73 | /// Each segment spec is either None (default) or an array of integers |
74 | /// where: |
75 | /// 1 = single element (expect non sequence operand/result) |
76 | /// 0 = optional element (expect a value or std::nullopt) |
77 | /// -1 = operand/result is a sequence corresponding to a variadic |
78 | constexpr const char *opClassSizedSegmentsTemplate = R"Py( |
79 | _ODS_{0}_SEGMENTS = {1} |
80 | )Py" ; |
81 | |
82 | /// Template for class level declarations of the _ODS_REGIONS spec: |
83 | /// {0} is the minimum number of regions |
84 | /// {1} is the Python bool literal for hasNoVariadicRegions |
85 | constexpr const char *opClassRegionSpecTemplate = R"Py( |
86 | _ODS_REGIONS = ({0}, {1}) |
87 | )Py" ; |
88 | |
89 | /// Template for single-element accessor: |
90 | /// {0} is the name of the accessor; |
91 | /// {1} is either 'operand' or 'result'; |
92 | /// {2} is the position in the element list. |
93 | constexpr const char *opSingleTemplate = R"Py( |
94 | @builtins.property |
95 | def {0}(self): |
96 | return self.operation.{1}s[{2}] |
97 | )Py" ; |
98 | |
99 | /// Template for single-element accessor after a variable-length group: |
100 | /// {0} is the name of the accessor; |
101 | /// {1} is either 'operand' or 'result'; |
102 | /// {2} is the total number of element groups; |
103 | /// {3} is the position of the current group in the group list. |
104 | /// This works for both a single variadic group (non-negative length) and an |
105 | /// single optional element (zero length if the element is absent). |
106 | constexpr const char *opSingleAfterVariableTemplate = R"Py( |
107 | @builtins.property |
108 | def {0}(self): |
109 | _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
110 | return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] |
111 | )Py" ; |
112 | |
113 | /// Template for an optional element accessor: |
114 | /// {0} is the name of the accessor; |
115 | /// {1} is either 'operand' or 'result'; |
116 | /// {2} is the total number of element groups; |
117 | /// {3} is the position of the current group in the group list. |
118 | /// This works if we have only one variable-length group (and it's the optional |
119 | /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is |
120 | /// smaller than the total number of groups. |
121 | constexpr const char *opOneOptionalTemplate = R"Py( |
122 | @builtins.property |
123 | def {0}(self): |
124 | return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] |
125 | )Py" ; |
126 | |
127 | /// Template for the variadic group accessor in the single variadic group case: |
128 | /// {0} is the name of the accessor; |
129 | /// {1} is either 'operand' or 'result'; |
130 | /// {2} is the total number of element groups; |
131 | /// {3} is the position of the current group in the group list. |
132 | constexpr const char *opOneVariadicTemplate = R"Py( |
133 | @builtins.property |
134 | def {0}(self): |
135 | _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
136 | return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] |
137 | )Py" ; |
138 | |
139 | /// First part of the template for equally-sized variadic group accessor: |
140 | /// {0} is the name of the accessor; |
141 | /// {1} is either 'operand' or 'result'; |
142 | /// {2} is the total number of variadic groups; |
143 | /// {3} is the number of non-variadic groups preceding the current group; |
144 | /// {3} is the number of variadic groups preceding the current group. |
145 | constexpr const char *opVariadicEqualPrefixTemplate = R"Py( |
146 | @builtins.property |
147 | def {0}(self): |
148 | start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py" ; |
149 | |
150 | /// Second part of the template for equally-sized case, accessing a single |
151 | /// element: |
152 | /// {0} is either 'operand' or 'result'. |
153 | constexpr const char *opVariadicEqualSimpleTemplate = R"Py( |
154 | return self.operation.{0}s[start] |
155 | )Py" ; |
156 | |
157 | /// Second part of the template for equally-sized case, accessing a variadic |
158 | /// group: |
159 | /// {0} is either 'operand' or 'result'. |
160 | constexpr const char *opVariadicEqualVariadicTemplate = R"Py( |
161 | return self.operation.{0}s[start:start + pg] |
162 | )Py" ; |
163 | |
164 | /// Template for an attribute-sized group accessor: |
165 | /// {0} is the name of the accessor; |
166 | /// {1} is either 'operand' or 'result'; |
167 | /// {2} is the position of the group in the group list; |
168 | /// {3} is a return suffix (expected [0] for single-element, empty for |
169 | /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). |
170 | constexpr const char *opVariadicSegmentTemplate = R"Py( |
171 | @builtins.property |
172 | def {0}(self): |
173 | {1}_range = _ods_segmented_accessor( |
174 | self.operation.{1}s, |
175 | self.operation.attributes["{1}SegmentSizes"], {2}) |
176 | return {1}_range{3} |
177 | )Py" ; |
178 | |
179 | /// Template for a suffix when accessing an optional element in the |
180 | /// attribute-sized case: |
181 | /// {0} is either 'operand' or 'result'; |
182 | constexpr const char *opVariadicSegmentOptionalTrailingTemplate = |
183 | R"Py([0] if len({0}_range) > 0 else None)Py" ; |
184 | |
185 | /// Template for an operation attribute getter: |
186 | /// {0} is the name of the attribute sanitized for Python; |
187 | /// {1} is the original name of the attribute. |
188 | constexpr const char *attributeGetterTemplate = R"Py( |
189 | @builtins.property |
190 | def {0}(self): |
191 | return self.operation.attributes["{1}"] |
192 | )Py" ; |
193 | |
194 | /// Template for an optional operation attribute getter: |
195 | /// {0} is the name of the attribute sanitized for Python; |
196 | /// {1} is the original name of the attribute. |
197 | constexpr const char *optionalAttributeGetterTemplate = R"Py( |
198 | @builtins.property |
199 | def {0}(self): |
200 | if "{1}" not in self.operation.attributes: |
201 | return None |
202 | return self.operation.attributes["{1}"] |
203 | )Py" ; |
204 | |
205 | /// Template for a getter of a unit operation attribute, returns True of the |
206 | /// unit attribute is present, False otherwise (unit attributes have meaning |
207 | /// by mere presence): |
208 | /// {0} is the name of the attribute sanitized for Python, |
209 | /// {1} is the original name of the attribute. |
210 | constexpr const char *unitAttributeGetterTemplate = R"Py( |
211 | @builtins.property |
212 | def {0}(self): |
213 | return "{1}" in self.operation.attributes |
214 | )Py" ; |
215 | |
216 | /// Template for an operation attribute setter: |
217 | /// {0} is the name of the attribute sanitized for Python; |
218 | /// {1} is the original name of the attribute. |
219 | constexpr const char *attributeSetterTemplate = R"Py( |
220 | @{0}.setter |
221 | def {0}(self, value): |
222 | if value is None: |
223 | raise ValueError("'None' not allowed as value for mandatory attributes") |
224 | self.operation.attributes["{1}"] = value |
225 | )Py" ; |
226 | |
227 | /// Template for a setter of an optional operation attribute, setting to None |
228 | /// removes the attribute: |
229 | /// {0} is the name of the attribute sanitized for Python; |
230 | /// {1} is the original name of the attribute. |
231 | constexpr const char *optionalAttributeSetterTemplate = R"Py( |
232 | @{0}.setter |
233 | def {0}(self, value): |
234 | if value is not None: |
235 | self.operation.attributes["{1}"] = value |
236 | elif "{1}" in self.operation.attributes: |
237 | del self.operation.attributes["{1}"] |
238 | )Py" ; |
239 | |
240 | /// Template for a setter of a unit operation attribute, setting to None or |
241 | /// False removes the attribute: |
242 | /// {0} is the name of the attribute sanitized for Python; |
243 | /// {1} is the original name of the attribute. |
244 | constexpr const char *unitAttributeSetterTemplate = R"Py( |
245 | @{0}.setter |
246 | def {0}(self, value): |
247 | if bool(value): |
248 | self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() |
249 | elif "{1}" in self.operation.attributes: |
250 | del self.operation.attributes["{1}"] |
251 | )Py" ; |
252 | |
253 | /// Template for a deleter of an optional or a unit operation attribute, removes |
254 | /// the attribute from the operation: |
255 | /// {0} is the name of the attribute sanitized for Python; |
256 | /// {1} is the original name of the attribute. |
257 | constexpr const char *attributeDeleterTemplate = R"Py( |
258 | @{0}.deleter |
259 | def {0}(self): |
260 | del self.operation.attributes["{1}"] |
261 | )Py" ; |
262 | |
263 | constexpr const char *regionAccessorTemplate = R"Py( |
264 | @builtins.property |
265 | def {0}(self): |
266 | return self.regions[{1}] |
267 | )Py" ; |
268 | |
269 | constexpr const char *valueBuilderTemplate = R"Py( |
270 | def {0}({2}) -> {4}: |
271 | return _get_op_result_or_op_results({1}({3})) |
272 | )Py" ; |
273 | |
274 | static llvm::cl::OptionCategory |
275 | clOpPythonBindingCat("Options for -gen-python-op-bindings" ); |
276 | |
277 | static llvm::cl::opt<std::string> |
278 | clDialectName("bind-dialect" , |
279 | llvm::cl::desc("The dialect to run the generator for" ), |
280 | llvm::cl::init(Val: "" ), llvm::cl::cat(clOpPythonBindingCat)); |
281 | |
282 | static llvm::cl::opt<std::string> clDialectExtensionName( |
283 | "dialect-extension" , llvm::cl::desc("The prefix of the dialect extension" ), |
284 | llvm::cl::init(Val: "" ), llvm::cl::cat(clOpPythonBindingCat)); |
285 | |
286 | using AttributeClasses = DenseMap<StringRef, StringRef>; |
287 | |
288 | /// Checks whether `str` would shadow a generated variable or attribute |
289 | /// part of the OpView API. |
290 | static bool isODSReserved(StringRef str) { |
291 | static llvm::StringSet<> reserved( |
292 | {"attributes" , "create" , "context" , "ip" , "operands" , "print" , "get_asm" , |
293 | "loc" , "verify" , "regions" , "results" , "self" , "operation" , |
294 | "DIALECT_NAMESPACE" , "OPERATION_NAME" }); |
295 | return str.starts_with(Prefix: "_ods_" ) || str.ends_with(Suffix: "_ods" ) || |
296 | reserved.contains(key: str); |
297 | } |
298 | |
299 | /// Modifies the `name` in a way that it becomes suitable for Python bindings |
300 | /// (does not change the `name` if it already is suitable) and returns the |
301 | /// modified version. |
302 | static std::string sanitizeName(StringRef name) { |
303 | std::string processedStr = name.str(); |
304 | std::replace_if( |
305 | first: processedStr.begin(), last: processedStr.end(), |
306 | pred: [](char c) { return !llvm::isAlnum(C: c); }, new_value: '_'); |
307 | |
308 | if (llvm::isDigit(C: *processedStr.begin())) |
309 | return "_" + processedStr; |
310 | |
311 | if (isPythonReserved(str: processedStr) || isODSReserved(str: processedStr)) |
312 | return processedStr + "_" ; |
313 | return processedStr; |
314 | } |
315 | |
316 | static std::string attrSizedTraitForKind(const char *kind) { |
317 | return llvm::formatv(Fmt: "::mlir::OpTrait::AttrSized{0}{1}Segments" , |
318 | Vals: llvm::StringRef(kind).take_front().upper(), |
319 | Vals: llvm::StringRef(kind).drop_front()); |
320 | } |
321 | |
322 | /// Emits accessors to "elements" of an Op definition. Currently, the supported |
323 | /// elements are operands and results, indicated by `kind`, which must be either |
324 | /// `operand` or `result` and is used verbatim in the emitted code. |
325 | static void emitElementAccessors( |
326 | const Operator &op, raw_ostream &os, const char *kind, |
327 | llvm::function_ref<unsigned(const Operator &)> getNumVariableLength, |
328 | llvm::function_ref<int(const Operator &)> getNumElements, |
329 | llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
330 | getElement) { |
331 | assert(llvm::is_contained( |
332 | llvm::SmallVector<StringRef, 2>{"operand" , "result" }, kind) && |
333 | "unsupported kind" ); |
334 | |
335 | // Traits indicating how to process variadic elements. |
336 | std::string sameSizeTrait = |
337 | llvm::formatv(Fmt: "::mlir::OpTrait::SameVariadic{0}{1}Size" , |
338 | Vals: llvm::StringRef(kind).take_front().upper(), |
339 | Vals: llvm::StringRef(kind).drop_front()); |
340 | std::string attrSizedTrait = attrSizedTraitForKind(kind); |
341 | |
342 | unsigned numVariableLength = getNumVariableLength(op); |
343 | |
344 | // If there is only one variable-length element group, its size can be |
345 | // inferred from the total number of elements. If there are none, the |
346 | // generation is straightforward. |
347 | if (numVariableLength <= 1) { |
348 | bool seenVariableLength = false; |
349 | for (int i = 0, e = getNumElements(op); i < e; ++i) { |
350 | const NamedTypeConstraint &element = getElement(op, i); |
351 | if (element.isVariableLength()) |
352 | seenVariableLength = true; |
353 | if (element.name.empty()) |
354 | continue; |
355 | if (element.isVariableLength()) { |
356 | os << llvm::formatv(Fmt: element.isOptional() ? opOneOptionalTemplate |
357 | : opOneVariadicTemplate, |
358 | Vals: sanitizeName(name: element.name), Vals&: kind, |
359 | Vals: getNumElements(op), Vals&: i); |
360 | } else if (seenVariableLength) { |
361 | os << llvm::formatv(Fmt: opSingleAfterVariableTemplate, |
362 | Vals: sanitizeName(name: element.name), Vals&: kind, |
363 | Vals: getNumElements(op), Vals&: i); |
364 | } else { |
365 | os << llvm::formatv(Fmt: opSingleTemplate, Vals: sanitizeName(name: element.name), Vals&: kind, |
366 | Vals&: i); |
367 | } |
368 | } |
369 | return; |
370 | } |
371 | |
372 | // Handle the operations where variadic groups have the same size. |
373 | if (op.getTrait(trait: sameSizeTrait)) { |
374 | int numPrecedingSimple = 0; |
375 | int numPrecedingVariadic = 0; |
376 | for (int i = 0, e = getNumElements(op); i < e; ++i) { |
377 | const NamedTypeConstraint &element = getElement(op, i); |
378 | if (!element.name.empty()) { |
379 | os << llvm::formatv(Fmt: opVariadicEqualPrefixTemplate, |
380 | Vals: sanitizeName(name: element.name), Vals&: kind, Vals&: numVariableLength, |
381 | Vals&: numPrecedingSimple, Vals&: numPrecedingVariadic); |
382 | os << llvm::formatv(Fmt: element.isVariableLength() |
383 | ? opVariadicEqualVariadicTemplate |
384 | : opVariadicEqualSimpleTemplate, |
385 | Vals&: kind); |
386 | } |
387 | if (element.isVariableLength()) |
388 | ++numPrecedingVariadic; |
389 | else |
390 | ++numPrecedingSimple; |
391 | } |
392 | return; |
393 | } |
394 | |
395 | // Handle the operations where the size of groups (variadic or not) is |
396 | // provided as an attribute. For non-variadic elements, make sure to return |
397 | // an element rather than a singleton container. |
398 | if (op.getTrait(trait: attrSizedTrait)) { |
399 | for (int i = 0, e = getNumElements(op); i < e; ++i) { |
400 | const NamedTypeConstraint &element = getElement(op, i); |
401 | if (element.name.empty()) |
402 | continue; |
403 | std::string trailing; |
404 | if (!element.isVariableLength()) |
405 | trailing = "[0]" ; |
406 | else if (element.isOptional()) |
407 | trailing = std::string( |
408 | llvm::formatv(Fmt: opVariadicSegmentOptionalTrailingTemplate, Vals&: kind)); |
409 | os << llvm::formatv(Fmt: opVariadicSegmentTemplate, Vals: sanitizeName(name: element.name), |
410 | Vals&: kind, Vals&: i, Vals&: trailing); |
411 | } |
412 | return; |
413 | } |
414 | |
415 | llvm::PrintFatalError(Msg: "unsupported " + llvm::Twine(kind) + " structure" ); |
416 | } |
417 | |
418 | /// Free function helpers accessing Operator components. |
419 | static int getNumOperands(const Operator &op) { return op.getNumOperands(); } |
420 | static const NamedTypeConstraint &getOperand(const Operator &op, int i) { |
421 | return op.getOperand(index: i); |
422 | } |
423 | static int getNumResults(const Operator &op) { return op.getNumResults(); } |
424 | static const NamedTypeConstraint &getResult(const Operator &op, int i) { |
425 | return op.getResult(index: i); |
426 | } |
427 | |
428 | /// Emits accessors to Op operands. |
429 | static void emitOperandAccessors(const Operator &op, raw_ostream &os) { |
430 | auto getNumVariableLengthOperands = [](const Operator &oper) { |
431 | return oper.getNumVariableLengthOperands(); |
432 | }; |
433 | emitElementAccessors(op, os, kind: "operand" , getNumVariableLength: getNumVariableLengthOperands, |
434 | getNumElements: getNumOperands, getElement: getOperand); |
435 | } |
436 | |
437 | /// Emits accessors Op results. |
438 | static void emitResultAccessors(const Operator &op, raw_ostream &os) { |
439 | auto getNumVariableLengthResults = [](const Operator &oper) { |
440 | return oper.getNumVariableLengthResults(); |
441 | }; |
442 | emitElementAccessors(op, os, kind: "result" , getNumVariableLength: getNumVariableLengthResults, |
443 | getNumElements: getNumResults, getElement: getResult); |
444 | } |
445 | |
446 | /// Emits accessors to Op attributes. |
447 | static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { |
448 | for (const auto &namedAttr : op.getAttributes()) { |
449 | // Skip "derived" attributes because they are just C++ functions that we |
450 | // don't currently expose. |
451 | if (namedAttr.attr.isDerivedAttr()) |
452 | continue; |
453 | |
454 | if (namedAttr.name.empty()) |
455 | continue; |
456 | |
457 | std::string sanitizedName = sanitizeName(name: namedAttr.name); |
458 | |
459 | // Unit attributes are handled specially. |
460 | if (namedAttr.attr.getStorageType().trim().equals(RHS: "::mlir::UnitAttr" )) { |
461 | os << llvm::formatv(Fmt: unitAttributeGetterTemplate, Vals&: sanitizedName, |
462 | Vals: namedAttr.name); |
463 | os << llvm::formatv(Fmt: unitAttributeSetterTemplate, Vals&: sanitizedName, |
464 | Vals: namedAttr.name); |
465 | os << llvm::formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName, |
466 | Vals: namedAttr.name); |
467 | continue; |
468 | } |
469 | |
470 | if (namedAttr.attr.isOptional()) { |
471 | os << llvm::formatv(Fmt: optionalAttributeGetterTemplate, Vals&: sanitizedName, |
472 | Vals: namedAttr.name); |
473 | os << llvm::formatv(Fmt: optionalAttributeSetterTemplate, Vals&: sanitizedName, |
474 | Vals: namedAttr.name); |
475 | os << llvm::formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName, |
476 | Vals: namedAttr.name); |
477 | } else { |
478 | os << llvm::formatv(Fmt: attributeGetterTemplate, Vals&: sanitizedName, |
479 | Vals: namedAttr.name); |
480 | os << llvm::formatv(Fmt: attributeSetterTemplate, Vals&: sanitizedName, |
481 | Vals: namedAttr.name); |
482 | // Non-optional attributes cannot be deleted. |
483 | } |
484 | } |
485 | } |
486 | |
487 | /// Template for the default auto-generated builder. |
488 | /// {0} is a comma-separated list of builder arguments, including the trailing |
489 | /// `loc` and `ip`; |
490 | /// {1} is the code populating `operands`, `results` and `attributes`, |
491 | /// `successors` fields. |
492 | constexpr const char *initTemplate = R"Py( |
493 | def __init__(self, {0}): |
494 | operands = [] |
495 | results = [] |
496 | attributes = {{} |
497 | regions = None |
498 | {1} |
499 | super().__init__(self.build_generic({2})) |
500 | )Py" ; |
501 | |
502 | /// Template for appending a single element to the operand/result list. |
503 | /// {0} is the field name. |
504 | constexpr const char *singleOperandAppendTemplate = |
505 | "operands.append(_get_op_result_or_value({0}))" ; |
506 | constexpr const char *singleResultAppendTemplate = "results.append({0})" ; |
507 | |
508 | /// Template for appending an optional element to the operand/result list. |
509 | /// {0} is the field name. |
510 | constexpr const char *optionalAppendOperandTemplate = |
511 | "if {0} is not None: operands.append(_get_op_result_or_value({0}))" ; |
512 | constexpr const char *optionalAppendAttrSizedOperandsTemplate = |
513 | "operands.append(_get_op_result_or_value({0}) if {0} is not None else " |
514 | "None)" ; |
515 | constexpr const char *optionalAppendResultTemplate = |
516 | "if {0} is not None: results.append({0})" ; |
517 | |
518 | /// Template for appending a list of elements to the operand/result list. |
519 | /// {0} is the field name. |
520 | constexpr const char *multiOperandAppendTemplate = |
521 | "operands.extend(_get_op_results_or_values({0}))" ; |
522 | constexpr const char *multiOperandAppendPackTemplate = |
523 | "operands.append(_get_op_results_or_values({0}))" ; |
524 | constexpr const char *multiResultAppendTemplate = "results.extend({0})" ; |
525 | |
526 | /// Template for attribute builder from raw input in the operation builder. |
527 | /// {0} is the builder argument name; |
528 | /// {1} is the attribute builder from raw; |
529 | /// {2} is the attribute builder from raw. |
530 | /// Use the value the user passed in if either it is already an Attribute or |
531 | /// there is no method registered to make it an Attribute. |
532 | constexpr const char *initAttributeWithBuilderTemplate = |
533 | R"Py(attributes["{1}"] = ({0} if ( |
534 | isinstance({0}, _ods_ir.Attribute) or |
535 | not _ods_ir.AttrBuilder.contains('{2}')) else |
536 | _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py" ; |
537 | |
538 | /// Template for attribute builder from raw input for optional attribute in the |
539 | /// operation builder. |
540 | /// {0} is the builder argument name; |
541 | /// {1} is the attribute builder from raw; |
542 | /// {2} is the attribute builder from raw. |
543 | /// Use the value the user passed in if either it is already an Attribute or |
544 | /// there is no method registered to make it an Attribute. |
545 | constexpr const char *initOptionalAttributeWithBuilderTemplate = |
546 | R"Py(if {0} is not None: attributes["{1}"] = ({0} if ( |
547 | isinstance({0}, _ods_ir.Attribute) or |
548 | not _ods_ir.AttrBuilder.contains('{2}')) else |
549 | _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py" ; |
550 | |
551 | constexpr const char *initUnitAttributeTemplate = |
552 | R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( |
553 | _ods_get_default_loc_context(loc)))Py" ; |
554 | |
555 | /// Template to initialize the successors list in the builder if there are any |
556 | /// successors. |
557 | /// {0} is the value to initialize the successors list to. |
558 | constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py" ; |
559 | |
560 | /// Template to append or extend the list of successors in the builder. |
561 | /// {0} is the list method ('append' or 'extend'); |
562 | /// {1} is the value to add. |
563 | constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py" ; |
564 | |
565 | /// Returns true if the SameArgumentAndResultTypes trait can be used to infer |
566 | /// result types of the given operation. |
567 | static bool hasSameArgumentAndResultTypes(const Operator &op) { |
568 | return op.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType" ) && |
569 | op.getNumVariableLengthResults() == 0; |
570 | } |
571 | |
572 | /// Returns true if the FirstAttrDerivedResultType trait can be used to infer |
573 | /// result types of the given operation. |
574 | static bool hasFirstAttrDerivedResultTypes(const Operator &op) { |
575 | return op.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType" ) && |
576 | op.getNumVariableLengthResults() == 0; |
577 | } |
578 | |
579 | /// Returns true if the InferTypeOpInterface can be used to infer result types |
580 | /// of the given operation. |
581 | static bool hasInferTypeInterface(const Operator &op) { |
582 | return op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait" ) && |
583 | op.getNumRegions() == 0; |
584 | } |
585 | |
586 | /// Returns true if there is a trait or interface that can be used to infer |
587 | /// result types of the given operation. |
588 | static bool canInferType(const Operator &op) { |
589 | return hasSameArgumentAndResultTypes(op) || |
590 | hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); |
591 | } |
592 | |
593 | /// Populates `builderArgs` with result names if the builder is expected to |
594 | /// accept them as arguments. |
595 | static void |
596 | populateBuilderArgsResults(const Operator &op, |
597 | llvm::SmallVectorImpl<std::string> &builderArgs) { |
598 | if (canInferType(op)) |
599 | return; |
600 | |
601 | for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
602 | std::string name = op.getResultName(index: i).str(); |
603 | if (name.empty()) { |
604 | if (op.getNumResults() == 1) { |
605 | // Special case for one result, make the default name be 'result' |
606 | // to properly match the built-in result accessor. |
607 | name = "result" ; |
608 | } else { |
609 | name = llvm::formatv(Fmt: "_gen_res_{0}" , Vals&: i); |
610 | } |
611 | } |
612 | name = sanitizeName(name); |
613 | builderArgs.push_back(Elt: name); |
614 | } |
615 | } |
616 | |
617 | /// Populates `builderArgs` with the Python-compatible names of builder function |
618 | /// arguments using intermixed attributes and operands in the same order as they |
619 | /// appear in the `arguments` field of the op definition. Additionally, |
620 | /// `operandNames` is populated with names of operands in their order of |
621 | /// appearance. |
622 | static void |
623 | populateBuilderArgs(const Operator &op, |
624 | llvm::SmallVectorImpl<std::string> &builderArgs, |
625 | llvm::SmallVectorImpl<std::string> &operandNames) { |
626 | for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
627 | std::string name = op.getArgName(index: i).str(); |
628 | if (name.empty()) |
629 | name = llvm::formatv(Fmt: "_gen_arg_{0}" , Vals&: i); |
630 | name = sanitizeName(name); |
631 | builderArgs.push_back(Elt: name); |
632 | if (!op.getArg(index: i).is<NamedAttribute *>()) |
633 | operandNames.push_back(Elt: name); |
634 | } |
635 | } |
636 | |
637 | /// Populates `builderArgs` with the Python-compatible names of builder function |
638 | /// successor arguments. Additionally, `successorArgNames` is also populated. |
639 | static void populateBuilderArgsSuccessors( |
640 | const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs, |
641 | llvm::SmallVectorImpl<std::string> &successorArgNames) { |
642 | |
643 | for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { |
644 | NamedSuccessor successor = op.getSuccessor(index: i); |
645 | std::string name = std::string(successor.name); |
646 | if (name.empty()) |
647 | name = llvm::formatv(Fmt: "_gen_successor_{0}" , Vals&: i); |
648 | name = sanitizeName(name); |
649 | builderArgs.push_back(Elt: name); |
650 | successorArgNames.push_back(Elt: name); |
651 | } |
652 | } |
653 | |
654 | /// Populates `builderLines` with additional lines that are required in the |
655 | /// builder to set up operation attributes. `argNames` is expected to contain |
656 | /// the names of builder arguments that correspond to op arguments, i.e. to the |
657 | /// operands and attributes in the same order as they appear in the `arguments` |
658 | /// field. |
659 | static void |
660 | populateBuilderLinesAttr(const Operator &op, |
661 | llvm::ArrayRef<std::string> argNames, |
662 | llvm::SmallVectorImpl<std::string> &builderLines) { |
663 | builderLines.push_back(Elt: "_ods_context = _ods_get_default_loc_context(loc)" ); |
664 | for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
665 | Argument arg = op.getArg(index: i); |
666 | auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg); |
667 | if (!attribute) |
668 | continue; |
669 | |
670 | // Unit attributes are handled specially. |
671 | if (attribute->attr.getStorageType().trim().equals(RHS: "::mlir::UnitAttr" )) { |
672 | builderLines.push_back(Elt: llvm::formatv(Fmt: initUnitAttributeTemplate, |
673 | Vals&: attribute->name, Vals: argNames[i])); |
674 | continue; |
675 | } |
676 | |
677 | builderLines.push_back(Elt: llvm::formatv( |
678 | Fmt: attribute->attr.isOptional() || attribute->attr.hasDefaultValue() |
679 | ? initOptionalAttributeWithBuilderTemplate |
680 | : initAttributeWithBuilderTemplate, |
681 | Vals: argNames[i], Vals&: attribute->name, Vals: attribute->attr.getAttrDefName())); |
682 | } |
683 | } |
684 | |
685 | /// Populates `builderLines` with additional lines that are required in the |
686 | /// builder to set up successors. successorArgNames is expected to correspond |
687 | /// to the Python argument name for each successor on the op. |
688 | static void populateBuilderLinesSuccessors( |
689 | const Operator &op, llvm::ArrayRef<std::string> successorArgNames, |
690 | llvm::SmallVectorImpl<std::string> &builderLines) { |
691 | if (successorArgNames.empty()) { |
692 | builderLines.push_back(Elt: llvm::formatv(Fmt: initSuccessorsTemplate, Vals: "None" )); |
693 | return; |
694 | } |
695 | |
696 | builderLines.push_back(Elt: llvm::formatv(Fmt: initSuccessorsTemplate, Vals: "[]" )); |
697 | for (int i = 0, e = successorArgNames.size(); i < e; ++i) { |
698 | auto &argName = successorArgNames[i]; |
699 | const NamedSuccessor &successor = op.getSuccessor(index: i); |
700 | builderLines.push_back( |
701 | Elt: llvm::formatv(Fmt: addSuccessorTemplate, |
702 | Vals: successor.isVariadic() ? "extend" : "append" , Vals: argName)); |
703 | } |
704 | } |
705 | |
706 | /// Populates `builderLines` with additional lines that are required in the |
707 | /// builder to set up op operands. |
708 | static void |
709 | populateBuilderLinesOperand(const Operator &op, |
710 | llvm::ArrayRef<std::string> names, |
711 | llvm::SmallVectorImpl<std::string> &builderLines) { |
712 | bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "operand" )) != nullptr; |
713 | |
714 | // For each element, find or generate a name. |
715 | for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
716 | const NamedTypeConstraint &element = op.getOperand(index: i); |
717 | std::string name = names[i]; |
718 | |
719 | // Choose the formatting string based on the element kind. |
720 | llvm::StringRef formatString; |
721 | if (!element.isVariableLength()) { |
722 | formatString = singleOperandAppendTemplate; |
723 | } else if (element.isOptional()) { |
724 | if (sizedSegments) { |
725 | formatString = optionalAppendAttrSizedOperandsTemplate; |
726 | } else { |
727 | formatString = optionalAppendOperandTemplate; |
728 | } |
729 | } else { |
730 | assert(element.isVariadic() && "unhandled element group type" ); |
731 | // If emitting with sizedSegments, then we add the actual list-typed |
732 | // element. Otherwise, we extend the actual operands. |
733 | if (sizedSegments) { |
734 | formatString = multiOperandAppendPackTemplate; |
735 | } else { |
736 | formatString = multiOperandAppendTemplate; |
737 | } |
738 | } |
739 | |
740 | builderLines.push_back(Elt: llvm::formatv(Fmt: formatString.data(), Vals&: name)); |
741 | } |
742 | } |
743 | |
744 | /// Python code template for deriving the operation result types from its |
745 | /// attribute: |
746 | /// - {0} is the name of the attribute from which to derive the types. |
747 | constexpr const char *deriveTypeFromAttrTemplate = |
748 | R"Py(_ods_result_type_source_attr = attributes["{0}"] |
749 | _ods_derived_result_type = ( |
750 | _ods_ir.TypeAttr(_ods_result_type_source_attr).value |
751 | if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else |
752 | _ods_result_type_source_attr.type))Py" ; |
753 | |
754 | /// Python code template appending {0} type {1} times to the results list. |
755 | constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})" ; |
756 | |
757 | /// Appends the given multiline string as individual strings into |
758 | /// `builderLines`. |
759 | static void appendLineByLine(StringRef string, |
760 | llvm::SmallVectorImpl<std::string> &builderLines) { |
761 | |
762 | std::pair<StringRef, StringRef> split = std::make_pair(x&: string, y&: string); |
763 | do { |
764 | split = split.second.split(Separator: '\n'); |
765 | builderLines.push_back(Elt: split.first.str()); |
766 | } while (!split.second.empty()); |
767 | } |
768 | |
769 | /// Populates `builderLines` with additional lines that are required in the |
770 | /// builder to set up op results. |
771 | static void |
772 | populateBuilderLinesResult(const Operator &op, |
773 | llvm::ArrayRef<std::string> names, |
774 | llvm::SmallVectorImpl<std::string> &builderLines) { |
775 | bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "result" )) != nullptr; |
776 | |
777 | if (hasSameArgumentAndResultTypes(op)) { |
778 | builderLines.push_back(Elt: llvm::formatv( |
779 | Fmt: appendSameResultsTemplate, Vals: "operands[0].type" , Vals: op.getNumResults())); |
780 | return; |
781 | } |
782 | |
783 | if (hasFirstAttrDerivedResultTypes(op)) { |
784 | const NamedAttribute &firstAttr = op.getAttribute(index: 0); |
785 | assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " |
786 | "from which the type is derived" ); |
787 | appendLineByLine( |
788 | string: llvm::formatv(Fmt: deriveTypeFromAttrTemplate, Vals: firstAttr.name).str(), |
789 | builderLines); |
790 | builderLines.push_back(Elt: llvm::formatv(Fmt: appendSameResultsTemplate, |
791 | Vals: "_ods_derived_result_type" , |
792 | Vals: op.getNumResults())); |
793 | return; |
794 | } |
795 | |
796 | if (hasInferTypeInterface(op)) |
797 | return; |
798 | |
799 | // For each element, find or generate a name. |
800 | for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
801 | const NamedTypeConstraint &element = op.getResult(index: i); |
802 | std::string name = names[i]; |
803 | |
804 | // Choose the formatting string based on the element kind. |
805 | llvm::StringRef formatString; |
806 | if (!element.isVariableLength()) { |
807 | formatString = singleResultAppendTemplate; |
808 | } else if (element.isOptional()) { |
809 | formatString = optionalAppendResultTemplate; |
810 | } else { |
811 | assert(element.isVariadic() && "unhandled element group type" ); |
812 | // If emitting with sizedSegments, then we add the actual list-typed |
813 | // element. Otherwise, we extend the actual operands. |
814 | if (sizedSegments) { |
815 | formatString = singleResultAppendTemplate; |
816 | } else { |
817 | formatString = multiResultAppendTemplate; |
818 | } |
819 | } |
820 | |
821 | builderLines.push_back(Elt: llvm::formatv(Fmt: formatString.data(), Vals&: name)); |
822 | } |
823 | } |
824 | |
825 | /// If the operation has variadic regions, adds a builder argument to specify |
826 | /// the number of those regions and builder lines to forward it to the generic |
827 | /// constructor. |
828 | static void |
829 | populateBuilderRegions(const Operator &op, |
830 | llvm::SmallVectorImpl<std::string> &builderArgs, |
831 | llvm::SmallVectorImpl<std::string> &builderLines) { |
832 | if (op.hasNoVariadicRegions()) |
833 | return; |
834 | |
835 | // This is currently enforced when Operator is constructed. |
836 | assert(op.getNumVariadicRegions() == 1 && |
837 | op.getRegion(op.getNumRegions() - 1).isVariadic() && |
838 | "expected the last region to be varidic" ); |
839 | |
840 | const NamedRegion ®ion = op.getRegion(index: op.getNumRegions() - 1); |
841 | std::string name = |
842 | ("num_" + region.name.take_front().lower() + region.name.drop_front()) |
843 | .str(); |
844 | builderArgs.push_back(Elt: name); |
845 | builderLines.push_back( |
846 | Elt: llvm::formatv(Fmt: "regions = {0} + {1}" , Vals: op.getNumRegions() - 1, Vals&: name)); |
847 | } |
848 | |
849 | /// Emits a default builder constructing an operation from the list of its |
850 | /// result types, followed by a list of its operands. Returns vector |
851 | /// of fully built functionArgs for downstream users (to save having to |
852 | /// rebuild anew). |
853 | static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op, |
854 | raw_ostream &os) { |
855 | llvm::SmallVector<std::string> builderArgs; |
856 | llvm::SmallVector<std::string> builderLines; |
857 | llvm::SmallVector<std::string> operandArgNames; |
858 | llvm::SmallVector<std::string> successorArgNames; |
859 | builderArgs.reserve(N: op.getNumOperands() + op.getNumResults() + |
860 | op.getNumNativeAttributes() + op.getNumSuccessors()); |
861 | populateBuilderArgsResults(op, builderArgs); |
862 | size_t numResultArgs = builderArgs.size(); |
863 | populateBuilderArgs(op, builderArgs, operandNames&: operandArgNames); |
864 | size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; |
865 | populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); |
866 | |
867 | populateBuilderLinesOperand(op, names: operandArgNames, builderLines); |
868 | populateBuilderLinesAttr( |
869 | op, argNames: llvm::ArrayRef(builderArgs).drop_front(N: numResultArgs), builderLines); |
870 | populateBuilderLinesResult( |
871 | op, names: llvm::ArrayRef(builderArgs).take_front(N: numResultArgs), builderLines); |
872 | populateBuilderLinesSuccessors(op, successorArgNames, builderLines); |
873 | populateBuilderRegions(op, builderArgs, builderLines); |
874 | |
875 | // Layout of builderArgs vector elements: |
876 | // [ result_args operand_attr_args successor_args regions ] |
877 | |
878 | // Determine whether the argument corresponding to a given index into the |
879 | // builderArgs vector is a python keyword argument or not. |
880 | auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { |
881 | // All result, successor, and region arguments are positional arguments. |
882 | if ((builderArgIndex < numResultArgs) || |
883 | (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) |
884 | return false; |
885 | // Keyword arguments: |
886 | // - optional named attributes (including unit attributes) |
887 | // - default-valued named attributes |
888 | // - optional operands |
889 | Argument a = op.getArg(index: builderArgIndex - numResultArgs); |
890 | if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: a)) |
891 | return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); |
892 | if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: a)) |
893 | return ntype->isOptional(); |
894 | return false; |
895 | }; |
896 | |
897 | // StringRefs in functionArgs refer to strings allocated by builderArgs. |
898 | llvm::SmallVector<llvm::StringRef> functionArgs; |
899 | |
900 | // Add positional arguments. |
901 | for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { |
902 | if (!isKeywordArgFn(i)) |
903 | functionArgs.push_back(Elt: builderArgs[i]); |
904 | } |
905 | |
906 | // Add a bare '*' to indicate that all following arguments must be keyword |
907 | // arguments. |
908 | functionArgs.push_back(Elt: "*" ); |
909 | |
910 | // Add a default 'None' value to each keyword arg string, and then add to the |
911 | // function args list. |
912 | for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { |
913 | if (isKeywordArgFn(i)) { |
914 | builderArgs[i].append(s: "=None" ); |
915 | functionArgs.push_back(Elt: builderArgs[i]); |
916 | } |
917 | } |
918 | functionArgs.push_back(Elt: "loc=None" ); |
919 | functionArgs.push_back(Elt: "ip=None" ); |
920 | |
921 | SmallVector<std::string> initArgs; |
922 | initArgs.push_back(Elt: "attributes=attributes" ); |
923 | if (!hasInferTypeInterface(op)) |
924 | initArgs.push_back(Elt: "results=results" ); |
925 | initArgs.push_back(Elt: "operands=operands" ); |
926 | initArgs.push_back(Elt: "successors=_ods_successors" ); |
927 | initArgs.push_back(Elt: "regions=regions" ); |
928 | initArgs.push_back(Elt: "loc=loc" ); |
929 | initArgs.push_back(Elt: "ip=ip" ); |
930 | |
931 | os << llvm::formatv(Fmt: initTemplate, Vals: llvm::join(R&: functionArgs, Separator: ", " ), |
932 | Vals: llvm::join(R&: builderLines, Separator: "\n " ), |
933 | Vals: llvm::join(R&: initArgs, Separator: ", " )); |
934 | return llvm::to_vector<8>( |
935 | Range: llvm::map_range(C&: functionArgs, F: [](llvm::StringRef s) { return s.str(); })); |
936 | } |
937 | |
938 | static void emitSegmentSpec( |
939 | const Operator &op, const char *kind, |
940 | llvm::function_ref<int(const Operator &)> getNumElements, |
941 | llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
942 | getElement, |
943 | raw_ostream &os) { |
944 | std::string segmentSpec("[" ); |
945 | for (int i = 0, e = getNumElements(op); i < e; ++i) { |
946 | const NamedTypeConstraint &element = getElement(op, i); |
947 | if (element.isOptional()) { |
948 | segmentSpec.append(s: "0," ); |
949 | } else if (element.isVariadic()) { |
950 | segmentSpec.append(s: "-1," ); |
951 | } else { |
952 | segmentSpec.append(s: "1," ); |
953 | } |
954 | } |
955 | segmentSpec.append(s: "]" ); |
956 | |
957 | os << llvm::formatv(Fmt: opClassSizedSegmentsTemplate, Vals&: kind, Vals&: segmentSpec); |
958 | } |
959 | |
960 | static void emitRegionAttributes(const Operator &op, raw_ostream &os) { |
961 | // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). |
962 | // Note that the base OpView class defines this as (0, True). |
963 | unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); |
964 | os << llvm::formatv(Fmt: opClassRegionSpecTemplate, Vals&: minRegionCount, |
965 | Vals: op.hasNoVariadicRegions() ? "True" : "False" ); |
966 | } |
967 | |
968 | /// Emits named accessors to regions. |
969 | static void emitRegionAccessors(const Operator &op, raw_ostream &os) { |
970 | for (const auto &en : llvm::enumerate(First: op.getRegions())) { |
971 | const NamedRegion ®ion = en.value(); |
972 | if (region.name.empty()) |
973 | continue; |
974 | |
975 | assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && |
976 | "expected only the last region to be variadic" ); |
977 | os << llvm::formatv(Fmt: regionAccessorTemplate, Vals: sanitizeName(name: region.name), |
978 | Vals: std::to_string(val: en.index()) + |
979 | (region.isVariadic() ? ":" : "" )); |
980 | } |
981 | } |
982 | |
983 | /// Emits builder that extracts results from op |
984 | static void emitValueBuilder(const Operator &op, |
985 | llvm::SmallVector<std::string> functionArgs, |
986 | raw_ostream &os) { |
987 | // Params with (possibly) default args. |
988 | auto valueBuilderParams = |
989 | llvm::map_range(C&: functionArgs, F: [](const std::string &argAndMaybeDefault) { |
990 | llvm::SmallVector<llvm::StringRef> argMaybeDefault = |
991 | llvm::to_vector<2>(Range: llvm::split(Str: argAndMaybeDefault, Separator: "=" )); |
992 | auto arg = llvm::convertToSnakeFromCamelCase(input: argMaybeDefault[0]); |
993 | if (argMaybeDefault.size() == 2) |
994 | return arg + "=" + argMaybeDefault[1].str(); |
995 | return arg; |
996 | }); |
997 | // Actual args passed to op builder (e.g., opParam=op_param). |
998 | auto opBuilderArgs = llvm::map_range( |
999 | C: llvm::make_filter_range(Range&: functionArgs, |
1000 | Pred: [](const std::string &s) { return s != "*" ; }), |
1001 | F: [](const std::string &arg) { |
1002 | auto lhs = *llvm::split(Str: arg, Separator: "=" ).begin(); |
1003 | return (lhs + "=" + llvm::convertToSnakeFromCamelCase(input: lhs)).str(); |
1004 | }); |
1005 | std::string nameWithoutDialect = |
1006 | op.getOperationName().substr(pos: op.getOperationName().find(c: '.') + 1); |
1007 | os << llvm::formatv( |
1008 | Fmt: valueBuilderTemplate, Vals: sanitizeName(name: nameWithoutDialect), |
1009 | Vals: op.getCppClassName(), Vals: llvm::join(R&: valueBuilderParams, Separator: ", " ), |
1010 | Vals: llvm::join(R&: opBuilderArgs, Separator: ", " ), |
1011 | Vals: (op.getNumResults() > 1 |
1012 | ? "_Sequence[_ods_ir.Value]" |
1013 | : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation" ))); |
1014 | } |
1015 | |
1016 | /// Emits bindings for a specific Op to the given output stream. |
1017 | static void emitOpBindings(const Operator &op, raw_ostream &os) { |
1018 | os << llvm::formatv(Fmt: opClassTemplate, Vals: op.getCppClassName(), |
1019 | Vals: op.getOperationName()); |
1020 | |
1021 | // Sized segments. |
1022 | if (op.getTrait(trait: attrSizedTraitForKind(kind: "operand" )) != nullptr) { |
1023 | emitSegmentSpec(op, kind: "OPERAND" , getNumElements: getNumOperands, getElement: getOperand, os); |
1024 | } |
1025 | if (op.getTrait(trait: attrSizedTraitForKind(kind: "result" )) != nullptr) { |
1026 | emitSegmentSpec(op, kind: "RESULT" , getNumElements: getNumResults, getElement: getResult, os); |
1027 | } |
1028 | |
1029 | emitRegionAttributes(op, os); |
1030 | llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os); |
1031 | emitOperandAccessors(op, os); |
1032 | emitAttributeAccessors(op, os); |
1033 | emitResultAccessors(op, os); |
1034 | emitRegionAccessors(op, os); |
1035 | emitValueBuilder(op, functionArgs, os); |
1036 | } |
1037 | |
1038 | /// Emits bindings for the dialect specified in the command line, including file |
1039 | /// headers and utilities. Returns `false` on success to comply with Tablegen |
1040 | /// registration requirements. |
1041 | static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { |
1042 | if (clDialectName.empty()) |
1043 | llvm::PrintFatalError(Msg: "dialect name not provided" ); |
1044 | |
1045 | os << fileHeader; |
1046 | if (!clDialectExtensionName.empty()) |
1047 | os << llvm::formatv(Fmt: dialectExtensionTemplate, Vals&: clDialectName.getValue()); |
1048 | else |
1049 | os << llvm::formatv(Fmt: dialectClassTemplate, Vals&: clDialectName.getValue()); |
1050 | |
1051 | for (const llvm::Record *rec : records.getAllDerivedDefinitions(ClassName: "Op" )) { |
1052 | Operator op(rec); |
1053 | if (op.getDialectName() == clDialectName.getValue()) |
1054 | emitOpBindings(op, os); |
1055 | } |
1056 | return false; |
1057 | } |
1058 | |
1059 | static GenRegistration |
1060 | genPythonBindings("gen-python-op-bindings" , |
1061 | "Generate Python bindings for MLIR Ops" , &emitAllOps); |
1062 | |