| 1 | //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python |
| 10 | // binding classes wrapping a generic operation API. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "OpGenHelpers.h" |
| 15 | |
| 16 | #include "mlir/TableGen/GenInfo.h" |
| 17 | #include "mlir/TableGen/Operator.h" |
| 18 | #include "llvm/ADT/StringSet.h" |
| 19 | #include "llvm/Support/CommandLine.h" |
| 20 | #include "llvm/Support/FormatVariadic.h" |
| 21 | #include "llvm/TableGen/Error.h" |
| 22 | #include "llvm/TableGen/Record.h" |
| 23 | |
| 24 | using namespace mlir; |
| 25 | using namespace mlir::tblgen; |
| 26 | using llvm::formatv; |
| 27 | using llvm::Record; |
| 28 | using llvm::RecordKeeper; |
| 29 | |
| 30 | /// File header and includes. |
| 31 | /// {0} is the dialect namespace. |
| 32 | constexpr const char * = R"Py( |
| 33 | # Autogenerated by mlir-tblgen; don't manually edit. |
| 34 | |
| 35 | from ._ods_common import _cext as _ods_cext |
| 36 | from ._ods_common import ( |
| 37 | equally_sized_accessor as _ods_equally_sized_accessor, |
| 38 | get_default_loc_context as _ods_get_default_loc_context, |
| 39 | get_op_result_or_op_results as _get_op_result_or_op_results, |
| 40 | get_op_results_or_values as _get_op_results_or_values, |
| 41 | segmented_accessor as _ods_segmented_accessor, |
| 42 | ) |
| 43 | _ods_ir = _ods_cext.ir |
| 44 | |
| 45 | import builtins |
| 46 | from typing import Sequence as _Sequence, Union as _Union |
| 47 | |
| 48 | )Py" ; |
| 49 | |
| 50 | /// Template for dialect class: |
| 51 | /// {0} is the dialect namespace. |
| 52 | constexpr const char *dialectClassTemplate = R"Py( |
| 53 | @_ods_cext.register_dialect |
| 54 | class _Dialect(_ods_ir.Dialect): |
| 55 | DIALECT_NAMESPACE = "{0}" |
| 56 | )Py" ; |
| 57 | |
| 58 | constexpr const char *dialectExtensionTemplate = R"Py( |
| 59 | from ._{0}_ops_gen import _Dialect |
| 60 | )Py" ; |
| 61 | |
| 62 | /// Template for operation class: |
| 63 | /// {0} is the Python class name; |
| 64 | /// {1} is the operation name. |
| 65 | constexpr const char *opClassTemplate = R"Py( |
| 66 | @_ods_cext.register_operation(_Dialect) |
| 67 | class {0}(_ods_ir.OpView): |
| 68 | OPERATION_NAME = "{1}" |
| 69 | )Py" ; |
| 70 | |
| 71 | /// Template for class level declarations of operand and result |
| 72 | /// segment specs. |
| 73 | /// {0} is either "OPERAND" or "RESULT" |
| 74 | /// {1} is the segment spec |
| 75 | /// Each segment spec is either None (default) or an array of integers |
| 76 | /// where: |
| 77 | /// 1 = single element (expect non sequence operand/result) |
| 78 | /// 0 = optional element (expect a value or std::nullopt) |
| 79 | /// -1 = operand/result is a sequence corresponding to a variadic |
| 80 | constexpr const char *opClassSizedSegmentsTemplate = R"Py( |
| 81 | _ODS_{0}_SEGMENTS = {1} |
| 82 | )Py" ; |
| 83 | |
| 84 | /// Template for class level declarations of the _ODS_REGIONS spec: |
| 85 | /// {0} is the minimum number of regions |
| 86 | /// {1} is the Python bool literal for hasNoVariadicRegions |
| 87 | constexpr const char *opClassRegionSpecTemplate = R"Py( |
| 88 | _ODS_REGIONS = ({0}, {1}) |
| 89 | )Py" ; |
| 90 | |
| 91 | /// Template for single-element accessor: |
| 92 | /// {0} is the name of the accessor; |
| 93 | /// {1} is either 'operand' or 'result'; |
| 94 | /// {2} is the position in the element list. |
| 95 | constexpr const char *opSingleTemplate = R"Py( |
| 96 | @builtins.property |
| 97 | def {0}(self): |
| 98 | return self.operation.{1}s[{2}] |
| 99 | )Py" ; |
| 100 | |
| 101 | /// Template for single-element accessor after a variable-length group: |
| 102 | /// {0} is the name of the accessor; |
| 103 | /// {1} is either 'operand' or 'result'; |
| 104 | /// {2} is the total number of element groups; |
| 105 | /// {3} is the position of the current group in the group list. |
| 106 | /// This works for both a single variadic group (non-negative length) and an |
| 107 | /// single optional element (zero length if the element is absent). |
| 108 | constexpr const char *opSingleAfterVariableTemplate = R"Py( |
| 109 | @builtins.property |
| 110 | def {0}(self): |
| 111 | _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
| 112 | return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] |
| 113 | )Py" ; |
| 114 | |
| 115 | /// Template for an optional element accessor: |
| 116 | /// {0} is the name of the accessor; |
| 117 | /// {1} is either 'operand' or 'result'; |
| 118 | /// {2} is the total number of element groups; |
| 119 | /// {3} is the position of the current group in the group list. |
| 120 | /// This works if we have only one variable-length group (and it's the optional |
| 121 | /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is |
| 122 | /// smaller than the total number of groups. |
| 123 | constexpr const char *opOneOptionalTemplate = R"Py( |
| 124 | @builtins.property |
| 125 | def {0}(self): |
| 126 | return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] |
| 127 | )Py" ; |
| 128 | |
| 129 | /// Template for the variadic group accessor in the single variadic group case: |
| 130 | /// {0} is the name of the accessor; |
| 131 | /// {1} is either 'operand' or 'result'; |
| 132 | /// {2} is the total number of element groups; |
| 133 | /// {3} is the position of the current group in the group list. |
| 134 | constexpr const char *opOneVariadicTemplate = R"Py( |
| 135 | @builtins.property |
| 136 | def {0}(self): |
| 137 | _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
| 138 | return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] |
| 139 | )Py" ; |
| 140 | |
| 141 | /// First part of the template for equally-sized variadic group accessor: |
| 142 | /// {0} is the name of the accessor; |
| 143 | /// {1} is either 'operand' or 'result'; |
| 144 | /// {2} is the total number of non-variadic groups; |
| 145 | /// {3} is the total number of variadic groups; |
| 146 | /// {4} is the number of non-variadic groups preceding the current group; |
| 147 | /// {5} is the number of variadic groups preceding the current group. |
| 148 | constexpr const char *opVariadicEqualPrefixTemplate = R"Py( |
| 149 | @builtins.property |
| 150 | def {0}(self): |
| 151 | start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py" ; |
| 152 | |
| 153 | /// Second part of the template for equally-sized case, accessing a single |
| 154 | /// element: |
| 155 | /// {0} is either 'operand' or 'result'. |
| 156 | constexpr const char *opVariadicEqualSimpleTemplate = R"Py( |
| 157 | return self.operation.{0}s[start] |
| 158 | )Py" ; |
| 159 | |
| 160 | /// Second part of the template for equally-sized case, accessing a variadic |
| 161 | /// group: |
| 162 | /// {0} is either 'operand' or 'result'. |
| 163 | constexpr const char *opVariadicEqualVariadicTemplate = R"Py( |
| 164 | return self.operation.{0}s[start:start + elements_per_group] |
| 165 | )Py" ; |
| 166 | |
| 167 | /// Template for an attribute-sized group accessor: |
| 168 | /// {0} is the name of the accessor; |
| 169 | /// {1} is either 'operand' or 'result'; |
| 170 | /// {2} is the position of the group in the group list; |
| 171 | /// {3} is a return suffix (expected [0] for single-element, empty for |
| 172 | /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). |
| 173 | constexpr const char *opVariadicSegmentTemplate = R"Py( |
| 174 | @builtins.property |
| 175 | def {0}(self): |
| 176 | {1}_range = _ods_segmented_accessor( |
| 177 | self.operation.{1}s, |
| 178 | self.operation.attributes["{1}SegmentSizes"], {2}) |
| 179 | return {1}_range{3} |
| 180 | )Py" ; |
| 181 | |
| 182 | /// Template for a suffix when accessing an optional element in the |
| 183 | /// attribute-sized case: |
| 184 | /// {0} is either 'operand' or 'result'; |
| 185 | constexpr const char *opVariadicSegmentOptionalTrailingTemplate = |
| 186 | R"Py([0] if len({0}_range) > 0 else None)Py" ; |
| 187 | |
| 188 | /// Template for an operation attribute getter: |
| 189 | /// {0} is the name of the attribute sanitized for Python; |
| 190 | /// {1} is the original name of the attribute. |
| 191 | constexpr const char *attributeGetterTemplate = R"Py( |
| 192 | @builtins.property |
| 193 | def {0}(self): |
| 194 | return self.operation.attributes["{1}"] |
| 195 | )Py" ; |
| 196 | |
| 197 | /// Template for an optional operation attribute getter: |
| 198 | /// {0} is the name of the attribute sanitized for Python; |
| 199 | /// {1} is the original name of the attribute. |
| 200 | constexpr const char *optionalAttributeGetterTemplate = R"Py( |
| 201 | @builtins.property |
| 202 | def {0}(self): |
| 203 | if "{1}" not in self.operation.attributes: |
| 204 | return None |
| 205 | return self.operation.attributes["{1}"] |
| 206 | )Py" ; |
| 207 | |
| 208 | /// Template for a getter of a unit operation attribute, returns True of the |
| 209 | /// unit attribute is present, False otherwise (unit attributes have meaning |
| 210 | /// by mere presence): |
| 211 | /// {0} is the name of the attribute sanitized for Python, |
| 212 | /// {1} is the original name of the attribute. |
| 213 | constexpr const char *unitAttributeGetterTemplate = R"Py( |
| 214 | @builtins.property |
| 215 | def {0}(self): |
| 216 | return "{1}" in self.operation.attributes |
| 217 | )Py" ; |
| 218 | |
| 219 | /// Template for an operation attribute setter: |
| 220 | /// {0} is the name of the attribute sanitized for Python; |
| 221 | /// {1} is the original name of the attribute. |
| 222 | constexpr const char *attributeSetterTemplate = R"Py( |
| 223 | @{0}.setter |
| 224 | def {0}(self, value): |
| 225 | if value is None: |
| 226 | raise ValueError("'None' not allowed as value for mandatory attributes") |
| 227 | self.operation.attributes["{1}"] = value |
| 228 | )Py" ; |
| 229 | |
| 230 | /// Template for a setter of an optional operation attribute, setting to None |
| 231 | /// removes the attribute: |
| 232 | /// {0} is the name of the attribute sanitized for Python; |
| 233 | /// {1} is the original name of the attribute. |
| 234 | constexpr const char *optionalAttributeSetterTemplate = R"Py( |
| 235 | @{0}.setter |
| 236 | def {0}(self, value): |
| 237 | if value is not None: |
| 238 | self.operation.attributes["{1}"] = value |
| 239 | elif "{1}" in self.operation.attributes: |
| 240 | del self.operation.attributes["{1}"] |
| 241 | )Py" ; |
| 242 | |
| 243 | /// Template for a setter of a unit operation attribute, setting to None or |
| 244 | /// False removes the attribute: |
| 245 | /// {0} is the name of the attribute sanitized for Python; |
| 246 | /// {1} is the original name of the attribute. |
| 247 | constexpr const char *unitAttributeSetterTemplate = R"Py( |
| 248 | @{0}.setter |
| 249 | def {0}(self, value): |
| 250 | if bool(value): |
| 251 | self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() |
| 252 | elif "{1}" in self.operation.attributes: |
| 253 | del self.operation.attributes["{1}"] |
| 254 | )Py" ; |
| 255 | |
| 256 | /// Template for a deleter of an optional or a unit operation attribute, removes |
| 257 | /// the attribute from the operation: |
| 258 | /// {0} is the name of the attribute sanitized for Python; |
| 259 | /// {1} is the original name of the attribute. |
| 260 | constexpr const char *attributeDeleterTemplate = R"Py( |
| 261 | @{0}.deleter |
| 262 | def {0}(self): |
| 263 | del self.operation.attributes["{1}"] |
| 264 | )Py" ; |
| 265 | |
| 266 | constexpr const char *regionAccessorTemplate = R"Py( |
| 267 | @builtins.property |
| 268 | def {0}(self): |
| 269 | return self.regions[{1}] |
| 270 | )Py" ; |
| 271 | |
| 272 | constexpr const char *valueBuilderTemplate = R"Py( |
| 273 | def {0}({2}) -> {4}: |
| 274 | return {1}({3}){5} |
| 275 | )Py" ; |
| 276 | |
| 277 | constexpr const char *valueBuilderVariadicTemplate = R"Py( |
| 278 | def {0}({2}) -> {4}: |
| 279 | return _get_op_result_or_op_results({1}({3})) |
| 280 | )Py" ; |
| 281 | |
| 282 | static llvm::cl::OptionCategory |
| 283 | clOpPythonBindingCat("Options for -gen-python-op-bindings" ); |
| 284 | |
| 285 | static llvm::cl::opt<std::string> |
| 286 | clDialectName("bind-dialect" , |
| 287 | llvm::cl::desc("The dialect to run the generator for" ), |
| 288 | llvm::cl::init(Val: "" ), llvm::cl::cat(clOpPythonBindingCat)); |
| 289 | |
| 290 | static llvm::cl::opt<std::string> clDialectExtensionName( |
| 291 | "dialect-extension" , llvm::cl::desc("The prefix of the dialect extension" ), |
| 292 | llvm::cl::init(Val: "" ), llvm::cl::cat(clOpPythonBindingCat)); |
| 293 | |
| 294 | using AttributeClasses = DenseMap<StringRef, StringRef>; |
| 295 | |
| 296 | /// Checks whether `str` would shadow a generated variable or attribute |
| 297 | /// part of the OpView API. |
| 298 | static bool isODSReserved(StringRef str) { |
| 299 | static llvm::StringSet<> reserved( |
| 300 | {"attributes" , "create" , "context" , "ip" , "operands" , "print" , "get_asm" , |
| 301 | "loc" , "verify" , "regions" , "results" , "self" , "operation" , |
| 302 | "DIALECT_NAMESPACE" , "OPERATION_NAME" }); |
| 303 | return str.starts_with(Prefix: "_ods_" ) || str.ends_with(Suffix: "_ods" ) || |
| 304 | reserved.contains(key: str); |
| 305 | } |
| 306 | |
| 307 | /// Modifies the `name` in a way that it becomes suitable for Python bindings |
| 308 | /// (does not change the `name` if it already is suitable) and returns the |
| 309 | /// modified version. |
| 310 | static std::string sanitizeName(StringRef name) { |
| 311 | std::string processedStr = name.str(); |
| 312 | std::replace_if( |
| 313 | first: processedStr.begin(), last: processedStr.end(), |
| 314 | pred: [](char c) { return !llvm::isAlnum(C: c); }, new_value: '_'); |
| 315 | |
| 316 | if (llvm::isDigit(C: *processedStr.begin())) |
| 317 | return "_" + processedStr; |
| 318 | |
| 319 | if (isPythonReserved(str: processedStr) || isODSReserved(str: processedStr)) |
| 320 | return processedStr + "_" ; |
| 321 | return processedStr; |
| 322 | } |
| 323 | |
| 324 | static std::string attrSizedTraitForKind(const char *kind) { |
| 325 | return formatv(Fmt: "::mlir::OpTrait::AttrSized{0}{1}Segments" , |
| 326 | Vals: StringRef(kind).take_front().upper(), |
| 327 | Vals: StringRef(kind).drop_front()); |
| 328 | } |
| 329 | |
| 330 | /// Emits accessors to "elements" of an Op definition. Currently, the supported |
| 331 | /// elements are operands and results, indicated by `kind`, which must be either |
| 332 | /// `operand` or `result` and is used verbatim in the emitted code. |
| 333 | static void emitElementAccessors( |
| 334 | const Operator &op, raw_ostream &os, const char *kind, |
| 335 | unsigned numVariadicGroups, unsigned numElements, |
| 336 | llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| 337 | getElement) { |
| 338 | assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand" , "result" }, |
| 339 | kind) && |
| 340 | "unsupported kind" ); |
| 341 | |
| 342 | // Traits indicating how to process variadic elements. |
| 343 | std::string sameSizeTrait = formatv(Fmt: "::mlir::OpTrait::SameVariadic{0}{1}Size" , |
| 344 | Vals: StringRef(kind).take_front().upper(), |
| 345 | Vals: StringRef(kind).drop_front()); |
| 346 | std::string attrSizedTrait = attrSizedTraitForKind(kind); |
| 347 | |
| 348 | // If there is only one variable-length element group, its size can be |
| 349 | // inferred from the total number of elements. If there are none, the |
| 350 | // generation is straightforward. |
| 351 | if (numVariadicGroups <= 1) { |
| 352 | bool seenVariableLength = false; |
| 353 | for (unsigned i = 0; i < numElements; ++i) { |
| 354 | const NamedTypeConstraint &element = getElement(op, i); |
| 355 | if (element.isVariableLength()) |
| 356 | seenVariableLength = true; |
| 357 | if (element.name.empty()) |
| 358 | continue; |
| 359 | if (element.isVariableLength()) { |
| 360 | os << formatv(Fmt: element.isOptional() ? opOneOptionalTemplate |
| 361 | : opOneVariadicTemplate, |
| 362 | Vals: sanitizeName(name: element.name), Vals&: kind, Vals&: numElements, Vals&: i); |
| 363 | } else if (seenVariableLength) { |
| 364 | os << formatv(Fmt: opSingleAfterVariableTemplate, Vals: sanitizeName(name: element.name), |
| 365 | Vals&: kind, Vals&: numElements, Vals&: i); |
| 366 | } else { |
| 367 | os << formatv(Fmt: opSingleTemplate, Vals: sanitizeName(name: element.name), Vals&: kind, Vals&: i); |
| 368 | } |
| 369 | } |
| 370 | return; |
| 371 | } |
| 372 | |
| 373 | // Handle the operations where variadic groups have the same size. |
| 374 | if (op.getTrait(trait: sameSizeTrait)) { |
| 375 | // Count the number of simple elements |
| 376 | unsigned numSimpleLength = 0; |
| 377 | for (unsigned i = 0; i < numElements; ++i) { |
| 378 | const NamedTypeConstraint &element = getElement(op, i); |
| 379 | if (!element.isVariableLength()) { |
| 380 | ++numSimpleLength; |
| 381 | } |
| 382 | } |
| 383 | |
| 384 | // Generate the accessors |
| 385 | int numPrecedingSimple = 0; |
| 386 | int numPrecedingVariadic = 0; |
| 387 | for (unsigned i = 0; i < numElements; ++i) { |
| 388 | const NamedTypeConstraint &element = getElement(op, i); |
| 389 | if (!element.name.empty()) { |
| 390 | os << formatv(Fmt: opVariadicEqualPrefixTemplate, Vals: sanitizeName(name: element.name), |
| 391 | Vals&: kind, Vals&: numSimpleLength, Vals&: numVariadicGroups, |
| 392 | Vals&: numPrecedingSimple, Vals&: numPrecedingVariadic); |
| 393 | os << formatv(Fmt: element.isVariableLength() |
| 394 | ? opVariadicEqualVariadicTemplate |
| 395 | : opVariadicEqualSimpleTemplate, |
| 396 | Vals&: kind); |
| 397 | } |
| 398 | if (element.isVariableLength()) |
| 399 | ++numPrecedingVariadic; |
| 400 | else |
| 401 | ++numPrecedingSimple; |
| 402 | } |
| 403 | return; |
| 404 | } |
| 405 | |
| 406 | // Handle the operations where the size of groups (variadic or not) is |
| 407 | // provided as an attribute. For non-variadic elements, make sure to return |
| 408 | // an element rather than a singleton container. |
| 409 | if (op.getTrait(trait: attrSizedTrait)) { |
| 410 | for (unsigned i = 0; i < numElements; ++i) { |
| 411 | const NamedTypeConstraint &element = getElement(op, i); |
| 412 | if (element.name.empty()) |
| 413 | continue; |
| 414 | std::string trailing; |
| 415 | if (!element.isVariableLength()) |
| 416 | trailing = "[0]" ; |
| 417 | else if (element.isOptional()) |
| 418 | trailing = std::string( |
| 419 | formatv(Fmt: opVariadicSegmentOptionalTrailingTemplate, Vals&: kind)); |
| 420 | os << formatv(Fmt: opVariadicSegmentTemplate, Vals: sanitizeName(name: element.name), Vals&: kind, |
| 421 | Vals&: i, Vals&: trailing); |
| 422 | } |
| 423 | return; |
| 424 | } |
| 425 | |
| 426 | llvm::PrintFatalError(Msg: "unsupported " + llvm::Twine(kind) + " structure" ); |
| 427 | } |
| 428 | |
| 429 | /// Free function helpers accessing Operator components. |
| 430 | static int getNumOperands(const Operator &op) { return op.getNumOperands(); } |
| 431 | static const NamedTypeConstraint &getOperand(const Operator &op, int i) { |
| 432 | return op.getOperand(index: i); |
| 433 | } |
| 434 | static int getNumResults(const Operator &op) { return op.getNumResults(); } |
| 435 | static const NamedTypeConstraint &getResult(const Operator &op, int i) { |
| 436 | return op.getResult(index: i); |
| 437 | } |
| 438 | |
| 439 | /// Emits accessors to Op operands. |
| 440 | static void emitOperandAccessors(const Operator &op, raw_ostream &os) { |
| 441 | emitElementAccessors(op, os, kind: "operand" , numVariadicGroups: op.getNumVariableLengthOperands(), |
| 442 | numElements: getNumOperands(op), getElement: getOperand); |
| 443 | } |
| 444 | |
| 445 | /// Emits accessors Op results. |
| 446 | static void emitResultAccessors(const Operator &op, raw_ostream &os) { |
| 447 | emitElementAccessors(op, os, kind: "result" , numVariadicGroups: op.getNumVariableLengthResults(), |
| 448 | numElements: getNumResults(op), getElement: getResult); |
| 449 | } |
| 450 | |
| 451 | /// Emits accessors to Op attributes. |
| 452 | static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { |
| 453 | for (const auto &namedAttr : op.getAttributes()) { |
| 454 | // Skip "derived" attributes because they are just C++ functions that we |
| 455 | // don't currently expose. |
| 456 | if (namedAttr.attr.isDerivedAttr()) |
| 457 | continue; |
| 458 | |
| 459 | if (namedAttr.name.empty()) |
| 460 | continue; |
| 461 | |
| 462 | std::string sanitizedName = sanitizeName(name: namedAttr.name); |
| 463 | |
| 464 | // Unit attributes are handled specially. |
| 465 | if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr" ) { |
| 466 | os << formatv(Fmt: unitAttributeGetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
| 467 | os << formatv(Fmt: unitAttributeSetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
| 468 | os << formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
| 469 | continue; |
| 470 | } |
| 471 | |
| 472 | if (namedAttr.attr.isOptional()) { |
| 473 | os << formatv(Fmt: optionalAttributeGetterTemplate, Vals&: sanitizedName, |
| 474 | Vals: namedAttr.name); |
| 475 | os << formatv(Fmt: optionalAttributeSetterTemplate, Vals&: sanitizedName, |
| 476 | Vals: namedAttr.name); |
| 477 | os << formatv(Fmt: attributeDeleterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
| 478 | } else { |
| 479 | os << formatv(Fmt: attributeGetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
| 480 | os << formatv(Fmt: attributeSetterTemplate, Vals&: sanitizedName, Vals: namedAttr.name); |
| 481 | // Non-optional attributes cannot be deleted. |
| 482 | } |
| 483 | } |
| 484 | } |
| 485 | |
| 486 | /// Template for the default auto-generated builder. |
| 487 | /// {0} is a comma-separated list of builder arguments, including the trailing |
| 488 | /// `loc` and `ip`; |
| 489 | /// {1} is the code populating `operands`, `results` and `attributes`, |
| 490 | /// `successors` fields. |
| 491 | constexpr const char *initTemplate = R"Py( |
| 492 | def __init__(self, {0}): |
| 493 | operands = [] |
| 494 | results = [] |
| 495 | attributes = {{} |
| 496 | regions = None |
| 497 | {1} |
| 498 | super().__init__({2}) |
| 499 | )Py" ; |
| 500 | |
| 501 | /// Template for appending a single element to the operand/result list. |
| 502 | /// {0} is the field name. |
| 503 | constexpr const char *singleOperandAppendTemplate = "operands.append({0})" ; |
| 504 | constexpr const char *singleResultAppendTemplate = "results.append({0})" ; |
| 505 | |
| 506 | /// Template for appending an optional element to the operand/result list. |
| 507 | /// {0} is the field name. |
| 508 | constexpr const char *optionalAppendOperandTemplate = |
| 509 | "if {0} is not None: operands.append({0})" ; |
| 510 | constexpr const char *optionalAppendAttrSizedOperandsTemplate = |
| 511 | "operands.append({0})" ; |
| 512 | constexpr const char *optionalAppendResultTemplate = |
| 513 | "if {0} is not None: results.append({0})" ; |
| 514 | |
| 515 | /// Template for appending a list of elements to the operand/result list. |
| 516 | /// {0} is the field name. |
| 517 | constexpr const char *multiOperandAppendTemplate = |
| 518 | "operands.extend(_get_op_results_or_values({0}))" ; |
| 519 | constexpr const char *multiOperandAppendPackTemplate = |
| 520 | "operands.append(_get_op_results_or_values({0}))" ; |
| 521 | constexpr const char *multiResultAppendTemplate = "results.extend({0})" ; |
| 522 | |
| 523 | /// Template for attribute builder from raw input in the operation builder. |
| 524 | /// {0} is the builder argument name; |
| 525 | /// {1} is the attribute builder from raw; |
| 526 | /// {2} is the attribute builder from raw. |
| 527 | /// Use the value the user passed in if either it is already an Attribute or |
| 528 | /// there is no method registered to make it an Attribute. |
| 529 | constexpr const char *initAttributeWithBuilderTemplate = |
| 530 | R"Py(attributes["{1}"] = ({0} if ( |
| 531 | isinstance({0}, _ods_ir.Attribute) or |
| 532 | not _ods_ir.AttrBuilder.contains('{2}')) else |
| 533 | _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py" ; |
| 534 | |
| 535 | /// Template for attribute builder from raw input for optional attribute in the |
| 536 | /// operation builder. |
| 537 | /// {0} is the builder argument name; |
| 538 | /// {1} is the attribute builder from raw; |
| 539 | /// {2} is the attribute builder from raw. |
| 540 | /// Use the value the user passed in if either it is already an Attribute or |
| 541 | /// there is no method registered to make it an Attribute. |
| 542 | constexpr const char *initOptionalAttributeWithBuilderTemplate = |
| 543 | R"Py(if {0} is not None: attributes["{1}"] = ({0} if ( |
| 544 | isinstance({0}, _ods_ir.Attribute) or |
| 545 | not _ods_ir.AttrBuilder.contains('{2}')) else |
| 546 | _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py" ; |
| 547 | |
| 548 | constexpr const char *initUnitAttributeTemplate = |
| 549 | R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( |
| 550 | _ods_get_default_loc_context(loc)))Py" ; |
| 551 | |
| 552 | /// Template to initialize the successors list in the builder if there are any |
| 553 | /// successors. |
| 554 | /// {0} is the value to initialize the successors list to. |
| 555 | constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py" ; |
| 556 | |
| 557 | /// Template to append or extend the list of successors in the builder. |
| 558 | /// {0} is the list method ('append' or 'extend'); |
| 559 | /// {1} is the value to add. |
| 560 | constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py" ; |
| 561 | |
| 562 | /// Returns true if the SameArgumentAndResultTypes trait can be used to infer |
| 563 | /// result types of the given operation. |
| 564 | static bool hasSameArgumentAndResultTypes(const Operator &op) { |
| 565 | return op.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType" ) && |
| 566 | op.getNumVariableLengthResults() == 0; |
| 567 | } |
| 568 | |
| 569 | /// Returns true if the FirstAttrDerivedResultType trait can be used to infer |
| 570 | /// result types of the given operation. |
| 571 | static bool hasFirstAttrDerivedResultTypes(const Operator &op) { |
| 572 | return op.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType" ) && |
| 573 | op.getNumVariableLengthResults() == 0; |
| 574 | } |
| 575 | |
| 576 | /// Returns true if the InferTypeOpInterface can be used to infer result types |
| 577 | /// of the given operation. |
| 578 | static bool hasInferTypeInterface(const Operator &op) { |
| 579 | return op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait" ) && |
| 580 | op.getNumRegions() == 0; |
| 581 | } |
| 582 | |
| 583 | /// Returns true if there is a trait or interface that can be used to infer |
| 584 | /// result types of the given operation. |
| 585 | static bool canInferType(const Operator &op) { |
| 586 | return hasSameArgumentAndResultTypes(op) || |
| 587 | hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); |
| 588 | } |
| 589 | |
| 590 | /// Populates `builderArgs` with result names if the builder is expected to |
| 591 | /// accept them as arguments. |
| 592 | static void |
| 593 | populateBuilderArgsResults(const Operator &op, |
| 594 | SmallVectorImpl<std::string> &builderArgs) { |
| 595 | if (canInferType(op)) |
| 596 | return; |
| 597 | |
| 598 | for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| 599 | std::string name = op.getResultName(index: i).str(); |
| 600 | if (name.empty()) { |
| 601 | if (op.getNumResults() == 1) { |
| 602 | // Special case for one result, make the default name be 'result' |
| 603 | // to properly match the built-in result accessor. |
| 604 | name = "result" ; |
| 605 | } else { |
| 606 | name = formatv(Fmt: "_gen_res_{0}" , Vals&: i); |
| 607 | } |
| 608 | } |
| 609 | name = sanitizeName(name); |
| 610 | builderArgs.push_back(Elt: name); |
| 611 | } |
| 612 | } |
| 613 | |
| 614 | /// Populates `builderArgs` with the Python-compatible names of builder function |
| 615 | /// arguments using intermixed attributes and operands in the same order as they |
| 616 | /// appear in the `arguments` field of the op definition. Additionally, |
| 617 | /// `operandNames` is populated with names of operands in their order of |
| 618 | /// appearance. |
| 619 | static void populateBuilderArgs(const Operator &op, |
| 620 | SmallVectorImpl<std::string> &builderArgs, |
| 621 | SmallVectorImpl<std::string> &operandNames) { |
| 622 | for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| 623 | std::string name = op.getArgName(index: i).str(); |
| 624 | if (name.empty()) |
| 625 | name = formatv(Fmt: "_gen_arg_{0}" , Vals&: i); |
| 626 | name = sanitizeName(name); |
| 627 | builderArgs.push_back(Elt: name); |
| 628 | if (!isa<NamedAttribute *>(Val: op.getArg(index: i))) |
| 629 | operandNames.push_back(Elt: name); |
| 630 | } |
| 631 | } |
| 632 | |
| 633 | /// Populates `builderArgs` with the Python-compatible names of builder function |
| 634 | /// successor arguments. Additionally, `successorArgNames` is also populated. |
| 635 | static void |
| 636 | populateBuilderArgsSuccessors(const Operator &op, |
| 637 | SmallVectorImpl<std::string> &builderArgs, |
| 638 | SmallVectorImpl<std::string> &successorArgNames) { |
| 639 | |
| 640 | for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { |
| 641 | NamedSuccessor successor = op.getSuccessor(index: i); |
| 642 | std::string name = std::string(successor.name); |
| 643 | if (name.empty()) |
| 644 | name = formatv(Fmt: "_gen_successor_{0}" , Vals&: i); |
| 645 | name = sanitizeName(name); |
| 646 | builderArgs.push_back(Elt: name); |
| 647 | successorArgNames.push_back(Elt: name); |
| 648 | } |
| 649 | } |
| 650 | |
| 651 | /// Populates `builderLines` with additional lines that are required in the |
| 652 | /// builder to set up operation attributes. `argNames` is expected to contain |
| 653 | /// the names of builder arguments that correspond to op arguments, i.e. to the |
| 654 | /// operands and attributes in the same order as they appear in the `arguments` |
| 655 | /// field. |
| 656 | static void |
| 657 | populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames, |
| 658 | SmallVectorImpl<std::string> &builderLines) { |
| 659 | builderLines.push_back(Elt: "_ods_context = _ods_get_default_loc_context(loc)" ); |
| 660 | for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| 661 | Argument arg = op.getArg(index: i); |
| 662 | auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg); |
| 663 | if (!attribute) |
| 664 | continue; |
| 665 | |
| 666 | // Unit attributes are handled specially. |
| 667 | if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr" ) { |
| 668 | builderLines.push_back( |
| 669 | Elt: formatv(Fmt: initUnitAttributeTemplate, Vals&: attribute->name, Vals: argNames[i])); |
| 670 | continue; |
| 671 | } |
| 672 | |
| 673 | builderLines.push_back(Elt: formatv( |
| 674 | Fmt: attribute->attr.isOptional() || attribute->attr.hasDefaultValue() |
| 675 | ? initOptionalAttributeWithBuilderTemplate |
| 676 | : initAttributeWithBuilderTemplate, |
| 677 | Vals: argNames[i], Vals&: attribute->name, Vals: attribute->attr.getAttrDefName())); |
| 678 | } |
| 679 | } |
| 680 | |
| 681 | /// Populates `builderLines` with additional lines that are required in the |
| 682 | /// builder to set up successors. successorArgNames is expected to correspond |
| 683 | /// to the Python argument name for each successor on the op. |
| 684 | static void |
| 685 | populateBuilderLinesSuccessors(const Operator &op, |
| 686 | ArrayRef<std::string> successorArgNames, |
| 687 | SmallVectorImpl<std::string> &builderLines) { |
| 688 | if (successorArgNames.empty()) { |
| 689 | builderLines.push_back(Elt: formatv(Fmt: initSuccessorsTemplate, Vals: "None" )); |
| 690 | return; |
| 691 | } |
| 692 | |
| 693 | builderLines.push_back(Elt: formatv(Fmt: initSuccessorsTemplate, Vals: "[]" )); |
| 694 | for (int i = 0, e = successorArgNames.size(); i < e; ++i) { |
| 695 | auto &argName = successorArgNames[i]; |
| 696 | const NamedSuccessor &successor = op.getSuccessor(index: i); |
| 697 | builderLines.push_back(Elt: formatv(Fmt: addSuccessorTemplate, |
| 698 | Vals: successor.isVariadic() ? "extend" : "append" , |
| 699 | Vals: argName)); |
| 700 | } |
| 701 | } |
| 702 | |
| 703 | /// Populates `builderLines` with additional lines that are required in the |
| 704 | /// builder to set up op operands. |
| 705 | static void |
| 706 | populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names, |
| 707 | SmallVectorImpl<std::string> &builderLines) { |
| 708 | bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "operand" )) != nullptr; |
| 709 | |
| 710 | // For each element, find or generate a name. |
| 711 | for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
| 712 | const NamedTypeConstraint &element = op.getOperand(index: i); |
| 713 | std::string name = names[i]; |
| 714 | |
| 715 | // Choose the formatting string based on the element kind. |
| 716 | StringRef formatString; |
| 717 | if (!element.isVariableLength()) { |
| 718 | formatString = singleOperandAppendTemplate; |
| 719 | } else if (element.isOptional()) { |
| 720 | if (sizedSegments) { |
| 721 | formatString = optionalAppendAttrSizedOperandsTemplate; |
| 722 | } else { |
| 723 | formatString = optionalAppendOperandTemplate; |
| 724 | } |
| 725 | } else { |
| 726 | assert(element.isVariadic() && "unhandled element group type" ); |
| 727 | // If emitting with sizedSegments, then we add the actual list-typed |
| 728 | // element. Otherwise, we extend the actual operands. |
| 729 | if (sizedSegments) { |
| 730 | formatString = multiOperandAppendPackTemplate; |
| 731 | } else { |
| 732 | formatString = multiOperandAppendTemplate; |
| 733 | } |
| 734 | } |
| 735 | |
| 736 | builderLines.push_back(Elt: formatv(Fmt: formatString.data(), Vals&: name)); |
| 737 | } |
| 738 | } |
| 739 | |
| 740 | /// Python code template for deriving the operation result types from its |
| 741 | /// attribute: |
| 742 | /// - {0} is the name of the attribute from which to derive the types. |
| 743 | constexpr const char *deriveTypeFromAttrTemplate = |
| 744 | R"Py(_ods_result_type_source_attr = attributes["{0}"] |
| 745 | _ods_derived_result_type = ( |
| 746 | _ods_ir.TypeAttr(_ods_result_type_source_attr).value |
| 747 | if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else |
| 748 | _ods_result_type_source_attr.type))Py" ; |
| 749 | |
| 750 | /// Python code template appending {0} type {1} times to the results list. |
| 751 | constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})" ; |
| 752 | |
| 753 | /// Appends the given multiline string as individual strings into |
| 754 | /// `builderLines`. |
| 755 | static void appendLineByLine(StringRef string, |
| 756 | SmallVectorImpl<std::string> &builderLines) { |
| 757 | |
| 758 | std::pair<StringRef, StringRef> split = std::make_pair(x&: string, y&: string); |
| 759 | do { |
| 760 | split = split.second.split(Separator: '\n'); |
| 761 | builderLines.push_back(Elt: split.first.str()); |
| 762 | } while (!split.second.empty()); |
| 763 | } |
| 764 | |
| 765 | /// Populates `builderLines` with additional lines that are required in the |
| 766 | /// builder to set up op results. |
| 767 | static void |
| 768 | populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names, |
| 769 | SmallVectorImpl<std::string> &builderLines) { |
| 770 | bool sizedSegments = op.getTrait(trait: attrSizedTraitForKind(kind: "result" )) != nullptr; |
| 771 | |
| 772 | if (hasSameArgumentAndResultTypes(op)) { |
| 773 | builderLines.push_back(Elt: formatv(Fmt: appendSameResultsTemplate, |
| 774 | Vals: "operands[0].type" , Vals: op.getNumResults())); |
| 775 | return; |
| 776 | } |
| 777 | |
| 778 | if (hasFirstAttrDerivedResultTypes(op)) { |
| 779 | const NamedAttribute &firstAttr = op.getAttribute(index: 0); |
| 780 | assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " |
| 781 | "from which the type is derived" ); |
| 782 | appendLineByLine(string: formatv(Fmt: deriveTypeFromAttrTemplate, Vals: firstAttr.name).str(), |
| 783 | builderLines); |
| 784 | builderLines.push_back(Elt: formatv(Fmt: appendSameResultsTemplate, |
| 785 | Vals: "_ods_derived_result_type" , |
| 786 | Vals: op.getNumResults())); |
| 787 | return; |
| 788 | } |
| 789 | |
| 790 | if (hasInferTypeInterface(op)) |
| 791 | return; |
| 792 | |
| 793 | // For each element, find or generate a name. |
| 794 | for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| 795 | const NamedTypeConstraint &element = op.getResult(index: i); |
| 796 | std::string name = names[i]; |
| 797 | |
| 798 | // Choose the formatting string based on the element kind. |
| 799 | StringRef formatString; |
| 800 | if (!element.isVariableLength()) { |
| 801 | formatString = singleResultAppendTemplate; |
| 802 | } else if (element.isOptional()) { |
| 803 | formatString = optionalAppendResultTemplate; |
| 804 | } else { |
| 805 | assert(element.isVariadic() && "unhandled element group type" ); |
| 806 | // If emitting with sizedSegments, then we add the actual list-typed |
| 807 | // element. Otherwise, we extend the actual operands. |
| 808 | if (sizedSegments) { |
| 809 | formatString = singleResultAppendTemplate; |
| 810 | } else { |
| 811 | formatString = multiResultAppendTemplate; |
| 812 | } |
| 813 | } |
| 814 | |
| 815 | builderLines.push_back(Elt: formatv(Fmt: formatString.data(), Vals&: name)); |
| 816 | } |
| 817 | } |
| 818 | |
| 819 | /// If the operation has variadic regions, adds a builder argument to specify |
| 820 | /// the number of those regions and builder lines to forward it to the generic |
| 821 | /// constructor. |
| 822 | static void populateBuilderRegions(const Operator &op, |
| 823 | SmallVectorImpl<std::string> &builderArgs, |
| 824 | SmallVectorImpl<std::string> &builderLines) { |
| 825 | if (op.hasNoVariadicRegions()) |
| 826 | return; |
| 827 | |
| 828 | // This is currently enforced when Operator is constructed. |
| 829 | assert(op.getNumVariadicRegions() == 1 && |
| 830 | op.getRegion(op.getNumRegions() - 1).isVariadic() && |
| 831 | "expected the last region to be varidic" ); |
| 832 | |
| 833 | const NamedRegion ®ion = op.getRegion(index: op.getNumRegions() - 1); |
| 834 | std::string name = |
| 835 | ("num_" + region.name.take_front().lower() + region.name.drop_front()) |
| 836 | .str(); |
| 837 | builderArgs.push_back(Elt: name); |
| 838 | builderLines.push_back( |
| 839 | Elt: formatv(Fmt: "regions = {0} + {1}" , Vals: op.getNumRegions() - 1, Vals&: name)); |
| 840 | } |
| 841 | |
| 842 | /// Emits a default builder constructing an operation from the list of its |
| 843 | /// result types, followed by a list of its operands. Returns vector |
| 844 | /// of fully built functionArgs for downstream users (to save having to |
| 845 | /// rebuild anew). |
| 846 | static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op, |
| 847 | raw_ostream &os) { |
| 848 | SmallVector<std::string> builderArgs; |
| 849 | SmallVector<std::string> builderLines; |
| 850 | SmallVector<std::string> operandArgNames; |
| 851 | SmallVector<std::string> successorArgNames; |
| 852 | builderArgs.reserve(N: op.getNumOperands() + op.getNumResults() + |
| 853 | op.getNumNativeAttributes() + op.getNumSuccessors()); |
| 854 | populateBuilderArgsResults(op, builderArgs); |
| 855 | size_t numResultArgs = builderArgs.size(); |
| 856 | populateBuilderArgs(op, builderArgs, operandNames&: operandArgNames); |
| 857 | size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; |
| 858 | populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); |
| 859 | |
| 860 | populateBuilderLinesOperand(op, names: operandArgNames, builderLines); |
| 861 | populateBuilderLinesAttr(op, argNames: ArrayRef(builderArgs).drop_front(N: numResultArgs), |
| 862 | builderLines); |
| 863 | populateBuilderLinesResult( |
| 864 | op, names: ArrayRef(builderArgs).take_front(N: numResultArgs), builderLines); |
| 865 | populateBuilderLinesSuccessors(op, successorArgNames, builderLines); |
| 866 | populateBuilderRegions(op, builderArgs, builderLines); |
| 867 | |
| 868 | // Layout of builderArgs vector elements: |
| 869 | // [ result_args operand_attr_args successor_args regions ] |
| 870 | |
| 871 | // Determine whether the argument corresponding to a given index into the |
| 872 | // builderArgs vector is a python keyword argument or not. |
| 873 | auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { |
| 874 | // All result, successor, and region arguments are positional arguments. |
| 875 | if ((builderArgIndex < numResultArgs) || |
| 876 | (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) |
| 877 | return false; |
| 878 | // Keyword arguments: |
| 879 | // - optional named attributes (including unit attributes) |
| 880 | // - default-valued named attributes |
| 881 | // - optional operands |
| 882 | Argument a = op.getArg(index: builderArgIndex - numResultArgs); |
| 883 | if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: a)) |
| 884 | return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); |
| 885 | if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: a)) |
| 886 | return ntype->isOptional(); |
| 887 | return false; |
| 888 | }; |
| 889 | |
| 890 | // StringRefs in functionArgs refer to strings allocated by builderArgs. |
| 891 | SmallVector<StringRef> functionArgs; |
| 892 | |
| 893 | // Add positional arguments. |
| 894 | for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { |
| 895 | if (!isKeywordArgFn(i)) |
| 896 | functionArgs.push_back(Elt: builderArgs[i]); |
| 897 | } |
| 898 | |
| 899 | // Add a bare '*' to indicate that all following arguments must be keyword |
| 900 | // arguments. |
| 901 | functionArgs.push_back(Elt: "*" ); |
| 902 | |
| 903 | // Add a default 'None' value to each keyword arg string, and then add to the |
| 904 | // function args list. |
| 905 | for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { |
| 906 | if (isKeywordArgFn(i)) { |
| 907 | builderArgs[i].append(s: "=None" ); |
| 908 | functionArgs.push_back(Elt: builderArgs[i]); |
| 909 | } |
| 910 | } |
| 911 | functionArgs.push_back(Elt: "loc=None" ); |
| 912 | functionArgs.push_back(Elt: "ip=None" ); |
| 913 | |
| 914 | SmallVector<std::string> initArgs; |
| 915 | initArgs.push_back(Elt: "self.OPERATION_NAME" ); |
| 916 | initArgs.push_back(Elt: "self._ODS_REGIONS" ); |
| 917 | initArgs.push_back(Elt: "self._ODS_OPERAND_SEGMENTS" ); |
| 918 | initArgs.push_back(Elt: "self._ODS_RESULT_SEGMENTS" ); |
| 919 | initArgs.push_back(Elt: "attributes=attributes" ); |
| 920 | if (!hasInferTypeInterface(op)) |
| 921 | initArgs.push_back(Elt: "results=results" ); |
| 922 | initArgs.push_back(Elt: "operands=operands" ); |
| 923 | initArgs.push_back(Elt: "successors=_ods_successors" ); |
| 924 | initArgs.push_back(Elt: "regions=regions" ); |
| 925 | initArgs.push_back(Elt: "loc=loc" ); |
| 926 | initArgs.push_back(Elt: "ip=ip" ); |
| 927 | |
| 928 | os << formatv(Fmt: initTemplate, Vals: llvm::join(R&: functionArgs, Separator: ", " ), |
| 929 | Vals: llvm::join(R&: builderLines, Separator: "\n " ), Vals: llvm::join(R&: initArgs, Separator: ", " )); |
| 930 | return llvm::to_vector<8>( |
| 931 | Range: llvm::map_range(C&: functionArgs, F: [](StringRef s) { return s.str(); })); |
| 932 | } |
| 933 | |
| 934 | static void emitSegmentSpec( |
| 935 | const Operator &op, const char *kind, |
| 936 | llvm::function_ref<int(const Operator &)> getNumElements, |
| 937 | llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| 938 | getElement, |
| 939 | raw_ostream &os) { |
| 940 | std::string segmentSpec("[" ); |
| 941 | for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| 942 | const NamedTypeConstraint &element = getElement(op, i); |
| 943 | if (element.isOptional()) { |
| 944 | segmentSpec.append(s: "0," ); |
| 945 | } else if (element.isVariadic()) { |
| 946 | segmentSpec.append(s: "-1," ); |
| 947 | } else { |
| 948 | segmentSpec.append(s: "1," ); |
| 949 | } |
| 950 | } |
| 951 | segmentSpec.append(s: "]" ); |
| 952 | |
| 953 | os << formatv(Fmt: opClassSizedSegmentsTemplate, Vals&: kind, Vals&: segmentSpec); |
| 954 | } |
| 955 | |
| 956 | static void emitRegionAttributes(const Operator &op, raw_ostream &os) { |
| 957 | // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). |
| 958 | // Note that the base OpView class defines this as (0, True). |
| 959 | unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); |
| 960 | os << formatv(Fmt: opClassRegionSpecTemplate, Vals&: minRegionCount, |
| 961 | Vals: op.hasNoVariadicRegions() ? "True" : "False" ); |
| 962 | } |
| 963 | |
| 964 | /// Emits named accessors to regions. |
| 965 | static void emitRegionAccessors(const Operator &op, raw_ostream &os) { |
| 966 | for (const auto &en : llvm::enumerate(First: op.getRegions())) { |
| 967 | const NamedRegion ®ion = en.value(); |
| 968 | if (region.name.empty()) |
| 969 | continue; |
| 970 | |
| 971 | assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && |
| 972 | "expected only the last region to be variadic" ); |
| 973 | os << formatv(Fmt: regionAccessorTemplate, Vals: sanitizeName(name: region.name), |
| 974 | Vals: std::to_string(val: en.index()) + |
| 975 | (region.isVariadic() ? ":" : "" )); |
| 976 | } |
| 977 | } |
| 978 | |
| 979 | /// Emits builder that extracts results from op |
| 980 | static void emitValueBuilder(const Operator &op, |
| 981 | SmallVector<std::string> functionArgs, |
| 982 | raw_ostream &os) { |
| 983 | // Params with (possibly) default args. |
| 984 | auto valueBuilderParams = |
| 985 | llvm::map_range(C&: functionArgs, F: [](const std::string &argAndMaybeDefault) { |
| 986 | SmallVector<StringRef> argMaybeDefault = |
| 987 | llvm::to_vector<2>(Range: llvm::split(Str: argAndMaybeDefault, Separator: "=" )); |
| 988 | auto arg = llvm::convertToSnakeFromCamelCase(input: argMaybeDefault[0]); |
| 989 | if (argMaybeDefault.size() == 2) |
| 990 | return arg + "=" + argMaybeDefault[1].str(); |
| 991 | return arg; |
| 992 | }); |
| 993 | // Actual args passed to op builder (e.g., opParam=op_param). |
| 994 | auto opBuilderArgs = llvm::map_range( |
| 995 | C: llvm::make_filter_range(Range&: functionArgs, |
| 996 | Pred: [](const std::string &s) { return s != "*" ; }), |
| 997 | F: [](const std::string &arg) { |
| 998 | auto lhs = *llvm::split(Str: arg, Separator: "=" ).begin(); |
| 999 | return (lhs + "=" + llvm::convertToSnakeFromCamelCase(input: lhs)).str(); |
| 1000 | }); |
| 1001 | std::string nameWithoutDialect = sanitizeName( |
| 1002 | name: op.getOperationName().substr(pos: op.getOperationName().find(c: '.') + 1)); |
| 1003 | if (nameWithoutDialect == op.getCppClassName()) |
| 1004 | nameWithoutDialect += "_" ; |
| 1005 | std::string params = llvm::join(R&: valueBuilderParams, Separator: ", " ); |
| 1006 | std::string args = llvm::join(R&: opBuilderArgs, Separator: ", " ); |
| 1007 | const char *type = |
| 1008 | (op.getNumResults() > 1 |
| 1009 | ? "_Sequence[_ods_ir.Value]" |
| 1010 | : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation" )); |
| 1011 | if (op.getNumVariableLengthResults() > 0) { |
| 1012 | os << formatv(Fmt: valueBuilderVariadicTemplate, Vals&: nameWithoutDialect, |
| 1013 | Vals: op.getCppClassName(), Vals&: params, Vals&: args, Vals&: type); |
| 1014 | } else { |
| 1015 | const char *results; |
| 1016 | if (op.getNumResults() == 0) { |
| 1017 | results = "" ; |
| 1018 | } else if (op.getNumResults() == 1) { |
| 1019 | results = ".result" ; |
| 1020 | } else { |
| 1021 | results = ".results" ; |
| 1022 | } |
| 1023 | os << formatv(Fmt: valueBuilderTemplate, Vals&: nameWithoutDialect, |
| 1024 | Vals: op.getCppClassName(), Vals&: params, Vals&: args, Vals&: type, Vals&: results); |
| 1025 | } |
| 1026 | } |
| 1027 | |
| 1028 | /// Emits bindings for a specific Op to the given output stream. |
| 1029 | static void emitOpBindings(const Operator &op, raw_ostream &os) { |
| 1030 | os << formatv(Fmt: opClassTemplate, Vals: op.getCppClassName(), Vals: op.getOperationName()); |
| 1031 | |
| 1032 | // Sized segments. |
| 1033 | if (op.getTrait(trait: attrSizedTraitForKind(kind: "operand" )) != nullptr) { |
| 1034 | emitSegmentSpec(op, kind: "OPERAND" , getNumElements: getNumOperands, getElement: getOperand, os); |
| 1035 | } |
| 1036 | if (op.getTrait(trait: attrSizedTraitForKind(kind: "result" )) != nullptr) { |
| 1037 | emitSegmentSpec(op, kind: "RESULT" , getNumElements: getNumResults, getElement: getResult, os); |
| 1038 | } |
| 1039 | |
| 1040 | emitRegionAttributes(op, os); |
| 1041 | SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os); |
| 1042 | emitOperandAccessors(op, os); |
| 1043 | emitAttributeAccessors(op, os); |
| 1044 | emitResultAccessors(op, os); |
| 1045 | emitRegionAccessors(op, os); |
| 1046 | emitValueBuilder(op, functionArgs, os); |
| 1047 | } |
| 1048 | |
| 1049 | /// Emits bindings for the dialect specified in the command line, including file |
| 1050 | /// headers and utilities. Returns `false` on success to comply with Tablegen |
| 1051 | /// registration requirements. |
| 1052 | static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) { |
| 1053 | if (clDialectName.empty()) |
| 1054 | llvm::PrintFatalError(Msg: "dialect name not provided" ); |
| 1055 | |
| 1056 | os << fileHeader; |
| 1057 | if (!clDialectExtensionName.empty()) |
| 1058 | os << formatv(Fmt: dialectExtensionTemplate, Vals&: clDialectName.getValue()); |
| 1059 | else |
| 1060 | os << formatv(Fmt: dialectClassTemplate, Vals&: clDialectName.getValue()); |
| 1061 | |
| 1062 | for (const Record *rec : records.getAllDerivedDefinitions(ClassName: "Op" )) { |
| 1063 | Operator op(rec); |
| 1064 | if (op.getDialectName() == clDialectName.getValue()) |
| 1065 | emitOpBindings(op, os); |
| 1066 | } |
| 1067 | return false; |
| 1068 | } |
| 1069 | |
| 1070 | static GenRegistration |
| 1071 | genPythonBindings("gen-python-op-bindings" , |
| 1072 | "Generate Python bindings for MLIR Ops" , &emitAllOps); |
| 1073 | |