1 | //===- IRModules.h - IR Submodules of pybind module -----------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H |
11 | #define MLIR_BINDINGS_PYTHON_IRMODULES_H |
12 | |
13 | #include <optional> |
14 | #include <sstream> |
15 | #include <utility> |
16 | #include <vector> |
17 | |
18 | #include "Globals.h" |
19 | #include "NanobindUtils.h" |
20 | #include "mlir-c/AffineExpr.h" |
21 | #include "mlir-c/AffineMap.h" |
22 | #include "mlir-c/Diagnostics.h" |
23 | #include "mlir-c/IR.h" |
24 | #include "mlir-c/IntegerSet.h" |
25 | #include "mlir-c/Transforms.h" |
26 | #include "mlir/Bindings/Python/Nanobind.h" |
27 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
28 | #include "llvm/ADT/DenseMap.h" |
29 | #include "llvm/Support/ThreadPool.h" |
30 | |
31 | namespace mlir { |
32 | namespace python { |
33 | |
34 | class PyBlock; |
35 | class PyDiagnostic; |
36 | class PyDiagnosticHandler; |
37 | class PyInsertionPoint; |
38 | class PyLocation; |
39 | class DefaultingPyLocation; |
40 | class PyMlirContext; |
41 | class DefaultingPyMlirContext; |
42 | class PyModule; |
43 | class PyOperation; |
44 | class PyOperationBase; |
45 | class PyType; |
46 | class PySymbolTable; |
47 | class PyValue; |
48 | |
49 | /// Template for a reference to a concrete type which captures a python |
50 | /// reference to its underlying python object. |
51 | template <typename T> |
52 | class PyObjectRef { |
53 | public: |
54 | PyObjectRef(T *referrent, nanobind::object object) |
55 | : referrent(referrent), object(std::move(object)) { |
56 | assert(this->referrent && |
57 | "cannot construct PyObjectRef with null referrent"); |
58 | assert(this->object && "cannot construct PyObjectRef with null object"); |
59 | } |
60 | PyObjectRef(PyObjectRef &&other) noexcept |
61 | : referrent(other.referrent), object(std::move(other.object)) { |
62 | other.referrent = nullptr; |
63 | assert(!other.object); |
64 | } |
65 | PyObjectRef(const PyObjectRef &other) |
66 | : referrent(other.referrent), object(other.object /* copies */) {} |
67 | ~PyObjectRef() = default; |
68 | |
69 | int getRefCount() { |
70 | if (!object) |
71 | return 0; |
72 | return Py_REFCNT(object.ptr()); |
73 | } |
74 | |
75 | /// Releases the object held by this instance, returning it. |
76 | /// This is the proper thing to return from a function that wants to return |
77 | /// the reference. Note that this does not work from initializers. |
78 | nanobind::object releaseObject() { |
79 | assert(referrent && object); |
80 | referrent = nullptr; |
81 | auto stolen = std::move(object); |
82 | return stolen; |
83 | } |
84 | |
85 | T *get() { return referrent; } |
86 | T *operator->() { |
87 | assert(referrent && object); |
88 | return referrent; |
89 | } |
90 | nanobind::object getObject() { |
91 | assert(referrent && object); |
92 | return object; |
93 | } |
94 | operator bool() const { return referrent && object; } |
95 | |
96 | private: |
97 | T *referrent; |
98 | nanobind::object object; |
99 | }; |
100 | |
101 | /// Tracks an entry in the thread context stack. New entries are pushed onto |
102 | /// here for each with block that activates a new InsertionPoint, Context or |
103 | /// Location. |
104 | /// |
105 | /// Pushing either a Location or InsertionPoint also pushes its associated |
106 | /// Context. Pushing a Context will not modify the Location or InsertionPoint |
107 | /// unless if they are from a different context, in which case, they are |
108 | /// cleared. |
109 | class PyThreadContextEntry { |
110 | public: |
111 | enum class FrameKind { |
112 | Context, |
113 | InsertionPoint, |
114 | Location, |
115 | }; |
116 | |
117 | PyThreadContextEntry(FrameKind frameKind, nanobind::object context, |
118 | nanobind::object insertionPoint, |
119 | nanobind::object location) |
120 | : context(std::move(context)), insertionPoint(std::move(insertionPoint)), |
121 | location(std::move(location)), frameKind(frameKind) {} |
122 | |
123 | /// Gets the top of stack context and return nullptr if not defined. |
124 | static PyMlirContext *getDefaultContext(); |
125 | |
126 | /// Gets the top of stack insertion point and return nullptr if not defined. |
127 | static PyInsertionPoint *getDefaultInsertionPoint(); |
128 | |
129 | /// Gets the top of stack location and returns nullptr if not defined. |
130 | static PyLocation *getDefaultLocation(); |
131 | |
132 | PyMlirContext *getContext(); |
133 | PyInsertionPoint *getInsertionPoint(); |
134 | PyLocation *getLocation(); |
135 | FrameKind getFrameKind() { return frameKind; } |
136 | |
137 | /// Stack management. |
138 | static PyThreadContextEntry *getTopOfStack(); |
139 | static nanobind::object pushContext(nanobind::object context); |
140 | static void popContext(PyMlirContext &context); |
141 | static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); |
142 | static void popInsertionPoint(PyInsertionPoint &insertionPoint); |
143 | static nanobind::object pushLocation(nanobind::object location); |
144 | static void popLocation(PyLocation &location); |
145 | |
146 | /// Gets the thread local stack. |
147 | static std::vector<PyThreadContextEntry> &getStack(); |
148 | |
149 | private: |
150 | static void push(FrameKind frameKind, nanobind::object context, |
151 | nanobind::object insertionPoint, nanobind::object location); |
152 | |
153 | /// An object reference to the PyContext. |
154 | nanobind::object context; |
155 | /// An object reference to the current insertion point. |
156 | nanobind::object insertionPoint; |
157 | /// An object reference to the current location. |
158 | nanobind::object location; |
159 | // The kind of push that was performed. |
160 | FrameKind frameKind; |
161 | }; |
162 | |
163 | /// Wrapper around MlirLlvmThreadPool |
164 | /// Python object owns the C++ thread pool |
165 | class PyThreadPool { |
166 | public: |
167 | PyThreadPool() { |
168 | ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>(); |
169 | } |
170 | PyThreadPool(const PyThreadPool &) = delete; |
171 | PyThreadPool(PyThreadPool &&) = delete; |
172 | |
173 | int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); } |
174 | MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); } |
175 | |
176 | std::string _mlir_thread_pool_ptr() const { |
177 | std::stringstream ss; |
178 | ss << ownedThreadPool.get(); |
179 | return ss.str(); |
180 | } |
181 | |
182 | private: |
183 | std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool; |
184 | }; |
185 | |
186 | /// Wrapper around MlirContext. |
187 | using PyMlirContextRef = PyObjectRef<PyMlirContext>; |
188 | class PyMlirContext { |
189 | public: |
190 | PyMlirContext() = delete; |
191 | PyMlirContext(MlirContext context); |
192 | PyMlirContext(const PyMlirContext &) = delete; |
193 | PyMlirContext(PyMlirContext &&) = delete; |
194 | |
195 | /// For the case of a python __init__ (nanobind::init) method, pybind11 is |
196 | /// quite strict about needing to return a pointer that is not yet associated |
197 | /// to an nanobind::object. Since the forContext() method acts like a pool, |
198 | /// possibly returning a recycled context, it does not satisfy this need. The |
199 | /// usual way in python to accomplish such a thing is to override __new__, but |
200 | /// that is also not supported by pybind11. Instead, we use this entry |
201 | /// point which always constructs a fresh context (which cannot alias an |
202 | /// existing one because it is fresh). |
203 | static PyMlirContext *createNewContextForInit(); |
204 | |
205 | /// Returns a context reference for the singleton PyMlirContext wrapper for |
206 | /// the given context. |
207 | static PyMlirContextRef forContext(MlirContext context); |
208 | ~PyMlirContext(); |
209 | |
210 | /// Accesses the underlying MlirContext. |
211 | MlirContext get() { return context; } |
212 | |
213 | /// Gets a strong reference to this context, which will ensure it is kept |
214 | /// alive for the life of the reference. |
215 | PyMlirContextRef getRef() { |
216 | return PyMlirContextRef(this, nanobind::cast(this)); |
217 | } |
218 | |
219 | /// Gets a capsule wrapping the void* within the MlirContext. |
220 | nanobind::object getCapsule(); |
221 | |
222 | /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. |
223 | /// Note that PyMlirContext instances are uniqued, so the returned object |
224 | /// may be a pre-existing object. Ownership of the underlying MlirContext |
225 | /// is taken by calling this function. |
226 | static nanobind::object createFromCapsule(nanobind::object capsule); |
227 | |
228 | /// Gets the count of live context objects. Used for testing. |
229 | static size_t getLiveCount(); |
230 | |
231 | /// Get a list of Python objects which are still in the live context map. |
232 | std::vector<PyOperation *> getLiveOperationObjects(); |
233 | |
234 | /// Gets the count of live operations associated with this context. |
235 | /// Used for testing. |
236 | size_t getLiveOperationCount(); |
237 | |
238 | /// Clears the live operations map, returning the number of entries which were |
239 | /// invalidated. To be used as a safety mechanism so that API end-users can't |
240 | /// corrupt by holding references they shouldn't have accessed in the first |
241 | /// place. |
242 | size_t clearLiveOperations(); |
243 | |
244 | /// Removes an operation from the live operations map and sets it invalid. |
245 | /// This is useful for when some non-bindings code destroys the operation and |
246 | /// the bindings need to made aware. For example, in the case when pass |
247 | /// manager is run. |
248 | /// |
249 | /// Note that this does *NOT* clear the nested operations. |
250 | void clearOperation(MlirOperation op); |
251 | |
252 | /// Clears all operations nested inside the given op using |
253 | /// `clearOperation(MlirOperation)`. |
254 | void clearOperationsInside(PyOperationBase &op); |
255 | void clearOperationsInside(MlirOperation op); |
256 | |
257 | /// Clears the operaiton _and_ all operations inside using |
258 | /// `clearOperation(MlirOperation)`. |
259 | void clearOperationAndInside(PyOperationBase &op); |
260 | |
261 | /// Gets the count of live modules associated with this context. |
262 | /// Used for testing. |
263 | size_t getLiveModuleCount(); |
264 | |
265 | /// Enter and exit the context manager. |
266 | static nanobind::object contextEnter(nanobind::object context); |
267 | void contextExit(const nanobind::object &excType, |
268 | const nanobind::object &excVal, |
269 | const nanobind::object &excTb); |
270 | |
271 | /// Attaches a Python callback as a diagnostic handler, returning a |
272 | /// registration object (internally a PyDiagnosticHandler). |
273 | nanobind::object attachDiagnosticHandler(nanobind::object callback); |
274 | |
275 | /// Controls whether error diagnostics should be propagated to diagnostic |
276 | /// handlers, instead of being captured by `ErrorCapture`. |
277 | void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } |
278 | struct ErrorCapture; |
279 | |
280 | private: |
281 | // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, |
282 | // preserving the relationship that an MlirContext maps to a single |
283 | // PyMlirContext wrapper. This could be replaced in the future with an |
284 | // extension mechanism on the MlirContext for stashing user pointers. |
285 | // Note that this holds a handle, which does not imply ownership. |
286 | // Mappings will be removed when the context is destructed. |
287 | using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; |
288 | static nanobind::ft_mutex live_contexts_mutex; |
289 | static LiveContextMap &getLiveContexts(); |
290 | |
291 | // Interns all live modules associated with this context. Modules tracked |
292 | // in this map are valid. When a module is invalidated, it is removed |
293 | // from this map, and while it still exists as an instance, any |
294 | // attempt to access it will raise an error. |
295 | using LiveModuleMap = |
296 | llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>; |
297 | LiveModuleMap liveModules; |
298 | |
299 | // Interns all live operations associated with this context. Operations |
300 | // tracked in this map are valid. When an operation is invalidated, it is |
301 | // removed from this map, and while it still exists as an instance, any |
302 | // attempt to access it will raise an error. |
303 | using LiveOperationMap = |
304 | llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>; |
305 | nanobind::ft_mutex liveOperationsMutex; |
306 | |
307 | // Guarded by liveOperationsMutex in free-threading mode. |
308 | LiveOperationMap liveOperations; |
309 | |
310 | bool emitErrorDiagnostics = false; |
311 | |
312 | MlirContext context; |
313 | friend class PyModule; |
314 | friend class PyOperation; |
315 | }; |
316 | |
317 | /// Used in function arguments when None should resolve to the current context |
318 | /// manager set instance. |
319 | class DefaultingPyMlirContext |
320 | : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { |
321 | public: |
322 | using Defaulting::Defaulting; |
323 | static constexpr const char kTypeDescription[] = "mlir.ir.Context"; |
324 | static PyMlirContext &resolve(); |
325 | }; |
326 | |
327 | /// Base class for all objects that directly or indirectly depend on an |
328 | /// MlirContext. The lifetime of the context will extend at least to the |
329 | /// lifetime of these instances. |
330 | /// Immutable objects that depend on a context extend this directly. |
331 | class BaseContextObject { |
332 | public: |
333 | BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { |
334 | assert(this->contextRef && |
335 | "context object constructed with null context ref"); |
336 | } |
337 | |
338 | /// Accesses the context reference. |
339 | PyMlirContextRef &getContext() { return contextRef; } |
340 | |
341 | private: |
342 | PyMlirContextRef contextRef; |
343 | }; |
344 | |
345 | /// Wrapper around an MlirLocation. |
346 | class PyLocation : public BaseContextObject { |
347 | public: |
348 | PyLocation(PyMlirContextRef contextRef, MlirLocation loc) |
349 | : BaseContextObject(std::move(contextRef)), loc(loc) {} |
350 | |
351 | operator MlirLocation() const { return loc; } |
352 | MlirLocation get() const { return loc; } |
353 | |
354 | /// Enter and exit the context manager. |
355 | static nanobind::object contextEnter(nanobind::object location); |
356 | void contextExit(const nanobind::object &excType, |
357 | const nanobind::object &excVal, |
358 | const nanobind::object &excTb); |
359 | |
360 | /// Gets a capsule wrapping the void* within the MlirLocation. |
361 | nanobind::object getCapsule(); |
362 | |
363 | /// Creates a PyLocation from the MlirLocation wrapped by a capsule. |
364 | /// Note that PyLocation instances are uniqued, so the returned object |
365 | /// may be a pre-existing object. Ownership of the underlying MlirLocation |
366 | /// is taken by calling this function. |
367 | static PyLocation createFromCapsule(nanobind::object capsule); |
368 | |
369 | private: |
370 | MlirLocation loc; |
371 | }; |
372 | |
373 | /// Python class mirroring the C MlirDiagnostic struct. Note that these structs |
374 | /// are only valid for the duration of a diagnostic callback and attempting |
375 | /// to access them outside of that will raise an exception. This applies to |
376 | /// nested diagnostics (in the notes) as well. |
377 | class PyDiagnostic { |
378 | public: |
379 | PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} |
380 | void invalidate(); |
381 | bool isValid() { return valid; } |
382 | MlirDiagnosticSeverity getSeverity(); |
383 | PyLocation getLocation(); |
384 | nanobind::str getMessage(); |
385 | nanobind::tuple getNotes(); |
386 | |
387 | /// Materialized diagnostic information. This is safe to access outside the |
388 | /// diagnostic callback. |
389 | struct DiagnosticInfo { |
390 | MlirDiagnosticSeverity severity; |
391 | PyLocation location; |
392 | std::string message; |
393 | std::vector<DiagnosticInfo> notes; |
394 | }; |
395 | DiagnosticInfo getInfo(); |
396 | |
397 | private: |
398 | MlirDiagnostic diagnostic; |
399 | |
400 | void checkValid(); |
401 | /// If notes have been materialized from the diagnostic, then this will |
402 | /// be populated with the corresponding objects (all castable to |
403 | /// PyDiagnostic). |
404 | std::optional<nanobind::tuple> materializedNotes; |
405 | bool valid = true; |
406 | }; |
407 | |
408 | /// Represents a diagnostic handler attached to the context. The handler's |
409 | /// callback will be invoked with PyDiagnostic instances until the detach() |
410 | /// method is called or the context is destroyed. A diagnostic handler can be |
411 | /// the subject of a `with` block, which will detach it when the block exits. |
412 | /// |
413 | /// Since diagnostic handlers can call back into Python code which can do |
414 | /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, |
415 | /// etc), this is generally not deemed to be a great user-level API. Users |
416 | /// should generally use some form of DiagnosticCollector. If the handler raises |
417 | /// any exceptions, they will just be emitted to stderr and dropped. |
418 | /// |
419 | /// The unique usage of this class means that its lifetime management is |
420 | /// different from most other parts of the API. Instances are always created |
421 | /// in an attached state and can transition to a detached state by either: |
422 | /// a) The context being destroyed and unregistering all handlers. |
423 | /// b) An explicit call to detach(). |
424 | /// The object may remain live from a Python perspective for an arbitrary time |
425 | /// after detachment, but there is nothing the user can do with it (since there |
426 | /// is no way to attach an existing handler object). |
427 | class PyDiagnosticHandler { |
428 | public: |
429 | PyDiagnosticHandler(MlirContext context, nanobind::object callback); |
430 | ~PyDiagnosticHandler(); |
431 | |
432 | bool isAttached() { return registeredID.has_value(); } |
433 | bool getHadError() { return hadError; } |
434 | |
435 | /// Detaches the handler. Does nothing if not attached. |
436 | void detach(); |
437 | |
438 | nanobind::object contextEnter() { return nanobind::cast(this); } |
439 | void contextExit(const nanobind::object &excType, |
440 | const nanobind::object &excVal, |
441 | const nanobind::object &excTb) { |
442 | detach(); |
443 | } |
444 | |
445 | private: |
446 | MlirContext context; |
447 | nanobind::object callback; |
448 | std::optional<MlirDiagnosticHandlerID> registeredID; |
449 | bool hadError = false; |
450 | friend class PyMlirContext; |
451 | }; |
452 | |
453 | /// RAII object that captures any error diagnostics emitted to the provided |
454 | /// context. |
455 | struct PyMlirContext::ErrorCapture { |
456 | ErrorCapture(PyMlirContextRef ctx) |
457 | : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( |
458 | ctx->get(), handler, /*userData=*/this, |
459 | /*deleteUserData=*/nullptr)) {} |
460 | ~ErrorCapture() { |
461 | mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); |
462 | assert(errors.empty() && "unhandled captured errors"); |
463 | } |
464 | |
465 | std::vector<PyDiagnostic::DiagnosticInfo> take() { |
466 | return std::move(errors); |
467 | }; |
468 | |
469 | private: |
470 | PyMlirContextRef ctx; |
471 | MlirDiagnosticHandlerID handlerID; |
472 | std::vector<PyDiagnostic::DiagnosticInfo> errors; |
473 | |
474 | static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); |
475 | }; |
476 | |
477 | /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in |
478 | /// order to differentiate it from the `Dialect` base class which is extended by |
479 | /// plugins which extend dialect functionality through extension python code. |
480 | /// This should be seen as the "low-level" object and `Dialect` as the |
481 | /// high-level, user facing object. |
482 | class PyDialectDescriptor : public BaseContextObject { |
483 | public: |
484 | PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) |
485 | : BaseContextObject(std::move(contextRef)), dialect(dialect) {} |
486 | |
487 | MlirDialect get() { return dialect; } |
488 | |
489 | private: |
490 | MlirDialect dialect; |
491 | }; |
492 | |
493 | /// User-level object for accessing dialects with dotted syntax such as: |
494 | /// ctx.dialect.std |
495 | class PyDialects : public BaseContextObject { |
496 | public: |
497 | PyDialects(PyMlirContextRef contextRef) |
498 | : BaseContextObject(std::move(contextRef)) {} |
499 | |
500 | MlirDialect getDialectForKey(const std::string &key, bool attrError); |
501 | }; |
502 | |
503 | /// User-level dialect object. For dialects that have a registered extension, |
504 | /// this will be the base class of the extension dialect type. For un-extended, |
505 | /// objects of this type will be returned directly. |
506 | class PyDialect { |
507 | public: |
508 | PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} |
509 | |
510 | nanobind::object getDescriptor() { return descriptor; } |
511 | |
512 | private: |
513 | nanobind::object descriptor; |
514 | }; |
515 | |
516 | /// Wrapper around an MlirDialectRegistry. |
517 | /// Upon construction, the Python wrapper takes ownership of the |
518 | /// underlying MlirDialectRegistry. |
519 | class PyDialectRegistry { |
520 | public: |
521 | PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} |
522 | PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} |
523 | ~PyDialectRegistry() { |
524 | if (!mlirDialectRegistryIsNull(registry)) |
525 | mlirDialectRegistryDestroy(registry); |
526 | } |
527 | PyDialectRegistry(PyDialectRegistry &) = delete; |
528 | PyDialectRegistry(PyDialectRegistry &&other) noexcept |
529 | : registry(other.registry) { |
530 | other.registry = {nullptr}; |
531 | } |
532 | |
533 | operator MlirDialectRegistry() const { return registry; } |
534 | MlirDialectRegistry get() const { return registry; } |
535 | |
536 | nanobind::object getCapsule(); |
537 | static PyDialectRegistry createFromCapsule(nanobind::object capsule); |
538 | |
539 | private: |
540 | MlirDialectRegistry registry; |
541 | }; |
542 | |
543 | /// Used in function arguments when None should resolve to the current context |
544 | /// manager set instance. |
545 | class DefaultingPyLocation |
546 | : public Defaulting<DefaultingPyLocation, PyLocation> { |
547 | public: |
548 | using Defaulting::Defaulting; |
549 | static constexpr const char kTypeDescription[] = "mlir.ir.Location"; |
550 | static PyLocation &resolve(); |
551 | |
552 | operator MlirLocation() const { return *get(); } |
553 | }; |
554 | |
555 | /// Wrapper around MlirModule. |
556 | /// This is the top-level, user-owned object that contains regions/ops/blocks. |
557 | class PyModule; |
558 | using PyModuleRef = PyObjectRef<PyModule>; |
559 | class PyModule : public BaseContextObject { |
560 | public: |
561 | /// Returns a PyModule reference for the given MlirModule. This may return |
562 | /// a pre-existing or new object. |
563 | static PyModuleRef forModule(MlirModule module); |
564 | PyModule(PyModule &) = delete; |
565 | PyModule(PyMlirContext &&) = delete; |
566 | ~PyModule(); |
567 | |
568 | /// Gets the backing MlirModule. |
569 | MlirModule get() { return module; } |
570 | |
571 | /// Gets a strong reference to this module. |
572 | PyModuleRef getRef() { |
573 | return PyModuleRef(this, nanobind::borrow<nanobind::object>(handle)); |
574 | } |
575 | |
576 | /// Gets a capsule wrapping the void* within the MlirModule. |
577 | /// Note that the module does not (yet) provide a corresponding factory for |
578 | /// constructing from a capsule as that would require uniquing PyModule |
579 | /// instances, which is not currently done. |
580 | nanobind::object getCapsule(); |
581 | |
582 | /// Creates a PyModule from the MlirModule wrapped by a capsule. |
583 | /// Note that PyModule instances are uniqued, so the returned object |
584 | /// may be a pre-existing object. Ownership of the underlying MlirModule |
585 | /// is taken by calling this function. |
586 | static nanobind::object createFromCapsule(nanobind::object capsule); |
587 | |
588 | private: |
589 | PyModule(PyMlirContextRef contextRef, MlirModule module); |
590 | MlirModule module; |
591 | nanobind::handle handle; |
592 | }; |
593 | |
594 | class PyAsmState; |
595 | |
596 | /// Base class for PyOperation and PyOpView which exposes the primary, user |
597 | /// visible methods for manipulating it. |
598 | class PyOperationBase { |
599 | public: |
600 | virtual ~PyOperationBase() = default; |
601 | /// Implements the bound 'print' method and helps with others. |
602 | void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo, |
603 | bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, |
604 | bool useNameLocAsPrefix, bool assumeVerified, |
605 | nanobind::object fileObject, bool binary, bool skipRegions); |
606 | void print(PyAsmState &state, nanobind::object fileObject, bool binary); |
607 | |
608 | nanobind::object getAsm(bool binary, |
609 | std::optional<int64_t> largeElementsLimit, |
610 | bool enableDebugInfo, bool prettyDebugInfo, |
611 | bool printGenericOpForm, bool useLocalScope, |
612 | bool useNameLocAsPrefix, bool assumeVerified, |
613 | bool skipRegions); |
614 | |
615 | // Implement the bound 'writeBytecode' method. |
616 | void writeBytecode(const nanobind::object &fileObject, |
617 | std::optional<int64_t> bytecodeVersion); |
618 | |
619 | // Implement the walk method. |
620 | void walk(std::function<MlirWalkResult(MlirOperation)> callback, |
621 | MlirWalkOrder walkOrder); |
622 | |
623 | /// Moves the operation before or after the other operation. |
624 | void moveAfter(PyOperationBase &other); |
625 | void moveBefore(PyOperationBase &other); |
626 | |
627 | /// Verify the operation. Throws `MLIRError` if verification fails, and |
628 | /// returns `true` otherwise. |
629 | bool verify(); |
630 | |
631 | /// Each must provide access to the raw Operation. |
632 | virtual PyOperation &getOperation() = 0; |
633 | }; |
634 | |
635 | /// Wrapper around PyOperation. |
636 | /// Operations exist in either an attached (dependent) or detached (top-level) |
637 | /// state. In the detached state (as on creation), an operation is owned by |
638 | /// the creator and its lifetime extends either until its reference count |
639 | /// drops to zero or it is attached to a parent, at which point its lifetime |
640 | /// is bounded by its top-level parent reference. |
641 | class PyOperation; |
642 | using PyOperationRef = PyObjectRef<PyOperation>; |
643 | class PyOperation : public PyOperationBase, public BaseContextObject { |
644 | public: |
645 | ~PyOperation() override; |
646 | PyOperation &getOperation() override { return *this; } |
647 | |
648 | /// Returns a PyOperation for the given MlirOperation, optionally associating |
649 | /// it with a parentKeepAlive. |
650 | static PyOperationRef |
651 | forOperation(PyMlirContextRef contextRef, MlirOperation operation, |
652 | nanobind::object parentKeepAlive = nanobind::object()); |
653 | |
654 | /// Creates a detached operation. The operation must not be associated with |
655 | /// any existing live operation. |
656 | static PyOperationRef |
657 | createDetached(PyMlirContextRef contextRef, MlirOperation operation, |
658 | nanobind::object parentKeepAlive = nanobind::object()); |
659 | |
660 | /// Parses a source string (either text assembly or bytecode), creating a |
661 | /// detached operation. |
662 | static PyOperationRef parse(PyMlirContextRef contextRef, |
663 | const std::string &sourceStr, |
664 | const std::string &sourceName); |
665 | |
666 | /// Detaches the operation from its parent block and updates its state |
667 | /// accordingly. |
668 | void detachFromParent() { |
669 | mlirOperationRemoveFromParent(getOperation()); |
670 | setDetached(); |
671 | parentKeepAlive = nanobind::object(); |
672 | } |
673 | |
674 | /// Gets the backing operation. |
675 | operator MlirOperation() const { return get(); } |
676 | MlirOperation get() const { |
677 | checkValid(); |
678 | return operation; |
679 | } |
680 | |
681 | PyOperationRef getRef() { |
682 | return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle)); |
683 | } |
684 | |
685 | bool isAttached() { return attached; } |
686 | void setAttached(const nanobind::object &parent = nanobind::object()) { |
687 | assert(!attached && "operation already attached"); |
688 | attached = true; |
689 | } |
690 | void setDetached() { |
691 | assert(attached && "operation already detached"); |
692 | attached = false; |
693 | } |
694 | void checkValid() const; |
695 | |
696 | /// Gets the owning block or raises an exception if the operation has no |
697 | /// owning block. |
698 | PyBlock getBlock(); |
699 | |
700 | /// Gets the parent operation or raises an exception if the operation has |
701 | /// no parent. |
702 | std::optional<PyOperationRef> getParentOperation(); |
703 | |
704 | /// Gets a capsule wrapping the void* within the MlirOperation. |
705 | nanobind::object getCapsule(); |
706 | |
707 | /// Creates a PyOperation from the MlirOperation wrapped by a capsule. |
708 | /// Ownership of the underlying MlirOperation is taken by calling this |
709 | /// function. |
710 | static nanobind::object createFromCapsule(nanobind::object capsule); |
711 | |
712 | /// Creates an operation. See corresponding python docstring. |
713 | static nanobind::object |
714 | create(std::string_view name, std::optional<std::vector<PyType *>> results, |
715 | llvm::ArrayRef<MlirValue> operands, |
716 | std::optional<nanobind::dict> attributes, |
717 | std::optional<std::vector<PyBlock *>> successors, int regions, |
718 | DefaultingPyLocation location, const nanobind::object &ip, |
719 | bool inferType); |
720 | |
721 | /// Creates an OpView suitable for this operation. |
722 | nanobind::object createOpView(); |
723 | |
724 | /// Erases the underlying MlirOperation, removes its pointer from the |
725 | /// parent context's live operations map, and sets the valid bit false. |
726 | void erase(); |
727 | |
728 | /// Invalidate the operation. |
729 | void setInvalid() { valid = false; } |
730 | |
731 | /// Clones this operation. |
732 | nanobind::object clone(const nanobind::object &ip); |
733 | |
734 | PyOperation(PyMlirContextRef contextRef, MlirOperation operation); |
735 | |
736 | private: |
737 | static PyOperationRef createInstance(PyMlirContextRef contextRef, |
738 | MlirOperation operation, |
739 | nanobind::object parentKeepAlive); |
740 | |
741 | MlirOperation operation; |
742 | nanobind::handle handle; |
743 | // Keeps the parent alive, regardless of whether it is an Operation or |
744 | // Module. |
745 | // TODO: As implemented, this facility is only sufficient for modeling the |
746 | // trivial module parent back-reference. Generalize this to also account for |
747 | // transitions from detached to attached and address TODOs in the |
748 | // ir_operation.py regarding testing corresponding lifetime guarantees. |
749 | nanobind::object parentKeepAlive; |
750 | bool attached = true; |
751 | bool valid = true; |
752 | |
753 | friend class PyOperationBase; |
754 | friend class PySymbolTable; |
755 | }; |
756 | |
757 | /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for |
758 | /// providing more instance-specific accessors and serve as the base class for |
759 | /// custom ODS-style operation classes. Since this class is subclass on the |
760 | /// python side, it must present an __init__ method that operates in pure |
761 | /// python types. |
762 | class PyOpView : public PyOperationBase { |
763 | public: |
764 | PyOpView(const nanobind::object &operationObject); |
765 | PyOperation &getOperation() override { return operation; } |
766 | |
767 | nanobind::object getOperationObject() { return operationObject; } |
768 | |
769 | static nanobind::object |
770 | buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec, |
771 | nanobind::object operandSegmentSpecObj, |
772 | nanobind::object resultSegmentSpecObj, |
773 | std::optional<nanobind::list> resultTypeList, |
774 | nanobind::list operandList, |
775 | std::optional<nanobind::dict> attributes, |
776 | std::optional<std::vector<PyBlock *>> successors, |
777 | std::optional<int> regions, DefaultingPyLocation location, |
778 | const nanobind::object &maybeIp); |
779 | |
780 | /// Construct an instance of a class deriving from OpView, bypassing its |
781 | /// `__init__` method. The derived class will typically define a constructor |
782 | /// that provides a convenient builder, but we need to side-step this when |
783 | /// constructing an `OpView` for an already-built operation. |
784 | /// |
785 | /// The caller is responsible for verifying that `operation` is a valid |
786 | /// operation to construct `cls` with. |
787 | static nanobind::object constructDerived(const nanobind::object &cls, |
788 | const nanobind::object &operation); |
789 | |
790 | private: |
791 | PyOperation &operation; // For efficient, cast-free access from C++ |
792 | nanobind::object operationObject; // Holds the reference. |
793 | }; |
794 | |
795 | /// Wrapper around an MlirRegion. |
796 | /// Regions are managed completely by their containing operation. Unlike the |
797 | /// C++ API, the python API does not support detached regions. |
798 | class PyRegion { |
799 | public: |
800 | PyRegion(PyOperationRef parentOperation, MlirRegion region) |
801 | : parentOperation(std::move(parentOperation)), region(region) { |
802 | assert(!mlirRegionIsNull(region) && "python region cannot be null"); |
803 | } |
804 | operator MlirRegion() const { return region; } |
805 | |
806 | MlirRegion get() { return region; } |
807 | PyOperationRef &getParentOperation() { return parentOperation; } |
808 | |
809 | void checkValid() { return parentOperation->checkValid(); } |
810 | |
811 | private: |
812 | PyOperationRef parentOperation; |
813 | MlirRegion region; |
814 | }; |
815 | |
816 | /// Wrapper around an MlirAsmState. |
817 | class PyAsmState { |
818 | public: |
819 | PyAsmState(MlirValue value, bool useLocalScope) { |
820 | flags = mlirOpPrintingFlagsCreate(); |
821 | // The OpPrintingFlags are not exposed Python side, create locally and |
822 | // associate lifetime with the state. |
823 | if (useLocalScope) |
824 | mlirOpPrintingFlagsUseLocalScope(flags); |
825 | state = mlirAsmStateCreateForValue(value, flags); |
826 | } |
827 | |
828 | PyAsmState(PyOperationBase &operation, bool useLocalScope) { |
829 | flags = mlirOpPrintingFlagsCreate(); |
830 | // The OpPrintingFlags are not exposed Python side, create locally and |
831 | // associate lifetime with the state. |
832 | if (useLocalScope) |
833 | mlirOpPrintingFlagsUseLocalScope(flags); |
834 | state = |
835 | mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); |
836 | } |
837 | ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } |
838 | // Delete copy constructors. |
839 | PyAsmState(PyAsmState &other) = delete; |
840 | PyAsmState(const PyAsmState &other) = delete; |
841 | |
842 | MlirAsmState get() { return state; } |
843 | |
844 | private: |
845 | MlirAsmState state; |
846 | MlirOpPrintingFlags flags; |
847 | }; |
848 | |
849 | /// Wrapper around an MlirBlock. |
850 | /// Blocks are managed completely by their containing operation. Unlike the |
851 | /// C++ API, the python API does not support detached blocks. |
852 | class PyBlock { |
853 | public: |
854 | PyBlock(PyOperationRef parentOperation, MlirBlock block) |
855 | : parentOperation(std::move(parentOperation)), block(block) { |
856 | assert(!mlirBlockIsNull(block) && "python block cannot be null"); |
857 | } |
858 | |
859 | MlirBlock get() { return block; } |
860 | PyOperationRef &getParentOperation() { return parentOperation; } |
861 | |
862 | void checkValid() { return parentOperation->checkValid(); } |
863 | |
864 | /// Gets a capsule wrapping the void* within the MlirBlock. |
865 | nanobind::object getCapsule(); |
866 | |
867 | private: |
868 | PyOperationRef parentOperation; |
869 | MlirBlock block; |
870 | }; |
871 | |
872 | /// An insertion point maintains a pointer to a Block and a reference operation. |
873 | /// Calls to insert() will insert a new operation before the |
874 | /// reference operation. If the reference operation is null, then appends to |
875 | /// the end of the block. |
876 | class PyInsertionPoint { |
877 | public: |
878 | /// Creates an insertion point positioned after the last operation in the |
879 | /// block, but still inside the block. |
880 | PyInsertionPoint(PyBlock &block); |
881 | /// Creates an insertion point positioned before a reference operation. |
882 | PyInsertionPoint(PyOperationBase &beforeOperationBase); |
883 | |
884 | /// Shortcut to create an insertion point at the beginning of the block. |
885 | static PyInsertionPoint atBlockBegin(PyBlock &block); |
886 | /// Shortcut to create an insertion point before the block terminator. |
887 | static PyInsertionPoint atBlockTerminator(PyBlock &block); |
888 | |
889 | /// Inserts an operation. |
890 | void insert(PyOperationBase &operationBase); |
891 | |
892 | /// Enter and exit the context manager. |
893 | static nanobind::object contextEnter(nanobind::object insertionPoint); |
894 | void contextExit(const nanobind::object &excType, |
895 | const nanobind::object &excVal, |
896 | const nanobind::object &excTb); |
897 | |
898 | PyBlock &getBlock() { return block; } |
899 | std::optional<PyOperationRef> &getRefOperation() { return refOperation; } |
900 | |
901 | private: |
902 | // Trampoline constructor that avoids null initializing members while |
903 | // looking up parents. |
904 | PyInsertionPoint(PyBlock block, std::optional<PyOperationRef> refOperation) |
905 | : refOperation(std::move(refOperation)), block(std::move(block)) {} |
906 | |
907 | std::optional<PyOperationRef> refOperation; |
908 | PyBlock block; |
909 | }; |
910 | /// Wrapper around the generic MlirType. |
911 | /// The lifetime of a type is bound by the PyContext that created it. |
912 | class PyType : public BaseContextObject { |
913 | public: |
914 | PyType(PyMlirContextRef contextRef, MlirType type) |
915 | : BaseContextObject(std::move(contextRef)), type(type) {} |
916 | bool operator==(const PyType &other) const; |
917 | operator MlirType() const { return type; } |
918 | MlirType get() const { return type; } |
919 | |
920 | /// Gets a capsule wrapping the void* within the MlirType. |
921 | nanobind::object getCapsule(); |
922 | |
923 | /// Creates a PyType from the MlirType wrapped by a capsule. |
924 | /// Note that PyType instances are uniqued, so the returned object |
925 | /// may be a pre-existing object. Ownership of the underlying MlirType |
926 | /// is taken by calling this function. |
927 | static PyType createFromCapsule(nanobind::object capsule); |
928 | |
929 | private: |
930 | MlirType type; |
931 | }; |
932 | |
933 | /// A TypeID provides an efficient and unique identifier for a specific C++ |
934 | /// type. This allows for a C++ type to be compared, hashed, and stored in an |
935 | /// opaque context. This class wraps around the generic MlirTypeID. |
936 | class PyTypeID { |
937 | public: |
938 | PyTypeID(MlirTypeID typeID) : typeID(typeID) {} |
939 | // Note, this tests whether the underlying TypeIDs are the same, |
940 | // not whether the wrapper MlirTypeIDs are the same, nor whether |
941 | // the PyTypeID objects are the same (i.e., PyTypeID is a value type). |
942 | bool operator==(const PyTypeID &other) const; |
943 | operator MlirTypeID() const { return typeID; } |
944 | MlirTypeID get() { return typeID; } |
945 | |
946 | /// Gets a capsule wrapping the void* within the MlirTypeID. |
947 | nanobind::object getCapsule(); |
948 | |
949 | /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. |
950 | static PyTypeID createFromCapsule(nanobind::object capsule); |
951 | |
952 | private: |
953 | MlirTypeID typeID; |
954 | }; |
955 | |
956 | /// CRTP base classes for Python types that subclass Type and should be |
957 | /// castable from it (i.e. via something like IntegerType(t)). |
958 | /// By default, type class hierarchies are one level deep (i.e. a |
959 | /// concrete type class extends PyType); however, intermediate python-visible |
960 | /// base classes can be modeled by specifying a BaseTy. |
961 | template <typename DerivedTy, typename BaseTy = PyType> |
962 | class PyConcreteType : public BaseTy { |
963 | public: |
964 | // Derived classes must define statics for: |
965 | // IsAFunctionTy isaFunction |
966 | // const char *pyClassName |
967 | using ClassTy = nanobind::class_<DerivedTy, BaseTy>; |
968 | using IsAFunctionTy = bool (*)(MlirType); |
969 | using GetTypeIDFunctionTy = MlirTypeID (*)(); |
970 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; |
971 | |
972 | PyConcreteType() = default; |
973 | PyConcreteType(PyMlirContextRef contextRef, MlirType t) |
974 | : BaseTy(std::move(contextRef), t) {} |
975 | PyConcreteType(PyType &orig) |
976 | : PyConcreteType(orig.getContext(), castFrom(orig)) {} |
977 | |
978 | static MlirType castFrom(PyType &orig) { |
979 | if (!DerivedTy::isaFunction(orig)) { |
980 | auto origRepr = |
981 | nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig))); |
982 | throw nanobind::value_error((llvm::Twine("Cannot cast type to ") + |
983 | DerivedTy::pyClassName + " (from "+ |
984 | origRepr + ")") |
985 | .str() |
986 | .c_str()); |
987 | } |
988 | return orig; |
989 | } |
990 | |
991 | static void bind(nanobind::module_ &m) { |
992 | auto cls = ClassTy(m, DerivedTy::pyClassName); |
993 | cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(), |
994 | nanobind::arg("cast_from_type")); |
995 | cls.def_static( |
996 | "isinstance", |
997 | [](PyType &otherType) -> bool { |
998 | return DerivedTy::isaFunction(otherType); |
999 | }, |
1000 | nanobind::arg("other")); |
1001 | cls.def_prop_ro_static( |
1002 | "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { |
1003 | if (DerivedTy::getTypeIdFunction) |
1004 | return DerivedTy::getTypeIdFunction(); |
1005 | throw nanobind::attribute_error( |
1006 | (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) |
1007 | .str() |
1008 | .c_str()); |
1009 | }); |
1010 | cls.def_prop_ro("typeid", [](PyType &self) { |
1011 | return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid")); |
1012 | }); |
1013 | cls.def("__repr__", [](DerivedTy &self) { |
1014 | PyPrintAccumulator printAccum; |
1015 | printAccum.parts.append(DerivedTy::pyClassName); |
1016 | printAccum.parts.append("("); |
1017 | mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); |
1018 | printAccum.parts.append(")"); |
1019 | return printAccum.join(); |
1020 | }); |
1021 | |
1022 | if (DerivedTy::getTypeIdFunction) { |
1023 | PyGlobals::get().registerTypeCaster( |
1024 | DerivedTy::getTypeIdFunction(), |
1025 | nanobind::cast<nanobind::callable>(nanobind::cpp_function( |
1026 | [](PyType pyType) -> DerivedTy { return pyType; }))); |
1027 | } |
1028 | |
1029 | DerivedTy::bindDerived(cls); |
1030 | } |
1031 | |
1032 | /// Implemented by derived classes to add methods to the Python subclass. |
1033 | static void bindDerived(ClassTy &m) {} |
1034 | }; |
1035 | |
1036 | /// Wrapper around the generic MlirAttribute. |
1037 | /// The lifetime of a type is bound by the PyContext that created it. |
1038 | class PyAttribute : public BaseContextObject { |
1039 | public: |
1040 | PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) |
1041 | : BaseContextObject(std::move(contextRef)), attr(attr) {} |
1042 | bool operator==(const PyAttribute &other) const; |
1043 | operator MlirAttribute() const { return attr; } |
1044 | MlirAttribute get() const { return attr; } |
1045 | |
1046 | /// Gets a capsule wrapping the void* within the MlirAttribute. |
1047 | nanobind::object getCapsule(); |
1048 | |
1049 | /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. |
1050 | /// Note that PyAttribute instances are uniqued, so the returned object |
1051 | /// may be a pre-existing object. Ownership of the underlying MlirAttribute |
1052 | /// is taken by calling this function. |
1053 | static PyAttribute createFromCapsule(nanobind::object capsule); |
1054 | |
1055 | private: |
1056 | MlirAttribute attr; |
1057 | }; |
1058 | |
1059 | /// Represents a Python MlirNamedAttr, carrying an optional owned name. |
1060 | /// TODO: Refactor this and the C-API to be based on an Identifier owned |
1061 | /// by the context so as to avoid ownership issues here. |
1062 | class PyNamedAttribute { |
1063 | public: |
1064 | /// Constructs a PyNamedAttr that retains an owned name. This should be |
1065 | /// used in any code that originates an MlirNamedAttribute from a python |
1066 | /// string. |
1067 | /// The lifetime of the PyNamedAttr must extend to the lifetime of the |
1068 | /// passed attribute. |
1069 | PyNamedAttribute(MlirAttribute attr, std::string ownedName); |
1070 | |
1071 | MlirNamedAttribute namedAttr; |
1072 | |
1073 | private: |
1074 | // Since the MlirNamedAttr contains an internal pointer to the actual |
1075 | // memory of the owned string, it must be heap allocated to remain valid. |
1076 | // Otherwise, strings that fit within the small object optimization threshold |
1077 | // will have their memory address change as the containing object is moved, |
1078 | // resulting in an invalid aliased pointer. |
1079 | std::unique_ptr<std::string> ownedName; |
1080 | }; |
1081 | |
1082 | /// CRTP base classes for Python attributes that subclass Attribute and should |
1083 | /// be castable from it (i.e. via something like StringAttr(attr)). |
1084 | /// By default, attribute class hierarchies are one level deep (i.e. a |
1085 | /// concrete attribute class extends PyAttribute); however, intermediate |
1086 | /// python-visible base classes can be modeled by specifying a BaseTy. |
1087 | template <typename DerivedTy, typename BaseTy = PyAttribute> |
1088 | class PyConcreteAttribute : public BaseTy { |
1089 | public: |
1090 | // Derived classes must define statics for: |
1091 | // IsAFunctionTy isaFunction |
1092 | // const char *pyClassName |
1093 | using ClassTy = nanobind::class_<DerivedTy, BaseTy>; |
1094 | using IsAFunctionTy = bool (*)(MlirAttribute); |
1095 | using GetTypeIDFunctionTy = MlirTypeID (*)(); |
1096 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; |
1097 | |
1098 | PyConcreteAttribute() = default; |
1099 | PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) |
1100 | : BaseTy(std::move(contextRef), attr) {} |
1101 | PyConcreteAttribute(PyAttribute &orig) |
1102 | : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} |
1103 | |
1104 | static MlirAttribute castFrom(PyAttribute &orig) { |
1105 | if (!DerivedTy::isaFunction(orig)) { |
1106 | auto origRepr = |
1107 | nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig))); |
1108 | throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") + |
1109 | DerivedTy::pyClassName + " (from "+ |
1110 | origRepr + ")") |
1111 | .str() |
1112 | .c_str()); |
1113 | } |
1114 | return orig; |
1115 | } |
1116 | |
1117 | static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { |
1118 | ClassTy cls; |
1119 | if (slots) { |
1120 | cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); |
1121 | } else { |
1122 | cls = ClassTy(m, DerivedTy::pyClassName); |
1123 | } |
1124 | cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(), |
1125 | nanobind::arg("cast_from_attr")); |
1126 | cls.def_static( |
1127 | "isinstance", |
1128 | [](PyAttribute &otherAttr) -> bool { |
1129 | return DerivedTy::isaFunction(otherAttr); |
1130 | }, |
1131 | nanobind::arg("other")); |
1132 | cls.def_prop_ro( |
1133 | "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); |
1134 | cls.def_prop_ro_static( |
1135 | "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { |
1136 | if (DerivedTy::getTypeIdFunction) |
1137 | return DerivedTy::getTypeIdFunction(); |
1138 | throw nanobind::attribute_error( |
1139 | (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) |
1140 | .str() |
1141 | .c_str()); |
1142 | }); |
1143 | cls.def_prop_ro("typeid", [](PyAttribute &self) { |
1144 | return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid")); |
1145 | }); |
1146 | cls.def("__repr__", [](DerivedTy &self) { |
1147 | PyPrintAccumulator printAccum; |
1148 | printAccum.parts.append(DerivedTy::pyClassName); |
1149 | printAccum.parts.append("("); |
1150 | mlirAttributePrint(self, printAccum.getCallback(), |
1151 | printAccum.getUserData()); |
1152 | printAccum.parts.append(")"); |
1153 | return printAccum.join(); |
1154 | }); |
1155 | |
1156 | if (DerivedTy::getTypeIdFunction) { |
1157 | PyGlobals::get().registerTypeCaster( |
1158 | DerivedTy::getTypeIdFunction(), |
1159 | nanobind::cast<nanobind::callable>( |
1160 | nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { |
1161 | return pyAttribute; |
1162 | }))); |
1163 | } |
1164 | |
1165 | DerivedTy::bindDerived(cls); |
1166 | } |
1167 | |
1168 | /// Implemented by derived classes to add methods to the Python subclass. |
1169 | static void bindDerived(ClassTy &m) {} |
1170 | }; |
1171 | |
1172 | /// Wrapper around the generic MlirValue. |
1173 | /// Values are managed completely by the operation that resulted in their |
1174 | /// definition. For op result value, this is the operation that defines the |
1175 | /// value. For block argument values, this is the operation that contains the |
1176 | /// block to which the value is an argument (blocks cannot be detached in Python |
1177 | /// bindings so such operation always exists). |
1178 | class PyValue { |
1179 | public: |
1180 | // The virtual here is "load bearing" in that it enables RTTI |
1181 | // for PyConcreteValue CRTP classes that support maybeDownCast. |
1182 | // See PyValue::maybeDownCast. |
1183 | virtual ~PyValue() = default; |
1184 | PyValue(PyOperationRef parentOperation, MlirValue value) |
1185 | : parentOperation(std::move(parentOperation)), value(value) {} |
1186 | operator MlirValue() const { return value; } |
1187 | |
1188 | MlirValue get() { return value; } |
1189 | PyOperationRef &getParentOperation() { return parentOperation; } |
1190 | |
1191 | void checkValid() { return parentOperation->checkValid(); } |
1192 | |
1193 | /// Gets a capsule wrapping the void* within the MlirValue. |
1194 | nanobind::object getCapsule(); |
1195 | |
1196 | nanobind::object maybeDownCast(); |
1197 | |
1198 | /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of |
1199 | /// the underlying MlirValue is still tied to the owning operation. |
1200 | static PyValue createFromCapsule(nanobind::object capsule); |
1201 | |
1202 | private: |
1203 | PyOperationRef parentOperation; |
1204 | MlirValue value; |
1205 | }; |
1206 | |
1207 | /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. |
1208 | class PyAffineExpr : public BaseContextObject { |
1209 | public: |
1210 | PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) |
1211 | : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} |
1212 | bool operator==(const PyAffineExpr &other) const; |
1213 | operator MlirAffineExpr() const { return affineExpr; } |
1214 | MlirAffineExpr get() const { return affineExpr; } |
1215 | |
1216 | /// Gets a capsule wrapping the void* within the MlirAffineExpr. |
1217 | nanobind::object getCapsule(); |
1218 | |
1219 | /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. |
1220 | /// Note that PyAffineExpr instances are uniqued, so the returned object |
1221 | /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr |
1222 | /// is taken by calling this function. |
1223 | static PyAffineExpr createFromCapsule(nanobind::object capsule); |
1224 | |
1225 | PyAffineExpr add(const PyAffineExpr &other) const; |
1226 | PyAffineExpr mul(const PyAffineExpr &other) const; |
1227 | PyAffineExpr floorDiv(const PyAffineExpr &other) const; |
1228 | PyAffineExpr ceilDiv(const PyAffineExpr &other) const; |
1229 | PyAffineExpr mod(const PyAffineExpr &other) const; |
1230 | |
1231 | private: |
1232 | MlirAffineExpr affineExpr; |
1233 | }; |
1234 | |
1235 | class PyAffineMap : public BaseContextObject { |
1236 | public: |
1237 | PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) |
1238 | : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} |
1239 | bool operator==(const PyAffineMap &other) const; |
1240 | operator MlirAffineMap() const { return affineMap; } |
1241 | MlirAffineMap get() const { return affineMap; } |
1242 | |
1243 | /// Gets a capsule wrapping the void* within the MlirAffineMap. |
1244 | nanobind::object getCapsule(); |
1245 | |
1246 | /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. |
1247 | /// Note that PyAffineMap instances are uniqued, so the returned object |
1248 | /// may be a pre-existing object. Ownership of the underlying MlirAffineMap |
1249 | /// is taken by calling this function. |
1250 | static PyAffineMap createFromCapsule(nanobind::object capsule); |
1251 | |
1252 | private: |
1253 | MlirAffineMap affineMap; |
1254 | }; |
1255 | |
1256 | class PyIntegerSet : public BaseContextObject { |
1257 | public: |
1258 | PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) |
1259 | : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} |
1260 | bool operator==(const PyIntegerSet &other) const; |
1261 | operator MlirIntegerSet() const { return integerSet; } |
1262 | MlirIntegerSet get() const { return integerSet; } |
1263 | |
1264 | /// Gets a capsule wrapping the void* within the MlirIntegerSet. |
1265 | nanobind::object getCapsule(); |
1266 | |
1267 | /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. |
1268 | /// Note that PyIntegerSet instances may be uniqued, so the returned object |
1269 | /// may be a pre-existing object. Integer sets are owned by the context. |
1270 | static PyIntegerSet createFromCapsule(nanobind::object capsule); |
1271 | |
1272 | private: |
1273 | MlirIntegerSet integerSet; |
1274 | }; |
1275 | |
1276 | /// Bindings for MLIR symbol tables. |
1277 | class PySymbolTable { |
1278 | public: |
1279 | /// Constructs a symbol table for the given operation. |
1280 | explicit PySymbolTable(PyOperationBase &operation); |
1281 | |
1282 | /// Destroys the symbol table. |
1283 | ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } |
1284 | |
1285 | /// Returns the symbol (opview) with the given name, throws if there is no |
1286 | /// such symbol in the table. |
1287 | nanobind::object dunderGetItem(const std::string &name); |
1288 | |
1289 | /// Removes the given operation from the symbol table and erases it. |
1290 | void erase(PyOperationBase &symbol); |
1291 | |
1292 | /// Removes the operation with the given name from the symbol table and erases |
1293 | /// it, throws if there is no such symbol in the table. |
1294 | void dunderDel(const std::string &name); |
1295 | |
1296 | /// Inserts the given operation into the symbol table. The operation must have |
1297 | /// the symbol trait. |
1298 | MlirAttribute insert(PyOperationBase &symbol); |
1299 | |
1300 | /// Gets and sets the name of a symbol op. |
1301 | static MlirAttribute getSymbolName(PyOperationBase &symbol); |
1302 | static void setSymbolName(PyOperationBase &symbol, const std::string &name); |
1303 | |
1304 | /// Gets and sets the visibility of a symbol op. |
1305 | static MlirAttribute getVisibility(PyOperationBase &symbol); |
1306 | static void setVisibility(PyOperationBase &symbol, |
1307 | const std::string &visibility); |
1308 | |
1309 | /// Replaces all symbol uses within an operation. See the API |
1310 | /// mlirSymbolTableReplaceAllSymbolUses for all caveats. |
1311 | static void replaceAllSymbolUses(const std::string &oldSymbol, |
1312 | const std::string &newSymbol, |
1313 | PyOperationBase &from); |
1314 | |
1315 | /// Walks all symbol tables under and including 'from'. |
1316 | static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, |
1317 | nanobind::object callback); |
1318 | |
1319 | /// Casts the bindings class into the C API structure. |
1320 | operator MlirSymbolTable() { return symbolTable; } |
1321 | |
1322 | private: |
1323 | PyOperationRef operation; |
1324 | MlirSymbolTable symbolTable; |
1325 | }; |
1326 | |
1327 | /// Custom exception that allows access to error diagnostic information. This is |
1328 | /// converted to the `ir.MLIRError` python exception when thrown. |
1329 | struct MLIRError { |
1330 | MLIRError(llvm::Twine message, |
1331 | std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {}) |
1332 | : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} |
1333 | std::string message; |
1334 | std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics; |
1335 | }; |
1336 | |
1337 | void populateIRAffine(nanobind::module_ &m); |
1338 | void populateIRAttributes(nanobind::module_ &m); |
1339 | void populateIRCore(nanobind::module_ &m); |
1340 | void populateIRInterfaces(nanobind::module_ &m); |
1341 | void populateIRTypes(nanobind::module_ &m); |
1342 | |
1343 | } // namespace python |
1344 | } // namespace mlir |
1345 | |
1346 | namespace nanobind { |
1347 | namespace detail { |
1348 | |
1349 | template <> |
1350 | struct type_caster<mlir::python::DefaultingPyMlirContext> |
1351 | : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; |
1352 | template <> |
1353 | struct type_caster<mlir::python::DefaultingPyLocation> |
1354 | : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; |
1355 | |
1356 | } // namespace detail |
1357 | } // namespace nanobind |
1358 | |
1359 | #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H |
1360 |
Definitions
- PyObjectRef
- PyObjectRef
- PyObjectRef
- PyObjectRef
- ~PyObjectRef
- getRefCount
- releaseObject
- get
- operator->
- getObject
- operator bool
- PyThreadContextEntry
- FrameKind
- PyThreadContextEntry
- getFrameKind
- PyThreadPool
- PyThreadPool
- PyThreadPool
- PyThreadPool
- getMaxConcurrency
- get
- _mlir_thread_pool_ptr
- PyMlirContext
- PyMlirContext
- PyMlirContext
- PyMlirContext
- get
- getRef
- setEmitErrorDiagnostics
- DefaultingPyMlirContext
- kTypeDescription
- BaseContextObject
- BaseContextObject
- getContext
- PyLocation
- PyLocation
- PyDiagnostic
- PyDiagnostic
- isValid
- DiagnosticInfo
- PyDiagnosticHandler
- isAttached
- getHadError
- contextEnter
- contextExit
- ErrorCapture
- ErrorCapture
- ~ErrorCapture
- take
- PyDialectDescriptor
- PyDialectDescriptor
- get
- PyDialects
- PyDialects
- PyDialect
- PyDialect
- getDescriptor
- PyDialectRegistry
- PyDialectRegistry
- PyDialectRegistry
- ~PyDialectRegistry
- PyDialectRegistry
- PyDialectRegistry
- DefaultingPyLocation
- kTypeDescription
- PyModule
- PyModule
- PyModule
- get
- getRef
- PyOperationBase
- ~PyOperationBase
- PyOperation
- getOperation
- detachFromParent
- setInvalid
- PyOpView
- getOperation
- getOperationObject
- PyRegion
- PyRegion
- PyAsmState
- PyAsmState
- PyAsmState
- ~PyAsmState
- PyAsmState
- PyAsmState
- get
- PyBlock
- PyBlock
- get
- getParentOperation
- checkValid
- PyInsertionPoint
- getBlock
- getRefOperation
- PyInsertionPoint
- PyType
- PyType
- PyTypeID
- PyTypeID
- PyConcreteType
- getTypeIdFunction
- PyConcreteType
- PyConcreteType
- PyConcreteType
- castFrom
- bind
- bindDerived
- PyAttribute
- PyAttribute
- PyNamedAttribute
- PyConcreteAttribute
- getTypeIdFunction
- PyConcreteAttribute
- PyConcreteAttribute
- PyConcreteAttribute
- castFrom
- bind
- bindDerived
- PyValue
- ~PyValue
- PyValue
- PyAffineExpr
- PyAffineExpr
- PyAffineMap
- PyAffineMap
- PyIntegerSet
- PyIntegerSet
- PySymbolTable
- ~PySymbolTable
- MLIRError
- MLIRError
Learn to use CMake with our Intro Training
Find out more