1//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstdint>
10#include <optional>
11#include <string>
12#include <utility>
13#include <vector>
14
15#include "IRModule.h"
16#include "mlir-c/BuiltinAttributes.h"
17#include "mlir-c/IR.h"
18#include "mlir-c/Interfaces.h"
19#include "mlir-c/Support.h"
20#include "mlir/Bindings/Python/Nanobind.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/SmallVector.h"
23
24namespace nb = nanobind;
25
26namespace mlir {
27namespace python {
28
29constexpr static const char *constructorDoc =
30 R"(Creates an interface from a given operation/opview object or from a
31subclass of OpView. Raises ValueError if the operation does not implement the
32interface.)";
33
34constexpr static const char *operationDoc =
35 R"(Returns an Operation for which the interface was constructed.)";
36
37constexpr static const char *opviewDoc =
38 R"(Returns an OpView subclass _instance_ for which the interface was
39constructed)";
40
41constexpr static const char *inferReturnTypesDoc =
42 R"(Given the arguments required to build an operation, attempts to infer
43its return types. Raises ValueError on failure.)";
44
45constexpr static const char *inferReturnTypeComponentsDoc =
46 R"(Given the arguments required to build an operation, attempts to infer
47its return shaped type components. Raises ValueError on failure.)";
48
49namespace {
50
51/// Takes in an optional ist of operands and converts them into a SmallVector
52/// of MlirVlaues. Returns an empty SmallVector if the list is empty.
53llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
54 llvm::SmallVector<MlirValue> mlirOperands;
55
56 if (!operandList || operandList->size() == 0) {
57 return mlirOperands;
58 }
59
60 // Note: as the list may contain other lists this may not be final size.
61 mlirOperands.reserve(operandList->size());
62 for (const auto &&it : llvm::enumerate(*operandList)) {
63 if (it.value().is_none())
64 continue;
65
66 PyValue *val;
67 try {
68 val = nb::cast<PyValue *>(it.value());
69 if (!val)
70 throw nb::cast_error();
71 mlirOperands.push_back(val->get());
72 continue;
73 } catch (nb::cast_error &err) {
74 // Intentionally unhandled to try sequence below first.
75 (void)err;
76 }
77
78 try {
79 auto vals = nb::cast<nb::sequence>(it.value());
80 for (nb::handle v : vals) {
81 try {
82 val = nb::cast<PyValue *>(v);
83 if (!val)
84 throw nb::cast_error();
85 mlirOperands.push_back(val->get());
86 } catch (nb::cast_error &err) {
87 throw nb::value_error(
88 (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
89 " must be a Value or Sequence of Values (" + err.what() + ")")
90 .str()
91 .c_str());
92 }
93 }
94 continue;
95 } catch (nb::cast_error &err) {
96 throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
97 " must be a Value or Sequence of Values (" +
98 err.what() + ")")
99 .str()
100 .c_str());
101 }
102
103 throw nb::cast_error();
104 }
105
106 return mlirOperands;
107}
108
109/// Takes in an optional vector of PyRegions and returns a SmallVector of
110/// MlirRegion. Returns an empty SmallVector if the list is empty.
111llvm::SmallVector<MlirRegion>
112wrapRegions(std::optional<std::vector<PyRegion>> regions) {
113 llvm::SmallVector<MlirRegion> mlirRegions;
114
115 if (regions) {
116 mlirRegions.reserve(regions->size());
117 for (PyRegion &region : *regions) {
118 mlirRegions.push_back(region);
119 }
120 }
121
122 return mlirRegions;
123}
124
125} // namespace
126
127/// CRTP base class for Python classes representing MLIR Op interfaces.
128/// Interface hierarchies are flat so no base class is expected here. The
129/// derived class is expected to define the following static fields:
130/// - `const char *pyClassName` - the name of the Python class to create;
131/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
132/// of the interface.
133/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
134/// interface-specific methods.
135///
136/// An interface class may be constructed from either an Operation/OpView object
137/// or from a subclass of OpView. In the latter case, only the static interface
138/// methods are available, similarly to calling ConcereteOp::staticMethod on the
139/// C++ side. Implementations of concrete interfaces can use the `isStatic`
140/// method to check whether the interface object was constructed from a class or
141/// an operation/opview instance. The `getOpName` always succeeds and returns a
142/// canonical name of the operation suitable for lookups.
143template <typename ConcreteIface>
144class PyConcreteOpInterface {
145protected:
146 using ClassTy = nb::class_<ConcreteIface>;
147 using GetTypeIDFunctionTy = MlirTypeID (*)();
148
149public:
150 /// Constructs an interface instance from an object that is either an
151 /// operation or a subclass of OpView. In the latter case, only the static
152 /// methods of the interface are accessible to the caller.
153 PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
154 : obj(std::move(object)) {
155 try {
156 operation = &nb::cast<PyOperation &>(obj);
157 } catch (nb::cast_error &) {
158 // Do nothing.
159 }
160
161 try {
162 operation = &nb::cast<PyOpView &>(obj).getOperation();
163 } catch (nb::cast_error &) {
164 // Do nothing.
165 }
166
167 if (operation != nullptr) {
168 if (!mlirOperationImplementsInterface(*operation,
169 ConcreteIface::getInterfaceID())) {
170 std::string msg = "the operation does not implement ";
171 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
172 }
173
174 MlirIdentifier identifier = mlirOperationGetName(*operation);
175 MlirStringRef stringRef = mlirIdentifierStr(identifier);
176 opName = std::string(stringRef.data, stringRef.length);
177 } else {
178 try {
179 opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
180 } catch (nb::cast_error &) {
181 throw nb::type_error(
182 "Op interface does not refer to an operation or OpView class");
183 }
184
185 if (!mlirOperationImplementsInterfaceStatic(
186 mlirStringRefCreate(opName.data(), opName.length()),
187 context.resolve().get(), ConcreteIface::getInterfaceID())) {
188 std::string msg = "the operation does not implement ";
189 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
190 }
191 }
192 }
193
194 /// Creates the Python bindings for this class in the given module.
195 static void bind(nb::module_ &m) {
196 nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
197 cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
198 nb::arg("context").none() = nb::none(), constructorDoc)
199 .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
200 operationDoc)
201 .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
202 ConcreteIface::bindDerived(cls);
203 }
204
205 /// Hook for derived classes to add class-specific bindings.
206 static void bindDerived(ClassTy &cls) {}
207
208 /// Returns `true` if this object was constructed from a subclass of OpView
209 /// rather than from an operation instance.
210 bool isStatic() { return operation == nullptr; }
211
212 /// Returns the operation instance from which this object was constructed.
213 /// Throws a type error if this object was constructed from a subclass of
214 /// OpView.
215 nb::object getOperationObject() {
216 if (operation == nullptr) {
217 throw nb::type_error("Cannot get an operation from a static interface");
218 }
219
220 return operation->getRef().releaseObject();
221 }
222
223 /// Returns the opview of the operation instance from which this object was
224 /// constructed. Throws a type error if this object was constructed form a
225 /// subclass of OpView.
226 nb::object getOpView() {
227 if (operation == nullptr) {
228 throw nb::type_error("Cannot get an opview from a static interface");
229 }
230
231 return operation->createOpView();
232 }
233
234 /// Returns the canonical name of the operation this interface is constructed
235 /// from.
236 const std::string &getOpName() { return opName; }
237
238private:
239 PyOperation *operation = nullptr;
240 std::string opName;
241 nb::object obj;
242};
243
244/// Python wrapper for InferTypeOpInterface. This interface has only static
245/// methods.
246class PyInferTypeOpInterface
247 : public PyConcreteOpInterface<PyInferTypeOpInterface> {
248public:
249 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
250
251 constexpr static const char *pyClassName = "InferTypeOpInterface";
252 constexpr static GetTypeIDFunctionTy getInterfaceID =
253 &mlirInferTypeOpInterfaceTypeID;
254
255 /// C-style user-data structure for type appending callback.
256 struct AppendResultsCallbackData {
257 std::vector<PyType> &inferredTypes;
258 PyMlirContext &pyMlirContext;
259 };
260
261 /// Appends the types provided as the two first arguments to the user-data
262 /// structure (expects AppendResultsCallbackData).
263 static void appendResultsCallback(intptr_t nTypes, MlirType *types,
264 void *userData) {
265 auto *data = static_cast<AppendResultsCallbackData *>(userData);
266 data->inferredTypes.reserve(n: data->inferredTypes.size() + nTypes);
267 for (intptr_t i = 0; i < nTypes; ++i) {
268 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
269 }
270 }
271
272 /// Given the arguments required to build an operation, attempts to infer its
273 /// return types. Throws value_error on failure.
274 std::vector<PyType>
275 inferReturnTypes(std::optional<nb::list> operandList,
276 std::optional<PyAttribute> attributes, void *properties,
277 std::optional<std::vector<PyRegion>> regions,
278 DefaultingPyMlirContext context,
279 DefaultingPyLocation location) {
280 llvm::SmallVector<MlirValue> mlirOperands =
281 wrapOperands(std::move(operandList));
282 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
283
284 std::vector<PyType> inferredTypes;
285 PyMlirContext &pyContext = context.resolve();
286 AppendResultsCallbackData data{.inferredTypes: inferredTypes, .pyMlirContext: pyContext};
287 MlirStringRef opNameRef =
288 mlirStringRefCreate(getOpName().data(), getOpName().length());
289 MlirAttribute attributeDict =
290 attributes ? attributes->get() : mlirAttributeGetNull();
291
292 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
293 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
294 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
295 mlirRegions.data(), &appendResultsCallback, &data);
296
297 if (mlirLogicalResultIsFailure(result)) {
298 throw nb::value_error("Failed to infer result types");
299 }
300
301 return inferredTypes;
302 }
303
304 static void bindDerived(ClassTy &cls) {
305 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
306 nb::arg("operands").none() = nb::none(),
307 nb::arg("attributes").none() = nb::none(),
308 nb::arg("properties").none() = nb::none(),
309 nb::arg("regions").none() = nb::none(),
310 nb::arg("context").none() = nb::none(),
311 nb::arg("loc").none() = nb::none(), inferReturnTypesDoc);
312 }
313};
314
315/// Wrapper around an shaped type components.
316class PyShapedTypeComponents {
317public:
318 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
319 PyShapedTypeComponents(nb::list shape, MlirType elementType)
320 : shape(std::move(shape)), elementType(elementType), ranked(true) {}
321 PyShapedTypeComponents(nb::list shape, MlirType elementType,
322 MlirAttribute attribute)
323 : shape(std::move(shape)), elementType(elementType), attribute(attribute),
324 ranked(true) {}
325 PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
326 PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
327 : shape(other.shape), elementType(other.elementType),
328 attribute(other.attribute), ranked(other.ranked) {}
329
330 static void bind(nb::module_ &m) {
331 nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
332 .def_prop_ro(
333 "element_type",
334 [](PyShapedTypeComponents &self) { return self.elementType; },
335 "Returns the element type of the shaped type components.")
336 .def_static(
337 "get",
338 [](PyType &elementType) {
339 return PyShapedTypeComponents(elementType);
340 },
341 nb::arg("element_type"),
342 "Create an shaped type components object with only the element "
343 "type.")
344 .def_static(
345 "get",
346 [](nb::list shape, PyType &elementType) {
347 return PyShapedTypeComponents(std::move(shape), elementType);
348 },
349 nb::arg("shape"), nb::arg("element_type"),
350 "Create a ranked shaped type components object.")
351 .def_static(
352 "get",
353 [](nb::list shape, PyType &elementType, PyAttribute &attribute) {
354 return PyShapedTypeComponents(std::move(shape), elementType,
355 attribute);
356 },
357 nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
358 "Create a ranked shaped type components object with attribute.")
359 .def_prop_ro(
360 "has_rank",
361 [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
362 "Returns whether the given shaped type component is ranked.")
363 .def_prop_ro(
364 "rank",
365 [](PyShapedTypeComponents &self) -> nb::object {
366 if (!self.ranked) {
367 return nb::none();
368 }
369 return nb::int_(self.shape.size());
370 },
371 "Returns the rank of the given ranked shaped type components. If "
372 "the shaped type components does not have a rank, None is "
373 "returned.")
374 .def_prop_ro(
375 "shape",
376 [](PyShapedTypeComponents &self) -> nb::object {
377 if (!self.ranked) {
378 return nb::none();
379 }
380 return nb::list(self.shape);
381 },
382 "Returns the shape of the ranked shaped type components as a list "
383 "of integers. Returns none if the shaped type component does not "
384 "have a rank.");
385 }
386
387 nb::object getCapsule();
388 static PyShapedTypeComponents createFromCapsule(nb::object capsule);
389
390private:
391 nb::list shape;
392 MlirType elementType;
393 MlirAttribute attribute;
394 bool ranked{false};
395};
396
397/// Python wrapper for InferShapedTypeOpInterface. This interface has only
398/// static methods.
399class PyInferShapedTypeOpInterface
400 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
401public:
402 using PyConcreteOpInterface<
403 PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
404
405 constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
406 constexpr static GetTypeIDFunctionTy getInterfaceID =
407 &mlirInferShapedTypeOpInterfaceTypeID;
408
409 /// C-style user-data structure for type appending callback.
410 struct AppendResultsCallbackData {
411 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
412 };
413
414 /// Appends the shaped type components provided as unpacked shape, element
415 /// type, attribute to the user-data.
416 static void appendResultsCallback(bool hasRank, intptr_t rank,
417 const int64_t *shape, MlirType elementType,
418 MlirAttribute attribute, void *userData) {
419 auto *data = static_cast<AppendResultsCallbackData *>(userData);
420 if (!hasRank) {
421 data->inferredShapedTypeComponents.emplace_back(elementType);
422 } else {
423 nb::list shapeList;
424 for (intptr_t i = 0; i < rank; ++i) {
425 shapeList.append(shape[i]);
426 }
427 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
428 attribute);
429 }
430 }
431
432 /// Given the arguments required to build an operation, attempts to infer the
433 /// shaped type components. Throws value_error on failure.
434 std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
435 std::optional<nb::list> operandList,
436 std::optional<PyAttribute> attributes, void *properties,
437 std::optional<std::vector<PyRegion>> regions,
438 DefaultingPyMlirContext context, DefaultingPyLocation location) {
439 llvm::SmallVector<MlirValue> mlirOperands =
440 wrapOperands(std::move(operandList));
441 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
442
443 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
444 PyMlirContext &pyContext = context.resolve();
445 AppendResultsCallbackData data{.inferredShapedTypeComponents: inferredShapedTypeComponents};
446 MlirStringRef opNameRef =
447 mlirStringRefCreate(getOpName().data(), getOpName().length());
448 MlirAttribute attributeDict =
449 attributes ? attributes->get() : mlirAttributeGetNull();
450
451 MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
452 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
453 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
454 mlirRegions.data(), &appendResultsCallback, &data);
455
456 if (mlirLogicalResultIsFailure(result)) {
457 throw nb::value_error("Failed to infer result shape type components");
458 }
459
460 return inferredShapedTypeComponents;
461 }
462
463 static void bindDerived(ClassTy &cls) {
464 cls.def("inferReturnTypeComponents",
465 &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
466 nb::arg("operands").none() = nb::none(),
467 nb::arg("attributes").none() = nb::none(),
468 nb::arg("regions").none() = nb::none(),
469 nb::arg("properties").none() = nb::none(),
470 nb::arg("context").none() = nb::none(),
471 nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc);
472 }
473};
474
475void populateIRInterfaces(nb::module_ &m) {
476 PyInferTypeOpInterface::bind(m);
477 PyShapedTypeComponents::bind(m);
478 PyInferShapedTypeOpInterface::bind(m);
479}
480
481} // namespace python
482} // namespace mlir
483

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Bindings/Python/IRInterfaces.cpp