1//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstdint>
10#include <optional>
11#include <pybind11/cast.h>
12#include <pybind11/detail/common.h>
13#include <pybind11/pybind11.h>
14#include <pybind11/pytypes.h>
15#include <string>
16#include <utility>
17#include <vector>
18
19#include "IRModule.h"
20#include "mlir-c/BuiltinAttributes.h"
21#include "mlir-c/IR.h"
22#include "mlir-c/Interfaces.h"
23#include "mlir-c/Support.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallVector.h"
26
27namespace py = pybind11;
28
29namespace mlir {
30namespace python {
31
32constexpr static const char *constructorDoc =
33 R"(Creates an interface from a given operation/opview object or from a
34subclass of OpView. Raises ValueError if the operation does not implement the
35interface.)";
36
37constexpr static const char *operationDoc =
38 R"(Returns an Operation for which the interface was constructed.)";
39
40constexpr static const char *opviewDoc =
41 R"(Returns an OpView subclass _instance_ for which the interface was
42constructed)";
43
44constexpr static const char *inferReturnTypesDoc =
45 R"(Given the arguments required to build an operation, attempts to infer
46its return types. Raises ValueError on failure.)";
47
48constexpr static const char *inferReturnTypeComponentsDoc =
49 R"(Given the arguments required to build an operation, attempts to infer
50its return shaped type components. Raises ValueError on failure.)";
51
52namespace {
53
54/// Takes in an optional ist of operands and converts them into a SmallVector
55/// of MlirVlaues. Returns an empty SmallVector if the list is empty.
56llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
57 llvm::SmallVector<MlirValue> mlirOperands;
58
59 if (!operandList || operandList->empty()) {
60 return mlirOperands;
61 }
62
63 // Note: as the list may contain other lists this may not be final size.
64 mlirOperands.reserve(operandList->size());
65 for (const auto &&it : llvm::enumerate(*operandList)) {
66 if (it.value().is_none())
67 continue;
68
69 PyValue *val;
70 try {
71 val = py::cast<PyValue *>(it.value());
72 if (!val)
73 throw py::cast_error();
74 mlirOperands.push_back(val->get());
75 continue;
76 } catch (py::cast_error &err) {
77 // Intentionally unhandled to try sequence below first.
78 (void)err;
79 }
80
81 try {
82 auto vals = py::cast<py::sequence>(it.value());
83 for (py::object v : vals) {
84 try {
85 val = py::cast<PyValue *>(v);
86 if (!val)
87 throw py::cast_error();
88 mlirOperands.push_back(val->get());
89 } catch (py::cast_error &err) {
90 throw py::value_error(
91 (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
92 " must be a Value or Sequence of Values (" + err.what() + ")")
93 .str());
94 }
95 }
96 continue;
97 } catch (py::cast_error &err) {
98 throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
99 " must be a Value or Sequence of Values (" +
100 err.what() + ")")
101 .str());
102 }
103
104 throw py::cast_error();
105 }
106
107 return mlirOperands;
108}
109
110/// Takes in an optional vector of PyRegions and returns a SmallVector of
111/// MlirRegion. Returns an empty SmallVector if the list is empty.
112llvm::SmallVector<MlirRegion>
113wrapRegions(std::optional<std::vector<PyRegion>> regions) {
114 llvm::SmallVector<MlirRegion> mlirRegions;
115
116 if (regions) {
117 mlirRegions.reserve(regions->size());
118 for (PyRegion &region : *regions) {
119 mlirRegions.push_back(region);
120 }
121 }
122
123 return mlirRegions;
124}
125
126} // namespace
127
128/// CRTP base class for Python classes representing MLIR Op interfaces.
129/// Interface hierarchies are flat so no base class is expected here. The
130/// derived class is expected to define the following static fields:
131/// - `const char *pyClassName` - the name of the Python class to create;
132/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
133/// of the interface.
134/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
135/// interface-specific methods.
136///
137/// An interface class may be constructed from either an Operation/OpView object
138/// or from a subclass of OpView. In the latter case, only the static interface
139/// methods are available, similarly to calling ConcereteOp::staticMethod on the
140/// C++ side. Implementations of concrete interfaces can use the `isStatic`
141/// method to check whether the interface object was constructed from a class or
142/// an operation/opview instance. The `getOpName` always succeeds and returns a
143/// canonical name of the operation suitable for lookups.
144template <typename ConcreteIface>
145class PyConcreteOpInterface {
146protected:
147 using ClassTy = py::class_<ConcreteIface>;
148 using GetTypeIDFunctionTy = MlirTypeID (*)();
149
150public:
151 /// Constructs an interface instance from an object that is either an
152 /// operation or a subclass of OpView. In the latter case, only the static
153 /// methods of the interface are accessible to the caller.
154 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
155 : obj(std::move(object)) {
156 try {
157 operation = &py::cast<PyOperation &>(obj);
158 } catch (py::cast_error &) {
159 // Do nothing.
160 }
161
162 try {
163 operation = &py::cast<PyOpView &>(obj).getOperation();
164 } catch (py::cast_error &) {
165 // Do nothing.
166 }
167
168 if (operation != nullptr) {
169 if (!mlirOperationImplementsInterface(*operation,
170 ConcreteIface::getInterfaceID())) {
171 std::string msg = "the operation does not implement ";
172 throw py::value_error(msg + ConcreteIface::pyClassName);
173 }
174
175 MlirIdentifier identifier = mlirOperationGetName(*operation);
176 MlirStringRef stringRef = mlirIdentifierStr(identifier);
177 opName = std::string(stringRef.data, stringRef.length);
178 } else {
179 try {
180 opName = obj.attr("OPERATION_NAME").template cast<std::string>();
181 } catch (py::cast_error &) {
182 throw py::type_error(
183 "Op interface does not refer to an operation or OpView class");
184 }
185
186 if (!mlirOperationImplementsInterfaceStatic(
187 mlirStringRefCreate(opName.data(), opName.length()),
188 context.resolve().get(), ConcreteIface::getInterfaceID())) {
189 std::string msg = "the operation does not implement ";
190 throw py::value_error(msg + ConcreteIface::pyClassName);
191 }
192 }
193 }
194
195 /// Creates the Python bindings for this class in the given module.
196 static void bind(py::module &m) {
197 py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
198 py::module_local());
199 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
200 py::arg("context") = py::none(), constructorDoc)
201 .def_property_readonly("operation",
202 &PyConcreteOpInterface::getOperationObject,
203 operationDoc)
204 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
205 opviewDoc);
206 ConcreteIface::bindDerived(cls);
207 }
208
209 /// Hook for derived classes to add class-specific bindings.
210 static void bindDerived(ClassTy &cls) {}
211
212 /// Returns `true` if this object was constructed from a subclass of OpView
213 /// rather than from an operation instance.
214 bool isStatic() { return operation == nullptr; }
215
216 /// Returns the operation instance from which this object was constructed.
217 /// Throws a type error if this object was constructed from a subclass of
218 /// OpView.
219 py::object getOperationObject() {
220 if (operation == nullptr) {
221 throw py::type_error("Cannot get an operation from a static interface");
222 }
223
224 return operation->getRef().releaseObject();
225 }
226
227 /// Returns the opview of the operation instance from which this object was
228 /// constructed. Throws a type error if this object was constructed form a
229 /// subclass of OpView.
230 py::object getOpView() {
231 if (operation == nullptr) {
232 throw py::type_error("Cannot get an opview from a static interface");
233 }
234
235 return operation->createOpView();
236 }
237
238 /// Returns the canonical name of the operation this interface is constructed
239 /// from.
240 const std::string &getOpName() { return opName; }
241
242private:
243 PyOperation *operation = nullptr;
244 std::string opName;
245 py::object obj;
246};
247
248/// Python wrapper for InferTypeOpInterface. This interface has only static
249/// methods.
250class PyInferTypeOpInterface
251 : public PyConcreteOpInterface<PyInferTypeOpInterface> {
252public:
253 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
254
255 constexpr static const char *pyClassName = "InferTypeOpInterface";
256 constexpr static GetTypeIDFunctionTy getInterfaceID =
257 &mlirInferTypeOpInterfaceTypeID;
258
259 /// C-style user-data structure for type appending callback.
260 struct AppendResultsCallbackData {
261 std::vector<PyType> &inferredTypes;
262 PyMlirContext &pyMlirContext;
263 };
264
265 /// Appends the types provided as the two first arguments to the user-data
266 /// structure (expects AppendResultsCallbackData).
267 static void appendResultsCallback(intptr_t nTypes, MlirType *types,
268 void *userData) {
269 auto *data = static_cast<AppendResultsCallbackData *>(userData);
270 data->inferredTypes.reserve(n: data->inferredTypes.size() + nTypes);
271 for (intptr_t i = 0; i < nTypes; ++i) {
272 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
273 }
274 }
275
276 /// Given the arguments required to build an operation, attempts to infer its
277 /// return types. Throws value_error on failure.
278 std::vector<PyType>
279 inferReturnTypes(std::optional<py::list> operandList,
280 std::optional<PyAttribute> attributes, void *properties,
281 std::optional<std::vector<PyRegion>> regions,
282 DefaultingPyMlirContext context,
283 DefaultingPyLocation location) {
284 llvm::SmallVector<MlirValue> mlirOperands =
285 wrapOperands(std::move(operandList));
286 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
287
288 std::vector<PyType> inferredTypes;
289 PyMlirContext &pyContext = context.resolve();
290 AppendResultsCallbackData data{.inferredTypes: inferredTypes, .pyMlirContext: pyContext};
291 MlirStringRef opNameRef =
292 mlirStringRefCreate(getOpName().data(), getOpName().length());
293 MlirAttribute attributeDict =
294 attributes ? attributes->get() : mlirAttributeGetNull();
295
296 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
297 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
298 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
299 mlirRegions.data(), &appendResultsCallback, &data);
300
301 if (mlirLogicalResultIsFailure(result)) {
302 throw py::value_error("Failed to infer result types");
303 }
304
305 return inferredTypes;
306 }
307
308 static void bindDerived(ClassTy &cls) {
309 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
310 py::arg("operands") = py::none(),
311 py::arg("attributes") = py::none(),
312 py::arg("properties") = py::none(), py::arg("regions") = py::none(),
313 py::arg("context") = py::none(), py::arg("loc") = py::none(),
314 inferReturnTypesDoc);
315 }
316};
317
318/// Wrapper around an shaped type components.
319class PyShapedTypeComponents {
320public:
321 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
322 PyShapedTypeComponents(py::list shape, MlirType elementType)
323 : shape(std::move(shape)), elementType(elementType), ranked(true) {}
324 PyShapedTypeComponents(py::list shape, MlirType elementType,
325 MlirAttribute attribute)
326 : shape(std::move(shape)), elementType(elementType), attribute(attribute),
327 ranked(true) {}
328 PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
329 PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
330 : shape(other.shape), elementType(other.elementType),
331 attribute(other.attribute), ranked(other.ranked) {}
332
333 static void bind(py::module &m) {
334 py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
335 py::module_local())
336 .def_property_readonly(
337 "element_type",
338 [](PyShapedTypeComponents &self) { return self.elementType; },
339 "Returns the element type of the shaped type components.")
340 .def_static(
341 "get",
342 [](PyType &elementType) {
343 return PyShapedTypeComponents(elementType);
344 },
345 py::arg("element_type"),
346 "Create an shaped type components object with only the element "
347 "type.")
348 .def_static(
349 "get",
350 [](py::list shape, PyType &elementType) {
351 return PyShapedTypeComponents(std::move(shape), elementType);
352 },
353 py::arg("shape"), py::arg("element_type"),
354 "Create a ranked shaped type components object.")
355 .def_static(
356 "get",
357 [](py::list shape, PyType &elementType, PyAttribute &attribute) {
358 return PyShapedTypeComponents(std::move(shape), elementType,
359 attribute);
360 },
361 py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
362 "Create a ranked shaped type components object with attribute.")
363 .def_property_readonly(
364 "has_rank",
365 [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
366 "Returns whether the given shaped type component is ranked.")
367 .def_property_readonly(
368 "rank",
369 [](PyShapedTypeComponents &self) -> py::object {
370 if (!self.ranked) {
371 return py::none();
372 }
373 return py::int_(self.shape.size());
374 },
375 "Returns the rank of the given ranked shaped type components. If "
376 "the shaped type components does not have a rank, None is "
377 "returned.")
378 .def_property_readonly(
379 "shape",
380 [](PyShapedTypeComponents &self) -> py::object {
381 if (!self.ranked) {
382 return py::none();
383 }
384 return py::list(self.shape);
385 },
386 "Returns the shape of the ranked shaped type components as a list "
387 "of integers. Returns none if the shaped type component does not "
388 "have a rank.");
389 }
390
391 pybind11::object getCapsule();
392 static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
393
394private:
395 py::list shape;
396 MlirType elementType;
397 MlirAttribute attribute;
398 bool ranked{false};
399};
400
401/// Python wrapper for InferShapedTypeOpInterface. This interface has only
402/// static methods.
403class PyInferShapedTypeOpInterface
404 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
405public:
406 using PyConcreteOpInterface<
407 PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
408
409 constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
410 constexpr static GetTypeIDFunctionTy getInterfaceID =
411 &mlirInferShapedTypeOpInterfaceTypeID;
412
413 /// C-style user-data structure for type appending callback.
414 struct AppendResultsCallbackData {
415 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
416 };
417
418 /// Appends the shaped type components provided as unpacked shape, element
419 /// type, attribute to the user-data.
420 static void appendResultsCallback(bool hasRank, intptr_t rank,
421 const int64_t *shape, MlirType elementType,
422 MlirAttribute attribute, void *userData) {
423 auto *data = static_cast<AppendResultsCallbackData *>(userData);
424 if (!hasRank) {
425 data->inferredShapedTypeComponents.emplace_back(elementType);
426 } else {
427 py::list shapeList;
428 for (intptr_t i = 0; i < rank; ++i) {
429 shapeList.append(shape[i]);
430 }
431 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
432 attribute);
433 }
434 }
435
436 /// Given the arguments required to build an operation, attempts to infer the
437 /// shaped type components. Throws value_error on failure.
438 std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
439 std::optional<py::list> operandList,
440 std::optional<PyAttribute> attributes, void *properties,
441 std::optional<std::vector<PyRegion>> regions,
442 DefaultingPyMlirContext context, DefaultingPyLocation location) {
443 llvm::SmallVector<MlirValue> mlirOperands =
444 wrapOperands(std::move(operandList));
445 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
446
447 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
448 PyMlirContext &pyContext = context.resolve();
449 AppendResultsCallbackData data{.inferredShapedTypeComponents: inferredShapedTypeComponents};
450 MlirStringRef opNameRef =
451 mlirStringRefCreate(getOpName().data(), getOpName().length());
452 MlirAttribute attributeDict =
453 attributes ? attributes->get() : mlirAttributeGetNull();
454
455 MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
456 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
457 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
458 mlirRegions.data(), &appendResultsCallback, &data);
459
460 if (mlirLogicalResultIsFailure(result)) {
461 throw py::value_error("Failed to infer result shape type components");
462 }
463
464 return inferredShapedTypeComponents;
465 }
466
467 static void bindDerived(ClassTy &cls) {
468 cls.def("inferReturnTypeComponents",
469 &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
470 py::arg("operands") = py::none(),
471 py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
472 py::arg("properties") = py::none(), py::arg("context") = py::none(),
473 py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
474 }
475};
476
477void populateIRInterfaces(py::module &m) {
478 PyInferTypeOpInterface::bind(m);
479 PyShapedTypeComponents::bind(m);
480 PyInferShapedTypeOpInterface::bind(m);
481}
482
483} // namespace python
484} // namespace mlir
485

source code of mlir/lib/Bindings/Python/IRInterfaces.cpp