1//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
10// generate the corresponding Python binding classes.
11//
12//===----------------------------------------------------------------------===//
13#include "OpGenHelpers.h"
14
15#include "mlir/TableGen/AttrOrTypeDef.h"
16#include "mlir/TableGen/Attribute.h"
17#include "mlir/TableGen/Dialect.h"
18#include "mlir/TableGen/GenInfo.h"
19#include "llvm/Support/FormatVariadic.h"
20#include "llvm/TableGen/Record.h"
21
22using namespace mlir;
23using namespace mlir::tblgen;
24
25/// File header and includes.
26constexpr const char *fileHeader = R"Py(
27# Autogenerated by mlir-tblgen; don't manually edit.
28
29from enum import IntEnum, auto, IntFlag
30from ._ods_common import _cext as _ods_cext
31from ..ir import register_attribute_builder
32_ods_ir = _ods_cext.ir
33
34)Py";
35
36/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
37static std::string makePythonEnumCaseName(StringRef name) {
38 if (isPythonReserved(str: name.str()))
39 return (name + "_").str();
40 return name.str();
41}
42
43/// Emits the Python class for the given enum.
44static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
45 os << llvm::formatv(Fmt: "class {0}({1}):\n", Vals: enumAttr.getEnumClassName(),
46 Vals: enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
47 if (!enumAttr.getSummary().empty())
48 os << llvm::formatv(Fmt: " \"\"\"{0}\"\"\"\n", Vals: enumAttr.getSummary());
49 os << "\n";
50
51 for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
52 os << llvm::formatv(
53 Fmt: " {0} = {1}\n", Vals: makePythonEnumCaseName(name: enumCase.getSymbol()),
54 Vals: enumCase.getValue() >= 0 ? std::to_string(val: enumCase.getValue())
55 : "auto()");
56 }
57
58 os << "\n";
59
60 if (enumAttr.isBitEnum()) {
61 os << llvm::formatv(Fmt: " def __iter__(self):\n"
62 " return iter([case for case in type(self) if "
63 "(self & case) is case])\n");
64 os << llvm::formatv(Fmt: " def __len__(self):\n"
65 " return bin(self).count(\"1\")\n");
66 os << "\n";
67 }
68
69 os << llvm::formatv(Fmt: " def __str__(self):\n");
70 if (enumAttr.isBitEnum())
71 os << llvm::formatv(Fmt: " if len(self) > 1:\n"
72 " return \"{0}\".join(map(str, self))\n",
73 Vals: enumAttr.getDef().getValueAsString(FieldName: "separator"));
74 for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
75 os << llvm::formatv(Fmt: " if self is {0}.{1}:\n",
76 Vals: enumAttr.getEnumClassName(),
77 Vals: makePythonEnumCaseName(name: enumCase.getSymbol()));
78 os << llvm::formatv(Fmt: " return \"{0}\"\n", Vals: enumCase.getStr());
79 }
80 os << llvm::formatv(
81 Fmt: " raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
82 Vals: enumAttr.getEnumClassName());
83 os << "\n";
84}
85
86/// Attempts to extract the bitwidth B from string "uintB_t" describing the
87/// type. This bitwidth information is not readily available in ODS. Returns
88/// `false` on success, `true` on failure.
89static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
90 if (!uintType.consume_front(Prefix: "uint"))
91 return true;
92 if (!uintType.consume_back(Suffix: "_t"))
93 return true;
94 return uintType.getAsInteger(/*Radix=*/10, Result&: bitwidth);
95}
96
97/// Emits an attribute builder for the given enum attribute to support automatic
98/// conversion between enum values and attributes in Python. Returns
99/// `false` on success, `true` on failure.
100static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
101 int64_t bitwidth;
102 if (extractUIntBitwidth(uintType: enumAttr.getUnderlyingType(), bitwidth)) {
103 llvm::errs() << "failed to identify bitwidth of "
104 << enumAttr.getUnderlyingType();
105 return true;
106 }
107
108 os << llvm::formatv(Fmt: "@register_attribute_builder(\"{0}\")\n",
109 Vals: enumAttr.getAttrDefName());
110 os << llvm::formatv(Fmt: "def _{0}(x, context):\n",
111 Vals: enumAttr.getAttrDefName().lower());
112 os << llvm::formatv(
113 Fmt: " return "
114 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
115 "context=context), int(x))\n\n",
116 Vals&: bitwidth);
117 return false;
118}
119
120/// Emits an attribute builder for the given dialect enum attribute to support
121/// automatic conversion between enum values and attributes in Python. Returns
122/// `false` on success, `true` on failure.
123static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
124 StringRef formatString,
125 raw_ostream &os) {
126 os << llvm::formatv(Fmt: "@register_attribute_builder(\"{0}\")\n", Vals&: attrDefName);
127 os << llvm::formatv(Fmt: "def _{0}(x, context):\n", Vals: attrDefName.lower());
128 os << llvm::formatv(Fmt: " return "
129 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
130 Vals&: formatString);
131 return false;
132}
133
134/// Emits Python bindings for all enums in the record keeper. Returns
135/// `false` on success, `true` on failure.
136static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
137 raw_ostream &os) {
138 os << fileHeader;
139 for (auto &it :
140 recordKeeper.getAllDerivedDefinitionsIfDefined(ClassName: "EnumAttrInfo")) {
141 EnumAttr enumAttr(*it);
142 emitEnumClass(enumAttr, os);
143 emitAttributeBuilder(enumAttr, os);
144 }
145 for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined(ClassName: "EnumAttr")) {
146 AttrOrTypeDef attr(&*it);
147 if (!attr.getMnemonic()) {
148 llvm::errs() << "enum case " << attr
149 << " needs mnemonic for python enum bindings generation";
150 return true;
151 }
152 StringRef mnemonic = attr.getMnemonic().value();
153 std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
154 StringRef dialect = attr.getDialect().getName();
155 if (assemblyFormat == "`<` $value `>`") {
156 emitDialectEnumAttributeBuilder(
157 attrDefName: attr.getName(),
158 formatString: llvm::formatv(Fmt: "#{0}.{1}<{{str(x)}>", Vals&: dialect, Vals&: mnemonic).str(), os);
159 } else if (assemblyFormat == "$value") {
160 emitDialectEnumAttributeBuilder(
161 attrDefName: attr.getName(),
162 formatString: llvm::formatv(Fmt: "#{0}<{1} {{str(x)}>", Vals&: dialect, Vals&: mnemonic).str(), os);
163 } else {
164 llvm::errs()
165 << "unsupported assembly format for python enum bindings generation";
166 return true;
167 }
168 }
169
170 return false;
171}
172
173// Registers the enum utility generator to mlir-tblgen.
174static mlir::GenRegistration
175 genPythonEnumBindings("gen-python-enum-bindings",
176 "Generate Python bindings for enum attributes",
177 &emitPythonEnums);
178

source code of mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp