1 | //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <cstdint> |
10 | #include <optional> |
11 | #include <string> |
12 | #include <utility> |
13 | #include <vector> |
14 | |
15 | #include "IRModule.h" |
16 | #include "mlir-c/BuiltinAttributes.h" |
17 | #include "mlir-c/IR.h" |
18 | #include "mlir-c/Interfaces.h" |
19 | #include "mlir-c/Support.h" |
20 | #include "mlir/Bindings/Python/Nanobind.h" |
21 | #include "llvm/ADT/STLExtras.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | |
24 | namespace nb = nanobind; |
25 | |
26 | namespace mlir { |
27 | namespace python { |
28 | |
29 | constexpr static const char *constructorDoc = |
30 | R"(Creates an interface from a given operation/opview object or from a |
31 | subclass of OpView. Raises ValueError if the operation does not implement the |
32 | interface.)"; |
33 | |
34 | constexpr static const char *operationDoc = |
35 | R"(Returns an Operation for which the interface was constructed.)"; |
36 | |
37 | constexpr static const char *opviewDoc = |
38 | R"(Returns an OpView subclass _instance_ for which the interface was |
39 | constructed)"; |
40 | |
41 | constexpr static const char *inferReturnTypesDoc = |
42 | R"(Given the arguments required to build an operation, attempts to infer |
43 | its return types. Raises ValueError on failure.)"; |
44 | |
45 | constexpr static const char *inferReturnTypeComponentsDoc = |
46 | R"(Given the arguments required to build an operation, attempts to infer |
47 | its return shaped type components. Raises ValueError on failure.)"; |
48 | |
49 | namespace { |
50 | |
51 | /// Takes in an optional ist of operands and converts them into a SmallVector |
52 | /// of MlirVlaues. Returns an empty SmallVector if the list is empty. |
53 | llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) { |
54 | llvm::SmallVector<MlirValue> mlirOperands; |
55 | |
56 | if (!operandList || operandList->size() == 0) { |
57 | return mlirOperands; |
58 | } |
59 | |
60 | // Note: as the list may contain other lists this may not be final size. |
61 | mlirOperands.reserve(operandList->size()); |
62 | for (const auto &&it : llvm::enumerate(*operandList)) { |
63 | if (it.value().is_none()) |
64 | continue; |
65 | |
66 | PyValue *val; |
67 | try { |
68 | val = nb::cast<PyValue *>(it.value()); |
69 | if (!val) |
70 | throw nb::cast_error(); |
71 | mlirOperands.push_back(val->get()); |
72 | continue; |
73 | } catch (nb::cast_error &err) { |
74 | // Intentionally unhandled to try sequence below first. |
75 | (void)err; |
76 | } |
77 | |
78 | try { |
79 | auto vals = nb::cast<nb::sequence>(it.value()); |
80 | for (nb::handle v : vals) { |
81 | try { |
82 | val = nb::cast<PyValue *>(v); |
83 | if (!val) |
84 | throw nb::cast_error(); |
85 | mlirOperands.push_back(val->get()); |
86 | } catch (nb::cast_error &err) { |
87 | throw nb::value_error( |
88 | (llvm::Twine("Operand ") + llvm::Twine(it.index()) + |
89 | " must be a Value or Sequence of Values ("+ err.what() + ")") |
90 | .str() |
91 | .c_str()); |
92 | } |
93 | } |
94 | continue; |
95 | } catch (nb::cast_error &err) { |
96 | throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + |
97 | " must be a Value or Sequence of Values ("+ |
98 | err.what() + ")") |
99 | .str() |
100 | .c_str()); |
101 | } |
102 | |
103 | throw nb::cast_error(); |
104 | } |
105 | |
106 | return mlirOperands; |
107 | } |
108 | |
109 | /// Takes in an optional vector of PyRegions and returns a SmallVector of |
110 | /// MlirRegion. Returns an empty SmallVector if the list is empty. |
111 | llvm::SmallVector<MlirRegion> |
112 | wrapRegions(std::optional<std::vector<PyRegion>> regions) { |
113 | llvm::SmallVector<MlirRegion> mlirRegions; |
114 | |
115 | if (regions) { |
116 | mlirRegions.reserve(regions->size()); |
117 | for (PyRegion ®ion : *regions) { |
118 | mlirRegions.push_back(region); |
119 | } |
120 | } |
121 | |
122 | return mlirRegions; |
123 | } |
124 | |
125 | } // namespace |
126 | |
127 | /// CRTP base class for Python classes representing MLIR Op interfaces. |
128 | /// Interface hierarchies are flat so no base class is expected here. The |
129 | /// derived class is expected to define the following static fields: |
130 | /// - `const char *pyClassName` - the name of the Python class to create; |
131 | /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID |
132 | /// of the interface. |
133 | /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind |
134 | /// interface-specific methods. |
135 | /// |
136 | /// An interface class may be constructed from either an Operation/OpView object |
137 | /// or from a subclass of OpView. In the latter case, only the static interface |
138 | /// methods are available, similarly to calling ConcereteOp::staticMethod on the |
139 | /// C++ side. Implementations of concrete interfaces can use the `isStatic` |
140 | /// method to check whether the interface object was constructed from a class or |
141 | /// an operation/opview instance. The `getOpName` always succeeds and returns a |
142 | /// canonical name of the operation suitable for lookups. |
143 | template <typename ConcreteIface> |
144 | class PyConcreteOpInterface { |
145 | protected: |
146 | using ClassTy = nb::class_<ConcreteIface>; |
147 | using GetTypeIDFunctionTy = MlirTypeID (*)(); |
148 | |
149 | public: |
150 | /// Constructs an interface instance from an object that is either an |
151 | /// operation or a subclass of OpView. In the latter case, only the static |
152 | /// methods of the interface are accessible to the caller. |
153 | PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) |
154 | : obj(std::move(object)) { |
155 | try { |
156 | operation = &nb::cast<PyOperation &>(obj); |
157 | } catch (nb::cast_error &) { |
158 | // Do nothing. |
159 | } |
160 | |
161 | try { |
162 | operation = &nb::cast<PyOpView &>(obj).getOperation(); |
163 | } catch (nb::cast_error &) { |
164 | // Do nothing. |
165 | } |
166 | |
167 | if (operation != nullptr) { |
168 | if (!mlirOperationImplementsInterface(*operation, |
169 | ConcreteIface::getInterfaceID())) { |
170 | std::string msg = "the operation does not implement "; |
171 | throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); |
172 | } |
173 | |
174 | MlirIdentifier identifier = mlirOperationGetName(*operation); |
175 | MlirStringRef stringRef = mlirIdentifierStr(identifier); |
176 | opName = std::string(stringRef.data, stringRef.length); |
177 | } else { |
178 | try { |
179 | opName = nb::cast<std::string>(obj.attr("OPERATION_NAME")); |
180 | } catch (nb::cast_error &) { |
181 | throw nb::type_error( |
182 | "Op interface does not refer to an operation or OpView class"); |
183 | } |
184 | |
185 | if (!mlirOperationImplementsInterfaceStatic( |
186 | mlirStringRefCreate(opName.data(), opName.length()), |
187 | context.resolve().get(), ConcreteIface::getInterfaceID())) { |
188 | std::string msg = "the operation does not implement "; |
189 | throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); |
190 | } |
191 | } |
192 | } |
193 | |
194 | /// Creates the Python bindings for this class in the given module. |
195 | static void bind(nb::module_ &m) { |
196 | nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName); |
197 | cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"), |
198 | nb::arg("context").none() = nb::none(), constructorDoc) |
199 | .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, |
200 | operationDoc) |
201 | .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); |
202 | ConcreteIface::bindDerived(cls); |
203 | } |
204 | |
205 | /// Hook for derived classes to add class-specific bindings. |
206 | static void bindDerived(ClassTy &cls) {} |
207 | |
208 | /// Returns `true` if this object was constructed from a subclass of OpView |
209 | /// rather than from an operation instance. |
210 | bool isStatic() { return operation == nullptr; } |
211 | |
212 | /// Returns the operation instance from which this object was constructed. |
213 | /// Throws a type error if this object was constructed from a subclass of |
214 | /// OpView. |
215 | nb::object getOperationObject() { |
216 | if (operation == nullptr) { |
217 | throw nb::type_error("Cannot get an operation from a static interface"); |
218 | } |
219 | |
220 | return operation->getRef().releaseObject(); |
221 | } |
222 | |
223 | /// Returns the opview of the operation instance from which this object was |
224 | /// constructed. Throws a type error if this object was constructed form a |
225 | /// subclass of OpView. |
226 | nb::object getOpView() { |
227 | if (operation == nullptr) { |
228 | throw nb::type_error("Cannot get an opview from a static interface"); |
229 | } |
230 | |
231 | return operation->createOpView(); |
232 | } |
233 | |
234 | /// Returns the canonical name of the operation this interface is constructed |
235 | /// from. |
236 | const std::string &getOpName() { return opName; } |
237 | |
238 | private: |
239 | PyOperation *operation = nullptr; |
240 | std::string opName; |
241 | nb::object obj; |
242 | }; |
243 | |
244 | /// Python wrapper for InferTypeOpInterface. This interface has only static |
245 | /// methods. |
246 | class PyInferTypeOpInterface |
247 | : public PyConcreteOpInterface<PyInferTypeOpInterface> { |
248 | public: |
249 | using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface; |
250 | |
251 | constexpr static const char *pyClassName = "InferTypeOpInterface"; |
252 | constexpr static GetTypeIDFunctionTy getInterfaceID = |
253 | &mlirInferTypeOpInterfaceTypeID; |
254 | |
255 | /// C-style user-data structure for type appending callback. |
256 | struct AppendResultsCallbackData { |
257 | std::vector<PyType> &inferredTypes; |
258 | PyMlirContext &pyMlirContext; |
259 | }; |
260 | |
261 | /// Appends the types provided as the two first arguments to the user-data |
262 | /// structure (expects AppendResultsCallbackData). |
263 | static void appendResultsCallback(intptr_t nTypes, MlirType *types, |
264 | void *userData) { |
265 | auto *data = static_cast<AppendResultsCallbackData *>(userData); |
266 | data->inferredTypes.reserve(n: data->inferredTypes.size() + nTypes); |
267 | for (intptr_t i = 0; i < nTypes; ++i) { |
268 | data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); |
269 | } |
270 | } |
271 | |
272 | /// Given the arguments required to build an operation, attempts to infer its |
273 | /// return types. Throws value_error on failure. |
274 | std::vector<PyType> |
275 | inferReturnTypes(std::optional<nb::list> operandList, |
276 | std::optional<PyAttribute> attributes, void *properties, |
277 | std::optional<std::vector<PyRegion>> regions, |
278 | DefaultingPyMlirContext context, |
279 | DefaultingPyLocation location) { |
280 | llvm::SmallVector<MlirValue> mlirOperands = |
281 | wrapOperands(std::move(operandList)); |
282 | llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); |
283 | |
284 | std::vector<PyType> inferredTypes; |
285 | PyMlirContext &pyContext = context.resolve(); |
286 | AppendResultsCallbackData data{.inferredTypes: inferredTypes, .pyMlirContext: pyContext}; |
287 | MlirStringRef opNameRef = |
288 | mlirStringRefCreate(getOpName().data(), getOpName().length()); |
289 | MlirAttribute attributeDict = |
290 | attributes ? attributes->get() : mlirAttributeGetNull(); |
291 | |
292 | MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( |
293 | opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), |
294 | mlirOperands.data(), attributeDict, properties, mlirRegions.size(), |
295 | mlirRegions.data(), &appendResultsCallback, &data); |
296 | |
297 | if (mlirLogicalResultIsFailure(result)) { |
298 | throw nb::value_error("Failed to infer result types"); |
299 | } |
300 | |
301 | return inferredTypes; |
302 | } |
303 | |
304 | static void bindDerived(ClassTy &cls) { |
305 | cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, |
306 | nb::arg("operands").none() = nb::none(), |
307 | nb::arg("attributes").none() = nb::none(), |
308 | nb::arg("properties").none() = nb::none(), |
309 | nb::arg("regions").none() = nb::none(), |
310 | nb::arg("context").none() = nb::none(), |
311 | nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); |
312 | } |
313 | }; |
314 | |
315 | /// Wrapper around an shaped type components. |
316 | class PyShapedTypeComponents { |
317 | public: |
318 | PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} |
319 | PyShapedTypeComponents(nb::list shape, MlirType elementType) |
320 | : shape(std::move(shape)), elementType(elementType), ranked(true) {} |
321 | PyShapedTypeComponents(nb::list shape, MlirType elementType, |
322 | MlirAttribute attribute) |
323 | : shape(std::move(shape)), elementType(elementType), attribute(attribute), |
324 | ranked(true) {} |
325 | PyShapedTypeComponents(PyShapedTypeComponents &) = delete; |
326 | PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept |
327 | : shape(other.shape), elementType(other.elementType), |
328 | attribute(other.attribute), ranked(other.ranked) {} |
329 | |
330 | static void bind(nb::module_ &m) { |
331 | nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents") |
332 | .def_prop_ro( |
333 | "element_type", |
334 | [](PyShapedTypeComponents &self) { return self.elementType; }, |
335 | "Returns the element type of the shaped type components.") |
336 | .def_static( |
337 | "get", |
338 | [](PyType &elementType) { |
339 | return PyShapedTypeComponents(elementType); |
340 | }, |
341 | nb::arg("element_type"), |
342 | "Create an shaped type components object with only the element " |
343 | "type.") |
344 | .def_static( |
345 | "get", |
346 | [](nb::list shape, PyType &elementType) { |
347 | return PyShapedTypeComponents(std::move(shape), elementType); |
348 | }, |
349 | nb::arg("shape"), nb::arg( "element_type"), |
350 | "Create a ranked shaped type components object.") |
351 | .def_static( |
352 | "get", |
353 | [](nb::list shape, PyType &elementType, PyAttribute &attribute) { |
354 | return PyShapedTypeComponents(std::move(shape), elementType, |
355 | attribute); |
356 | }, |
357 | nb::arg("shape"), nb::arg( "element_type"), nb::arg( "attribute"), |
358 | "Create a ranked shaped type components object with attribute.") |
359 | .def_prop_ro( |
360 | "has_rank", |
361 | [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, |
362 | "Returns whether the given shaped type component is ranked.") |
363 | .def_prop_ro( |
364 | "rank", |
365 | [](PyShapedTypeComponents &self) -> nb::object { |
366 | if (!self.ranked) { |
367 | return nb::none(); |
368 | } |
369 | return nb::int_(self.shape.size()); |
370 | }, |
371 | "Returns the rank of the given ranked shaped type components. If " |
372 | "the shaped type components does not have a rank, None is " |
373 | "returned.") |
374 | .def_prop_ro( |
375 | "shape", |
376 | [](PyShapedTypeComponents &self) -> nb::object { |
377 | if (!self.ranked) { |
378 | return nb::none(); |
379 | } |
380 | return nb::list(self.shape); |
381 | }, |
382 | "Returns the shape of the ranked shaped type components as a list " |
383 | "of integers. Returns none if the shaped type component does not " |
384 | "have a rank."); |
385 | } |
386 | |
387 | nb::object getCapsule(); |
388 | static PyShapedTypeComponents createFromCapsule(nb::object capsule); |
389 | |
390 | private: |
391 | nb::list shape; |
392 | MlirType elementType; |
393 | MlirAttribute attribute; |
394 | bool ranked{false}; |
395 | }; |
396 | |
397 | /// Python wrapper for InferShapedTypeOpInterface. This interface has only |
398 | /// static methods. |
399 | class PyInferShapedTypeOpInterface |
400 | : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> { |
401 | public: |
402 | using PyConcreteOpInterface< |
403 | PyInferShapedTypeOpInterface>::PyConcreteOpInterface; |
404 | |
405 | constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; |
406 | constexpr static GetTypeIDFunctionTy getInterfaceID = |
407 | &mlirInferShapedTypeOpInterfaceTypeID; |
408 | |
409 | /// C-style user-data structure for type appending callback. |
410 | struct AppendResultsCallbackData { |
411 | std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents; |
412 | }; |
413 | |
414 | /// Appends the shaped type components provided as unpacked shape, element |
415 | /// type, attribute to the user-data. |
416 | static void appendResultsCallback(bool hasRank, intptr_t rank, |
417 | const int64_t *shape, MlirType elementType, |
418 | MlirAttribute attribute, void *userData) { |
419 | auto *data = static_cast<AppendResultsCallbackData *>(userData); |
420 | if (!hasRank) { |
421 | data->inferredShapedTypeComponents.emplace_back(elementType); |
422 | } else { |
423 | nb::list shapeList; |
424 | for (intptr_t i = 0; i < rank; ++i) { |
425 | shapeList.append(shape[i]); |
426 | } |
427 | data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, |
428 | attribute); |
429 | } |
430 | } |
431 | |
432 | /// Given the arguments required to build an operation, attempts to infer the |
433 | /// shaped type components. Throws value_error on failure. |
434 | std::vector<PyShapedTypeComponents> inferReturnTypeComponents( |
435 | std::optional<nb::list> operandList, |
436 | std::optional<PyAttribute> attributes, void *properties, |
437 | std::optional<std::vector<PyRegion>> regions, |
438 | DefaultingPyMlirContext context, DefaultingPyLocation location) { |
439 | llvm::SmallVector<MlirValue> mlirOperands = |
440 | wrapOperands(std::move(operandList)); |
441 | llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions)); |
442 | |
443 | std::vector<PyShapedTypeComponents> inferredShapedTypeComponents; |
444 | PyMlirContext &pyContext = context.resolve(); |
445 | AppendResultsCallbackData data{.inferredShapedTypeComponents: inferredShapedTypeComponents}; |
446 | MlirStringRef opNameRef = |
447 | mlirStringRefCreate(getOpName().data(), getOpName().length()); |
448 | MlirAttribute attributeDict = |
449 | attributes ? attributes->get() : mlirAttributeGetNull(); |
450 | |
451 | MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( |
452 | opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), |
453 | mlirOperands.data(), attributeDict, properties, mlirRegions.size(), |
454 | mlirRegions.data(), &appendResultsCallback, &data); |
455 | |
456 | if (mlirLogicalResultIsFailure(result)) { |
457 | throw nb::value_error("Failed to infer result shape type components"); |
458 | } |
459 | |
460 | return inferredShapedTypeComponents; |
461 | } |
462 | |
463 | static void bindDerived(ClassTy &cls) { |
464 | cls.def("inferReturnTypeComponents", |
465 | &PyInferShapedTypeOpInterface::inferReturnTypeComponents, |
466 | nb::arg("operands").none() = nb::none(), |
467 | nb::arg("attributes").none() = nb::none(), |
468 | nb::arg("regions").none() = nb::none(), |
469 | nb::arg("properties").none() = nb::none(), |
470 | nb::arg("context").none() = nb::none(), |
471 | nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); |
472 | } |
473 | }; |
474 | |
475 | void populateIRInterfaces(nb::module_ &m) { |
476 | PyInferTypeOpInterface::bind(m); |
477 | PyShapedTypeComponents::bind(m); |
478 | PyInferShapedTypeOpInterface::bind(m); |
479 | } |
480 | |
481 | } // namespace python |
482 | } // namespace mlir |
483 |
Definitions
- constructorDoc
- operationDoc
- opviewDoc
- inferReturnTypesDoc
- inferReturnTypeComponentsDoc
- wrapOperands
- wrapRegions
- PyConcreteOpInterface
- PyConcreteOpInterface
- bind
- bindDerived
- isStatic
- getOperationObject
- getOpView
- getOpName
- PyInferTypeOpInterface
- pyClassName
- getInterfaceID
- AppendResultsCallbackData
- appendResultsCallback
- inferReturnTypes
- bindDerived
- PyShapedTypeComponents
- PyShapedTypeComponents
- PyShapedTypeComponents
- PyShapedTypeComponents
- PyShapedTypeComponents
- PyShapedTypeComponents
- bind
- PyInferShapedTypeOpInterface
- pyClassName
- getInterfaceID
- AppendResultsCallbackData
- appendResultsCallback
- inferReturnTypeComponents
- bindDerived
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more