1 | //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <cstdint> |
10 | #include <optional> |
11 | #include <pybind11/cast.h> |
12 | #include <pybind11/detail/common.h> |
13 | #include <pybind11/pybind11.h> |
14 | #include <pybind11/pytypes.h> |
15 | #include <string> |
16 | #include <utility> |
17 | #include <vector> |
18 | |
19 | #include "IRModule.h" |
20 | #include "mlir-c/BuiltinAttributes.h" |
21 | #include "mlir-c/IR.h" |
22 | #include "mlir-c/Interfaces.h" |
23 | #include "mlir-c/Support.h" |
24 | #include "llvm/ADT/STLExtras.h" |
25 | #include "llvm/ADT/SmallVector.h" |
26 | |
27 | namespace py = pybind11; |
28 | |
29 | namespace mlir { |
30 | namespace python { |
31 | |
32 | constexpr static const char *constructorDoc = |
33 | R"(Creates an interface from a given operation/opview object or from a |
34 | subclass of OpView. Raises ValueError if the operation does not implement the |
35 | interface.)" ; |
36 | |
37 | constexpr static const char *operationDoc = |
38 | R"(Returns an Operation for which the interface was constructed.)" ; |
39 | |
40 | constexpr static const char *opviewDoc = |
41 | R"(Returns an OpView subclass _instance_ for which the interface was |
42 | constructed)" ; |
43 | |
44 | constexpr static const char *inferReturnTypesDoc = |
45 | R"(Given the arguments required to build an operation, attempts to infer |
46 | its return types. Raises ValueError on failure.)" ; |
47 | |
48 | constexpr static const char *inferReturnTypeComponentsDoc = |
49 | R"(Given the arguments required to build an operation, attempts to infer |
50 | its return shaped type components. Raises ValueError on failure.)" ; |
51 | |
52 | namespace { |
53 | |
54 | /// Takes in an optional ist of operands and converts them into a SmallVector |
55 | /// of MlirVlaues. Returns an empty SmallVector if the list is empty. |
56 | llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) { |
57 | llvm::SmallVector<MlirValue> mlirOperands; |
58 | |
59 | if (!operandList || operandList->empty()) { |
60 | return mlirOperands; |
61 | } |
62 | |
63 | // Note: as the list may contain other lists this may not be final size. |
64 | mlirOperands.reserve(operandList->size()); |
65 | for (const auto &&it : llvm::enumerate(*operandList)) { |
66 | if (it.value().is_none()) |
67 | continue; |
68 | |
69 | PyValue *val; |
70 | try { |
71 | val = py::cast<PyValue *>(it.value()); |
72 | if (!val) |
73 | throw py::cast_error(); |
74 | mlirOperands.push_back(val->get()); |
75 | continue; |
76 | } catch (py::cast_error &err) { |
77 | // Intentionally unhandled to try sequence below first. |
78 | (void)err; |
79 | } |
80 | |
81 | try { |
82 | auto vals = py::cast<py::sequence>(it.value()); |
83 | for (py::object v : vals) { |
84 | try { |
85 | val = py::cast<PyValue *>(v); |
86 | if (!val) |
87 | throw py::cast_error(); |
88 | mlirOperands.push_back(val->get()); |
89 | } catch (py::cast_error &err) { |
90 | throw py::value_error( |
91 | (llvm::Twine("Operand " ) + llvm::Twine(it.index()) + |
92 | " must be a Value or Sequence of Values (" + err.what() + ")" ) |
93 | .str()); |
94 | } |
95 | } |
96 | continue; |
97 | } catch (py::cast_error &err) { |
98 | throw py::value_error((llvm::Twine("Operand " ) + llvm::Twine(it.index()) + |
99 | " must be a Value or Sequence of Values (" + |
100 | err.what() + ")" ) |
101 | .str()); |
102 | } |
103 | |
104 | throw py::cast_error(); |
105 | } |
106 | |
107 | return mlirOperands; |
108 | } |
109 | |
110 | /// Takes in an optional vector of PyRegions and returns a SmallVector of |
111 | /// MlirRegion. Returns an empty SmallVector if the list is empty. |
112 | llvm::SmallVector<MlirRegion> |
113 | wrapRegions(std::optional<std::vector<PyRegion>> regions) { |
114 | llvm::SmallVector<MlirRegion> mlirRegions; |
115 | |
116 | if (regions) { |
117 | mlirRegions.reserve(regions->size()); |
118 | for (PyRegion ®ion : *regions) { |
119 | mlirRegions.push_back(region); |
120 | } |
121 | } |
122 | |
123 | return mlirRegions; |
124 | } |
125 | |
126 | } // namespace |
127 | |
128 | /// CRTP base class for Python classes representing MLIR Op interfaces. |
129 | /// Interface hierarchies are flat so no base class is expected here. The |
130 | /// derived class is expected to define the following static fields: |
131 | /// - `const char *pyClassName` - the name of the Python class to create; |
132 | /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID |
133 | /// of the interface. |
134 | /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind |
135 | /// interface-specific methods. |
136 | /// |
137 | /// An interface class may be constructed from either an Operation/OpView object |
138 | /// or from a subclass of OpView. In the latter case, only the static interface |
139 | /// methods are available, similarly to calling ConcereteOp::staticMethod on the |
140 | /// C++ side. Implementations of concrete interfaces can use the `isStatic` |
141 | /// method to check whether the interface object was constructed from a class or |
142 | /// an operation/opview instance. The `getOpName` always succeeds and returns a |
143 | /// canonical name of the operation suitable for lookups. |
144 | template <typename ConcreteIface> |
145 | class PyConcreteOpInterface { |
146 | protected: |
147 | using ClassTy = py::class_<ConcreteIface>; |
148 | using GetTypeIDFunctionTy = MlirTypeID (*)(); |
149 | |
150 | public: |
151 | /// Constructs an interface instance from an object that is either an |
152 | /// operation or a subclass of OpView. In the latter case, only the static |
153 | /// methods of the interface are accessible to the caller. |
154 | PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) |
155 | : obj(std::move(object)) { |
156 | try { |
157 | operation = &py::cast<PyOperation &>(obj); |
158 | } catch (py::cast_error &) { |
159 | // Do nothing. |
160 | } |
161 | |
162 | try { |
163 | operation = &py::cast<PyOpView &>(obj).getOperation(); |
164 | } catch (py::cast_error &) { |
165 | // Do nothing. |
166 | } |
167 | |
168 | if (operation != nullptr) { |
169 | if (!mlirOperationImplementsInterface(*operation, |
170 | ConcreteIface::getInterfaceID())) { |
171 | std::string msg = "the operation does not implement " ; |
172 | throw py::value_error(msg + ConcreteIface::pyClassName); |
173 | } |
174 | |
175 | MlirIdentifier identifier = mlirOperationGetName(*operation); |
176 | MlirStringRef stringRef = mlirIdentifierStr(identifier); |
177 | opName = std::string(stringRef.data, stringRef.length); |
178 | } else { |
179 | try { |
180 | opName = obj.attr("OPERATION_NAME" ).template cast<std::string>(); |
181 | } catch (py::cast_error &) { |
182 | throw py::type_error( |
183 | "Op interface does not refer to an operation or OpView class" ); |
184 | } |
185 | |
186 | if (!mlirOperationImplementsInterfaceStatic( |
187 | mlirStringRefCreate(opName.data(), opName.length()), |
188 | context.resolve().get(), ConcreteIface::getInterfaceID())) { |
189 | std::string msg = "the operation does not implement " ; |
190 | throw py::value_error(msg + ConcreteIface::pyClassName); |
191 | } |
192 | } |
193 | } |
194 | |
195 | /// Creates the Python bindings for this class in the given module. |
196 | static void bind(py::module &m) { |
197 | py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName, |
198 | py::module_local()); |
199 | cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object" ), |
200 | py::arg("context" ) = py::none(), constructorDoc) |
201 | .def_property_readonly("operation" , |
202 | &PyConcreteOpInterface::getOperationObject, |
203 | operationDoc) |
204 | .def_property_readonly("opview" , &PyConcreteOpInterface::getOpView, |
205 | opviewDoc); |
206 | ConcreteIface::bindDerived(cls); |
207 | } |
208 | |
209 | /// Hook for derived classes to add class-specific bindings. |
210 | static void bindDerived(ClassTy &cls) {} |
211 | |
212 | /// Returns `true` if this object was constructed from a subclass of OpView |
213 | /// rather than from an operation instance. |
214 | bool isStatic() { return operation == nullptr; } |
215 | |
216 | /// Returns the operation instance from which this object was constructed. |
217 | /// Throws a type error if this object was constructed from a subclass of |
218 | /// OpView. |
219 | py::object getOperationObject() { |
220 | if (operation == nullptr) { |
221 | throw py::type_error("Cannot get an operation from a static interface" ); |
222 | } |
223 | |
224 | return operation->getRef().releaseObject(); |
225 | } |
226 | |
227 | /// Returns the opview of the operation instance from which this object was |
228 | /// constructed. Throws a type error if this object was constructed form a |
229 | /// subclass of OpView. |
230 | py::object getOpView() { |
231 | if (operation == nullptr) { |
232 | throw py::type_error("Cannot get an opview from a static interface" ); |
233 | } |
234 | |
235 | return operation->createOpView(); |
236 | } |
237 | |
238 | /// Returns the canonical name of the operation this interface is constructed |
239 | /// from. |
240 | const std::string &getOpName() { return opName; } |
241 | |
242 | private: |
243 | PyOperation *operation = nullptr; |
244 | std::string opName; |
245 | py::object obj; |
246 | }; |
247 | |
248 | /// Python wrapper for InferTypeOpInterface. This interface has only static |
249 | /// methods. |
250 | class PyInferTypeOpInterface |
251 | : public PyConcreteOpInterface<PyInferTypeOpInterface> { |
252 | public: |
253 | using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; |
254 | |
255 | constexpr static const char *pyClassName = "InferTypeOpInterface" ; |
256 | constexpr static GetTypeIDFunctionTy getInterfaceID = |
257 | &mlirInferTypeOpInterfaceTypeID; |
258 | |
259 | /// C-style user-data structure for type appending callback. |
260 | struct AppendResultsCallbackData { |
261 | std::vector<PyType> &inferredTypes; |
262 | PyMlirContext &pyMlirContext; |
263 | }; |
264 | |
265 | /// Appends the types provided as the two first arguments to the user-data |
266 | /// structure (expects AppendResultsCallbackData). |
267 | static void appendResultsCallback(intptr_t nTypes, MlirType *types, |
268 | void *userData) { |
269 | auto *data = static_cast<AppendResultsCallbackData *>(userData); |
270 | data->inferredTypes.reserve(n: data->inferredTypes.size() + nTypes); |
271 | for (intptr_t i = 0; i < nTypes; ++i) { |
272 | data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); |
273 | } |
274 | } |
275 | |
276 | /// Given the arguments required to build an operation, attempts to infer its |
277 | /// return types. Throws value_error on failure. |
278 | std::vector<PyType> |
279 | inferReturnTypes(std::optional<py::list> operandList, |
280 | std::optional<PyAttribute> attributes, void *properties, |
281 | std::optional<std::vector<PyRegion>> regions, |
282 | DefaultingPyMlirContext context, |
283 | DefaultingPyLocation location) { |
284 | llvm::SmallVector<MlirValue> mlirOperands = |
285 | wrapOperands(std::move(operandList)); |
286 | llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); |
287 | |
288 | std::vector<PyType> inferredTypes; |
289 | PyMlirContext &pyContext = context.resolve(); |
290 | AppendResultsCallbackData data{.inferredTypes: inferredTypes, .pyMlirContext: pyContext}; |
291 | MlirStringRef opNameRef = |
292 | mlirStringRefCreate(getOpName().data(), getOpName().length()); |
293 | MlirAttribute attributeDict = |
294 | attributes ? attributes->get() : mlirAttributeGetNull(); |
295 | |
296 | MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( |
297 | opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), |
298 | mlirOperands.data(), attributeDict, properties, mlirRegions.size(), |
299 | mlirRegions.data(), &appendResultsCallback, &data); |
300 | |
301 | if (mlirLogicalResultIsFailure(result)) { |
302 | throw py::value_error("Failed to infer result types" ); |
303 | } |
304 | |
305 | return inferredTypes; |
306 | } |
307 | |
308 | static void bindDerived(ClassTy &cls) { |
309 | cls.def("inferReturnTypes" , &PyInferTypeOpInterface::inferReturnTypes, |
310 | py::arg("operands" ) = py::none(), |
311 | py::arg("attributes" ) = py::none(), |
312 | py::arg("properties" ) = py::none(), py::arg("regions" ) = py::none(), |
313 | py::arg("context" ) = py::none(), py::arg("loc" ) = py::none(), |
314 | inferReturnTypesDoc); |
315 | } |
316 | }; |
317 | |
318 | /// Wrapper around an shaped type components. |
319 | class PyShapedTypeComponents { |
320 | public: |
321 | PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} |
322 | PyShapedTypeComponents(py::list shape, MlirType elementType) |
323 | : shape(std::move(shape)), elementType(elementType), ranked(true) {} |
324 | PyShapedTypeComponents(py::list shape, MlirType elementType, |
325 | MlirAttribute attribute) |
326 | : shape(std::move(shape)), elementType(elementType), attribute(attribute), |
327 | ranked(true) {} |
328 | PyShapedTypeComponents(PyShapedTypeComponents &) = delete; |
329 | PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept |
330 | : shape(other.shape), elementType(other.elementType), |
331 | attribute(other.attribute), ranked(other.ranked) {} |
332 | |
333 | static void bind(py::module &m) { |
334 | py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents" , |
335 | py::module_local()) |
336 | .def_property_readonly( |
337 | "element_type" , |
338 | [](PyShapedTypeComponents &self) { return self.elementType; }, |
339 | "Returns the element type of the shaped type components." ) |
340 | .def_static( |
341 | "get" , |
342 | [](PyType &elementType) { |
343 | return PyShapedTypeComponents(elementType); |
344 | }, |
345 | py::arg("element_type" ), |
346 | "Create an shaped type components object with only the element " |
347 | "type." ) |
348 | .def_static( |
349 | "get" , |
350 | [](py::list shape, PyType &elementType) { |
351 | return PyShapedTypeComponents(std::move(shape), elementType); |
352 | }, |
353 | py::arg("shape" ), py::arg("element_type" ), |
354 | "Create a ranked shaped type components object." ) |
355 | .def_static( |
356 | "get" , |
357 | [](py::list shape, PyType &elementType, PyAttribute &attribute) { |
358 | return PyShapedTypeComponents(std::move(shape), elementType, |
359 | attribute); |
360 | }, |
361 | py::arg("shape" ), py::arg("element_type" ), py::arg("attribute" ), |
362 | "Create a ranked shaped type components object with attribute." ) |
363 | .def_property_readonly( |
364 | "has_rank" , |
365 | [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, |
366 | "Returns whether the given shaped type component is ranked." ) |
367 | .def_property_readonly( |
368 | "rank" , |
369 | [](PyShapedTypeComponents &self) -> py::object { |
370 | if (!self.ranked) { |
371 | return py::none(); |
372 | } |
373 | return py::int_(self.shape.size()); |
374 | }, |
375 | "Returns the rank of the given ranked shaped type components. If " |
376 | "the shaped type components does not have a rank, None is " |
377 | "returned." ) |
378 | .def_property_readonly( |
379 | "shape" , |
380 | [](PyShapedTypeComponents &self) -> py::object { |
381 | if (!self.ranked) { |
382 | return py::none(); |
383 | } |
384 | return py::list(self.shape); |
385 | }, |
386 | "Returns the shape of the ranked shaped type components as a list " |
387 | "of integers. Returns none if the shaped type component does not " |
388 | "have a rank." ); |
389 | } |
390 | |
391 | pybind11::object getCapsule(); |
392 | static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); |
393 | |
394 | private: |
395 | py::list shape; |
396 | MlirType elementType; |
397 | MlirAttribute attribute; |
398 | bool ranked{false}; |
399 | }; |
400 | |
401 | /// Python wrapper for InferShapedTypeOpInterface. This interface has only |
402 | /// static methods. |
403 | class PyInferShapedTypeOpInterface |
404 | : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { |
405 | public: |
406 | using PyConcreteOpInterface< |
407 | PyInferShapedTypeOpInterface>::PyConcreteOpInterface; |
408 | |
409 | constexpr static const char *pyClassName = "InferShapedTypeOpInterface" ; |
410 | constexpr static GetTypeIDFunctionTy getInterfaceID = |
411 | &mlirInferShapedTypeOpInterfaceTypeID; |
412 | |
413 | /// C-style user-data structure for type appending callback. |
414 | struct AppendResultsCallbackData { |
415 | std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; |
416 | }; |
417 | |
418 | /// Appends the shaped type components provided as unpacked shape, element |
419 | /// type, attribute to the user-data. |
420 | static void appendResultsCallback(bool hasRank, intptr_t rank, |
421 | const int64_t *shape, MlirType elementType, |
422 | MlirAttribute attribute, void *userData) { |
423 | auto *data = static_cast<AppendResultsCallbackData *>(userData); |
424 | if (!hasRank) { |
425 | data->inferredShapedTypeComponents.emplace_back(elementType); |
426 | } else { |
427 | py::list shapeList; |
428 | for (intptr_t i = 0; i < rank; ++i) { |
429 | shapeList.append(shape[i]); |
430 | } |
431 | data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, |
432 | attribute); |
433 | } |
434 | } |
435 | |
436 | /// Given the arguments required to build an operation, attempts to infer the |
437 | /// shaped type components. Throws value_error on failure. |
438 | std::vector<PyShapedTypeComponents> inferReturnTypeComponents( |
439 | std::optional<py::list> operandList, |
440 | std::optional<PyAttribute> attributes, void *properties, |
441 | std::optional<std::vector<PyRegion>> regions, |
442 | DefaultingPyMlirContext context, DefaultingPyLocation location) { |
443 | llvm::SmallVector<MlirValue> mlirOperands = |
444 | wrapOperands(std::move(operandList)); |
445 | llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); |
446 | |
447 | std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; |
448 | PyMlirContext &pyContext = context.resolve(); |
449 | AppendResultsCallbackData data{.inferredShapedTypeComponents: inferredShapedTypeComponents}; |
450 | MlirStringRef opNameRef = |
451 | mlirStringRefCreate(getOpName().data(), getOpName().length()); |
452 | MlirAttribute attributeDict = |
453 | attributes ? attributes->get() : mlirAttributeGetNull(); |
454 | |
455 | MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( |
456 | opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), |
457 | mlirOperands.data(), attributeDict, properties, mlirRegions.size(), |
458 | mlirRegions.data(), &appendResultsCallback, &data); |
459 | |
460 | if (mlirLogicalResultIsFailure(result)) { |
461 | throw py::value_error("Failed to infer result shape type components" ); |
462 | } |
463 | |
464 | return inferredShapedTypeComponents; |
465 | } |
466 | |
467 | static void bindDerived(ClassTy &cls) { |
468 | cls.def("inferReturnTypeComponents" , |
469 | &PyInferShapedTypeOpInterface::inferReturnTypeComponents, |
470 | py::arg("operands" ) = py::none(), |
471 | py::arg("attributes" ) = py::none(), py::arg("regions" ) = py::none(), |
472 | py::arg("properties" ) = py::none(), py::arg("context" ) = py::none(), |
473 | py::arg("loc" ) = py::none(), inferReturnTypeComponentsDoc); |
474 | } |
475 | }; |
476 | |
477 | void populateIRInterfaces(py::module &m) { |
478 | PyInferTypeOpInterface::bind(m); |
479 | PyShapedTypeComponents::bind(m); |
480 | PyInferShapedTypeOpInterface::bind(m); |
481 | } |
482 | |
483 | } // namespace python |
484 | } // namespace mlir |
485 | |