1//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
10// generate the corresponding Python binding classes.
11//
12//===----------------------------------------------------------------------===//
13#include "OpGenHelpers.h"
14
15#include "mlir/TableGen/AttrOrTypeDef.h"
16#include "mlir/TableGen/Attribute.h"
17#include "mlir/TableGen/Dialect.h"
18#include "mlir/TableGen/EnumInfo.h"
19#include "mlir/TableGen/GenInfo.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/TableGen/Record.h"
22
23using namespace mlir;
24using namespace mlir::tblgen;
25using llvm::formatv;
26using llvm::Record;
27using llvm::RecordKeeper;
28
29/// File header and includes.
30constexpr const char *fileHeader = R"Py(
31# Autogenerated by mlir-tblgen; don't manually edit.
32
33from enum import IntEnum, auto, IntFlag
34from ._ods_common import _cext as _ods_cext
35from ..ir import register_attribute_builder
36_ods_ir = _ods_cext.ir
37
38)Py";
39
40/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
41static std::string makePythonEnumCaseName(StringRef name) {
42 if (isPythonReserved(str: name.str()))
43 return (name + "_").str();
44 return name.str();
45}
46
47/// Emits the Python class for the given enum.
48static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
49 os << formatv(Fmt: "class {0}({1}):\n", Vals: enumInfo.getEnumClassName(),
50 Vals: enumInfo.isBitEnum() ? "IntFlag" : "IntEnum");
51 if (!enumInfo.getSummary().empty())
52 os << formatv(Fmt: " \"\"\"{0}\"\"\"\n", Vals: enumInfo.getSummary());
53 os << "\n";
54
55 for (const EnumCase &enumCase : enumInfo.getAllCases()) {
56 os << formatv(Fmt: " {0} = {1}\n",
57 Vals: makePythonEnumCaseName(name: enumCase.getSymbol()),
58 Vals: enumCase.getValue() >= 0 ? std::to_string(val: enumCase.getValue())
59 : "auto()");
60 }
61
62 os << "\n";
63
64 if (enumInfo.isBitEnum()) {
65 os << formatv(Fmt: " def __iter__(self):\n"
66 " return iter([case for case in type(self) if "
67 "(self & case) is case])\n");
68 os << formatv(Fmt: " def __len__(self):\n"
69 " return bin(self).count(\"1\")\n");
70 os << "\n";
71 }
72
73 os << formatv(Fmt: " def __str__(self):\n");
74 if (enumInfo.isBitEnum())
75 os << formatv(Fmt: " if len(self) > 1:\n"
76 " return \"{0}\".join(map(str, self))\n",
77 Vals: enumInfo.getDef().getValueAsString(FieldName: "separator"));
78 for (const EnumCase &enumCase : enumInfo.getAllCases()) {
79 os << formatv(Fmt: " if self is {0}.{1}:\n", Vals: enumInfo.getEnumClassName(),
80 Vals: makePythonEnumCaseName(name: enumCase.getSymbol()));
81 os << formatv(Fmt: " return \"{0}\"\n", Vals: enumCase.getStr());
82 }
83 os << formatv(Fmt: " raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
84 Vals: enumInfo.getEnumClassName());
85 os << "\n";
86}
87
88/// Emits an attribute builder for the given enum attribute to support automatic
89/// conversion between enum values and attributes in Python. Returns
90/// `false` on success, `true` on failure.
91static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
92 std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr();
93 if (!enumAttrInfo)
94 return false;
95
96 int64_t bitwidth = enumInfo.getBitwidth();
97 os << formatv(Fmt: "@register_attribute_builder(\"{0}\")\n",
98 Vals: enumAttrInfo->getAttrDefName());
99 os << formatv(Fmt: "def _{0}(x, context):\n",
100 Vals: enumAttrInfo->getAttrDefName().lower());
101 os << formatv(Fmt: " return "
102 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
103 "context=context), int(x))\n\n",
104 Vals&: bitwidth);
105 return false;
106}
107
108/// Emits an attribute builder for the given dialect enum attribute to support
109/// automatic conversion between enum values and attributes in Python. Returns
110/// `false` on success, `true` on failure.
111static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
112 StringRef formatString,
113 raw_ostream &os) {
114 os << formatv(Fmt: "@register_attribute_builder(\"{0}\")\n", Vals&: attrDefName);
115 os << formatv(Fmt: "def _{0}(x, context):\n", Vals: attrDefName.lower());
116 os << formatv(Fmt: " return "
117 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
118 Vals&: formatString);
119 return false;
120}
121
122/// Emits Python bindings for all enums in the record keeper. Returns
123/// `false` on success, `true` on failure.
124static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
125 os << fileHeader;
126 for (const Record *it :
127 records.getAllDerivedDefinitionsIfDefined(ClassName: "EnumInfo")) {
128 EnumInfo enumInfo(*it);
129 emitEnumClass(enumInfo, os);
130 emitAttributeBuilder(enumInfo, os);
131 }
132 for (const Record *it :
133 records.getAllDerivedDefinitionsIfDefined(ClassName: "EnumAttr")) {
134 AttrOrTypeDef attr(&*it);
135 if (!attr.getMnemonic()) {
136 llvm::errs() << "enum case " << attr
137 << " needs mnemonic for python enum bindings generation";
138 return true;
139 }
140 StringRef mnemonic = attr.getMnemonic().value();
141 std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
142 StringRef dialect = attr.getDialect().getName();
143 if (assemblyFormat == "`<` $value `>`") {
144 emitDialectEnumAttributeBuilder(
145 attrDefName: attr.getName(),
146 formatString: formatv(Fmt: "#{0}.{1}<{{str(x)}>", Vals&: dialect, Vals&: mnemonic).str(), os);
147 } else if (assemblyFormat == "$value") {
148 emitDialectEnumAttributeBuilder(
149 attrDefName: attr.getName(),
150 formatString: formatv(Fmt: "#{0}<{1} {{str(x)}>", Vals&: dialect, Vals&: mnemonic).str(), os);
151 } else {
152 llvm::errs()
153 << "unsupported assembly format for python enum bindings generation";
154 return true;
155 }
156 }
157
158 return false;
159}
160
161// Registers the enum utility generator to mlir-tblgen.
162static mlir::GenRegistration
163 genPythonEnumBindings("gen-python-enum-bindings",
164 "Generate Python bindings for enum attributes",
165 &emitPythonEnums);
166

source code of mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp