1 | //===- IRModules.h - IR Submodules of pybind module -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H |
11 | #define MLIR_BINDINGS_PYTHON_IRMODULES_H |
12 | |
13 | #include <optional> |
14 | #include <utility> |
15 | #include <vector> |
16 | |
17 | #include "Globals.h" |
18 | #include "PybindUtils.h" |
19 | |
20 | #include "mlir-c/AffineExpr.h" |
21 | #include "mlir-c/AffineMap.h" |
22 | #include "mlir-c/Diagnostics.h" |
23 | #include "mlir-c/IR.h" |
24 | #include "mlir-c/IntegerSet.h" |
25 | #include "mlir/Bindings/Python/PybindAdaptors.h" |
26 | #include "llvm/ADT/DenseMap.h" |
27 | |
28 | namespace mlir { |
29 | namespace python { |
30 | |
31 | class PyBlock; |
32 | class PyDiagnostic; |
33 | class PyDiagnosticHandler; |
34 | class PyInsertionPoint; |
35 | class PyLocation; |
36 | class DefaultingPyLocation; |
37 | class PyMlirContext; |
38 | class DefaultingPyMlirContext; |
39 | class PyModule; |
40 | class PyOperation; |
41 | class PyOperationBase; |
42 | class PyType; |
43 | class PySymbolTable; |
44 | class PyValue; |
45 | |
46 | /// Template for a reference to a concrete type which captures a python |
47 | /// reference to its underlying python object. |
48 | template <typename T> |
49 | class PyObjectRef { |
50 | public: |
51 | PyObjectRef(T *referrent, pybind11::object object) |
52 | : referrent(referrent), object(std::move(object)) { |
53 | assert(this->referrent && |
54 | "cannot construct PyObjectRef with null referrent" ); |
55 | assert(this->object && "cannot construct PyObjectRef with null object" ); |
56 | } |
57 | PyObjectRef(PyObjectRef &&other) noexcept |
58 | : referrent(other.referrent), object(std::move(other.object)) { |
59 | other.referrent = nullptr; |
60 | assert(!other.object); |
61 | } |
62 | PyObjectRef(const PyObjectRef &other) |
63 | : referrent(other.referrent), object(other.object /* copies */) {} |
64 | ~PyObjectRef() = default; |
65 | |
66 | int getRefCount() { |
67 | if (!object) |
68 | return 0; |
69 | return object.ref_count(); |
70 | } |
71 | |
72 | /// Releases the object held by this instance, returning it. |
73 | /// This is the proper thing to return from a function that wants to return |
74 | /// the reference. Note that this does not work from initializers. |
75 | pybind11::object releaseObject() { |
76 | assert(referrent && object); |
77 | referrent = nullptr; |
78 | auto stolen = std::move(object); |
79 | return stolen; |
80 | } |
81 | |
82 | T *get() { return referrent; } |
83 | T *operator->() { |
84 | assert(referrent && object); |
85 | return referrent; |
86 | } |
87 | pybind11::object getObject() { |
88 | assert(referrent && object); |
89 | return object; |
90 | } |
91 | operator bool() const { return referrent && object; } |
92 | |
93 | private: |
94 | T *referrent; |
95 | pybind11::object object; |
96 | }; |
97 | |
98 | /// Tracks an entry in the thread context stack. New entries are pushed onto |
99 | /// here for each with block that activates a new InsertionPoint, Context or |
100 | /// Location. |
101 | /// |
102 | /// Pushing either a Location or InsertionPoint also pushes its associated |
103 | /// Context. Pushing a Context will not modify the Location or InsertionPoint |
104 | /// unless if they are from a different context, in which case, they are |
105 | /// cleared. |
106 | class PyThreadContextEntry { |
107 | public: |
108 | enum class FrameKind { |
109 | Context, |
110 | InsertionPoint, |
111 | Location, |
112 | }; |
113 | |
114 | PyThreadContextEntry(FrameKind frameKind, pybind11::object context, |
115 | pybind11::object insertionPoint, |
116 | pybind11::object location) |
117 | : context(std::move(context)), insertionPoint(std::move(insertionPoint)), |
118 | location(std::move(location)), frameKind(frameKind) {} |
119 | |
120 | /// Gets the top of stack context and return nullptr if not defined. |
121 | static PyMlirContext *getDefaultContext(); |
122 | |
123 | /// Gets the top of stack insertion point and return nullptr if not defined. |
124 | static PyInsertionPoint *getDefaultInsertionPoint(); |
125 | |
126 | /// Gets the top of stack location and returns nullptr if not defined. |
127 | static PyLocation *getDefaultLocation(); |
128 | |
129 | PyMlirContext *getContext(); |
130 | PyInsertionPoint *getInsertionPoint(); |
131 | PyLocation *getLocation(); |
132 | FrameKind getFrameKind() { return frameKind; } |
133 | |
134 | /// Stack management. |
135 | static PyThreadContextEntry *getTopOfStack(); |
136 | static pybind11::object pushContext(PyMlirContext &context); |
137 | static void popContext(PyMlirContext &context); |
138 | static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); |
139 | static void popInsertionPoint(PyInsertionPoint &insertionPoint); |
140 | static pybind11::object pushLocation(PyLocation &location); |
141 | static void popLocation(PyLocation &location); |
142 | |
143 | /// Gets the thread local stack. |
144 | static std::vector<PyThreadContextEntry> &getStack(); |
145 | |
146 | private: |
147 | static void push(FrameKind frameKind, pybind11::object context, |
148 | pybind11::object insertionPoint, pybind11::object location); |
149 | |
150 | /// An object reference to the PyContext. |
151 | pybind11::object context; |
152 | /// An object reference to the current insertion point. |
153 | pybind11::object insertionPoint; |
154 | /// An object reference to the current location. |
155 | pybind11::object location; |
156 | // The kind of push that was performed. |
157 | FrameKind frameKind; |
158 | }; |
159 | |
160 | /// Wrapper around MlirContext. |
161 | using PyMlirContextRef = PyObjectRef<PyMlirContext>; |
162 | class PyMlirContext { |
163 | public: |
164 | PyMlirContext() = delete; |
165 | PyMlirContext(const PyMlirContext &) = delete; |
166 | PyMlirContext(PyMlirContext &&) = delete; |
167 | |
168 | /// For the case of a python __init__ (py::init) method, pybind11 is quite |
169 | /// strict about needing to return a pointer that is not yet associated to |
170 | /// an py::object. Since the forContext() method acts like a pool, possibly |
171 | /// returning a recycled context, it does not satisfy this need. The usual |
172 | /// way in python to accomplish such a thing is to override __new__, but |
173 | /// that is also not supported by pybind11. Instead, we use this entry |
174 | /// point which always constructs a fresh context (which cannot alias an |
175 | /// existing one because it is fresh). |
176 | static PyMlirContext *createNewContextForInit(); |
177 | |
178 | /// Returns a context reference for the singleton PyMlirContext wrapper for |
179 | /// the given context. |
180 | static PyMlirContextRef forContext(MlirContext context); |
181 | ~PyMlirContext(); |
182 | |
183 | /// Accesses the underlying MlirContext. |
184 | MlirContext get() { return context; } |
185 | |
186 | /// Gets a strong reference to this context, which will ensure it is kept |
187 | /// alive for the life of the reference. |
188 | PyMlirContextRef getRef() { |
189 | return PyMlirContextRef(this, pybind11::cast(this)); |
190 | } |
191 | |
192 | /// Gets a capsule wrapping the void* within the MlirContext. |
193 | pybind11::object getCapsule(); |
194 | |
195 | /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. |
196 | /// Note that PyMlirContext instances are uniqued, so the returned object |
197 | /// may be a pre-existing object. Ownership of the underlying MlirContext |
198 | /// is taken by calling this function. |
199 | static pybind11::object createFromCapsule(pybind11::object capsule); |
200 | |
201 | /// Gets the count of live context objects. Used for testing. |
202 | static size_t getLiveCount(); |
203 | |
204 | /// Get a list of Python objects which are still in the live context map. |
205 | std::vector<PyOperation *> getLiveOperationObjects(); |
206 | |
207 | /// Gets the count of live operations associated with this context. |
208 | /// Used for testing. |
209 | size_t getLiveOperationCount(); |
210 | |
211 | /// Clears the live operations map, returning the number of entries which were |
212 | /// invalidated. To be used as a safety mechanism so that API end-users can't |
213 | /// corrupt by holding references they shouldn't have accessed in the first |
214 | /// place. |
215 | size_t clearLiveOperations(); |
216 | |
217 | /// Removes an operation from the live operations map and sets it invalid. |
218 | /// This is useful for when some non-bindings code destroys the operation and |
219 | /// the bindings need to made aware. For example, in the case when pass |
220 | /// manager is run. |
221 | void clearOperation(MlirOperation op); |
222 | |
223 | /// Clears all operations nested inside the given op using |
224 | /// `clearOperation(MlirOperation)`. |
225 | void clearOperationsInside(PyOperationBase &op); |
226 | void clearOperationsInside(MlirOperation op); |
227 | |
228 | /// Gets the count of live modules associated with this context. |
229 | /// Used for testing. |
230 | size_t getLiveModuleCount(); |
231 | |
232 | /// Enter and exit the context manager. |
233 | pybind11::object contextEnter(); |
234 | void contextExit(const pybind11::object &excType, |
235 | const pybind11::object &excVal, |
236 | const pybind11::object &excTb); |
237 | |
238 | /// Attaches a Python callback as a diagnostic handler, returning a |
239 | /// registration object (internally a PyDiagnosticHandler). |
240 | pybind11::object attachDiagnosticHandler(pybind11::object callback); |
241 | |
242 | /// Controls whether error diagnostics should be propagated to diagnostic |
243 | /// handlers, instead of being captured by `ErrorCapture`. |
244 | void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } |
245 | struct ErrorCapture; |
246 | |
247 | private: |
248 | PyMlirContext(MlirContext context); |
249 | // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, |
250 | // preserving the relationship that an MlirContext maps to a single |
251 | // PyMlirContext wrapper. This could be replaced in the future with an |
252 | // extension mechanism on the MlirContext for stashing user pointers. |
253 | // Note that this holds a handle, which does not imply ownership. |
254 | // Mappings will be removed when the context is destructed. |
255 | using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; |
256 | static LiveContextMap &getLiveContexts(); |
257 | |
258 | // Interns all live modules associated with this context. Modules tracked |
259 | // in this map are valid. When a module is invalidated, it is removed |
260 | // from this map, and while it still exists as an instance, any |
261 | // attempt to access it will raise an error. |
262 | using LiveModuleMap = |
263 | llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; |
264 | LiveModuleMap liveModules; |
265 | |
266 | // Interns all live operations associated with this context. Operations |
267 | // tracked in this map are valid. When an operation is invalidated, it is |
268 | // removed from this map, and while it still exists as an instance, any |
269 | // attempt to access it will raise an error. |
270 | using LiveOperationMap = |
271 | llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; |
272 | LiveOperationMap liveOperations; |
273 | |
274 | bool emitErrorDiagnostics = false; |
275 | |
276 | MlirContext context; |
277 | friend class PyModule; |
278 | friend class PyOperation; |
279 | }; |
280 | |
281 | /// Used in function arguments when None should resolve to the current context |
282 | /// manager set instance. |
283 | class DefaultingPyMlirContext |
284 | : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { |
285 | public: |
286 | using Defaulting::Defaulting; |
287 | static constexpr const char kTypeDescription[] = "mlir.ir.Context" ; |
288 | static PyMlirContext &resolve(); |
289 | }; |
290 | |
291 | /// Base class for all objects that directly or indirectly depend on an |
292 | /// MlirContext. The lifetime of the context will extend at least to the |
293 | /// lifetime of these instances. |
294 | /// Immutable objects that depend on a context extend this directly. |
295 | class BaseContextObject { |
296 | public: |
297 | BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { |
298 | assert(this->contextRef && |
299 | "context object constructed with null context ref" ); |
300 | } |
301 | |
302 | /// Accesses the context reference. |
303 | PyMlirContextRef &getContext() { return contextRef; } |
304 | |
305 | private: |
306 | PyMlirContextRef contextRef; |
307 | }; |
308 | |
309 | /// Wrapper around an MlirLocation. |
310 | class PyLocation : public BaseContextObject { |
311 | public: |
312 | PyLocation(PyMlirContextRef contextRef, MlirLocation loc) |
313 | : BaseContextObject(std::move(contextRef)), loc(loc) {} |
314 | |
315 | operator MlirLocation() const { return loc; } |
316 | MlirLocation get() const { return loc; } |
317 | |
318 | /// Enter and exit the context manager. |
319 | pybind11::object contextEnter(); |
320 | void contextExit(const pybind11::object &excType, |
321 | const pybind11::object &excVal, |
322 | const pybind11::object &excTb); |
323 | |
324 | /// Gets a capsule wrapping the void* within the MlirLocation. |
325 | pybind11::object getCapsule(); |
326 | |
327 | /// Creates a PyLocation from the MlirLocation wrapped by a capsule. |
328 | /// Note that PyLocation instances are uniqued, so the returned object |
329 | /// may be a pre-existing object. Ownership of the underlying MlirLocation |
330 | /// is taken by calling this function. |
331 | static PyLocation createFromCapsule(pybind11::object capsule); |
332 | |
333 | private: |
334 | MlirLocation loc; |
335 | }; |
336 | |
337 | /// Python class mirroring the C MlirDiagnostic struct. Note that these structs |
338 | /// are only valid for the duration of a diagnostic callback and attempting |
339 | /// to access them outside of that will raise an exception. This applies to |
340 | /// nested diagnostics (in the notes) as well. |
341 | class PyDiagnostic { |
342 | public: |
343 | PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} |
344 | void invalidate(); |
345 | bool isValid() { return valid; } |
346 | MlirDiagnosticSeverity getSeverity(); |
347 | PyLocation getLocation(); |
348 | pybind11::str getMessage(); |
349 | pybind11::tuple getNotes(); |
350 | |
351 | /// Materialized diagnostic information. This is safe to access outside the |
352 | /// diagnostic callback. |
353 | struct DiagnosticInfo { |
354 | MlirDiagnosticSeverity severity; |
355 | PyLocation location; |
356 | std::string message; |
357 | std::vector<DiagnosticInfo> notes; |
358 | }; |
359 | DiagnosticInfo getInfo(); |
360 | |
361 | private: |
362 | MlirDiagnostic diagnostic; |
363 | |
364 | void checkValid(); |
365 | /// If notes have been materialized from the diagnostic, then this will |
366 | /// be populated with the corresponding objects (all castable to |
367 | /// PyDiagnostic). |
368 | std::optional<pybind11::tuple> materializedNotes; |
369 | bool valid = true; |
370 | }; |
371 | |
372 | /// Represents a diagnostic handler attached to the context. The handler's |
373 | /// callback will be invoked with PyDiagnostic instances until the detach() |
374 | /// method is called or the context is destroyed. A diagnostic handler can be |
375 | /// the subject of a `with` block, which will detach it when the block exits. |
376 | /// |
377 | /// Since diagnostic handlers can call back into Python code which can do |
378 | /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, |
379 | /// etc), this is generally not deemed to be a great user-level API. Users |
380 | /// should generally use some form of DiagnosticCollector. If the handler raises |
381 | /// any exceptions, they will just be emitted to stderr and dropped. |
382 | /// |
383 | /// The unique usage of this class means that its lifetime management is |
384 | /// different from most other parts of the API. Instances are always created |
385 | /// in an attached state and can transition to a detached state by either: |
386 | /// a) The context being destroyed and unregistering all handlers. |
387 | /// b) An explicit call to detach(). |
388 | /// The object may remain live from a Python perspective for an arbitrary time |
389 | /// after detachment, but there is nothing the user can do with it (since there |
390 | /// is no way to attach an existing handler object). |
391 | class PyDiagnosticHandler { |
392 | public: |
393 | PyDiagnosticHandler(MlirContext context, pybind11::object callback); |
394 | ~PyDiagnosticHandler(); |
395 | |
396 | bool isAttached() { return registeredID.has_value(); } |
397 | bool getHadError() { return hadError; } |
398 | |
399 | /// Detaches the handler. Does nothing if not attached. |
400 | void detach(); |
401 | |
402 | pybind11::object contextEnter() { return pybind11::cast(this); } |
403 | void contextExit(const pybind11::object &excType, |
404 | const pybind11::object &excVal, |
405 | const pybind11::object &excTb) { |
406 | detach(); |
407 | } |
408 | |
409 | private: |
410 | MlirContext context; |
411 | pybind11::object callback; |
412 | std::optional<MlirDiagnosticHandlerID> registeredID; |
413 | bool hadError = false; |
414 | friend class PyMlirContext; |
415 | }; |
416 | |
417 | /// RAII object that captures any error diagnostics emitted to the provided |
418 | /// context. |
419 | struct PyMlirContext::ErrorCapture { |
420 | ErrorCapture(PyMlirContextRef ctx) |
421 | : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( |
422 | ctx->get(), handler, /*userData=*/this, |
423 | /*deleteUserData=*/nullptr)) {} |
424 | ~ErrorCapture() { |
425 | mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); |
426 | assert(errors.empty() && "unhandled captured errors" ); |
427 | } |
428 | |
429 | std::vector<PyDiagnostic::DiagnosticInfo> take() { |
430 | return std::move(errors); |
431 | }; |
432 | |
433 | private: |
434 | PyMlirContextRef ctx; |
435 | MlirDiagnosticHandlerID handlerID; |
436 | std::vector<PyDiagnostic::DiagnosticInfo> errors; |
437 | |
438 | static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); |
439 | }; |
440 | |
441 | /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in |
442 | /// order to differentiate it from the `Dialect` base class which is extended by |
443 | /// plugins which extend dialect functionality through extension python code. |
444 | /// This should be seen as the "low-level" object and `Dialect` as the |
445 | /// high-level, user facing object. |
446 | class PyDialectDescriptor : public BaseContextObject { |
447 | public: |
448 | PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) |
449 | : BaseContextObject(std::move(contextRef)), dialect(dialect) {} |
450 | |
451 | MlirDialect get() { return dialect; } |
452 | |
453 | private: |
454 | MlirDialect dialect; |
455 | }; |
456 | |
457 | /// User-level object for accessing dialects with dotted syntax such as: |
458 | /// ctx.dialect.std |
459 | class PyDialects : public BaseContextObject { |
460 | public: |
461 | PyDialects(PyMlirContextRef contextRef) |
462 | : BaseContextObject(std::move(contextRef)) {} |
463 | |
464 | MlirDialect getDialectForKey(const std::string &key, bool attrError); |
465 | }; |
466 | |
467 | /// User-level dialect object. For dialects that have a registered extension, |
468 | /// this will be the base class of the extension dialect type. For un-extended, |
469 | /// objects of this type will be returned directly. |
470 | class PyDialect { |
471 | public: |
472 | PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} |
473 | |
474 | pybind11::object getDescriptor() { return descriptor; } |
475 | |
476 | private: |
477 | pybind11::object descriptor; |
478 | }; |
479 | |
480 | /// Wrapper around an MlirDialectRegistry. |
481 | /// Upon construction, the Python wrapper takes ownership of the |
482 | /// underlying MlirDialectRegistry. |
483 | class PyDialectRegistry { |
484 | public: |
485 | PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} |
486 | PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} |
487 | ~PyDialectRegistry() { |
488 | if (!mlirDialectRegistryIsNull(registry)) |
489 | mlirDialectRegistryDestroy(registry); |
490 | } |
491 | PyDialectRegistry(PyDialectRegistry &) = delete; |
492 | PyDialectRegistry(PyDialectRegistry &&other) noexcept |
493 | : registry(other.registry) { |
494 | other.registry = {nullptr}; |
495 | } |
496 | |
497 | operator MlirDialectRegistry() const { return registry; } |
498 | MlirDialectRegistry get() const { return registry; } |
499 | |
500 | pybind11::object getCapsule(); |
501 | static PyDialectRegistry createFromCapsule(pybind11::object capsule); |
502 | |
503 | private: |
504 | MlirDialectRegistry registry; |
505 | }; |
506 | |
507 | /// Used in function arguments when None should resolve to the current context |
508 | /// manager set instance. |
509 | class DefaultingPyLocation |
510 | : public Defaulting<DefaultingPyLocation, PyLocation> { |
511 | public: |
512 | using Defaulting::Defaulting; |
513 | static constexpr const char kTypeDescription[] = "mlir.ir.Location" ; |
514 | static PyLocation &resolve(); |
515 | |
516 | operator MlirLocation() const { return *get(); } |
517 | }; |
518 | |
519 | /// Wrapper around MlirModule. |
520 | /// This is the top-level, user-owned object that contains regions/ops/blocks. |
521 | class PyModule; |
522 | using PyModuleRef = PyObjectRef<PyModule>; |
523 | class PyModule : public BaseContextObject { |
524 | public: |
525 | /// Returns a PyModule reference for the given MlirModule. This may return |
526 | /// a pre-existing or new object. |
527 | static PyModuleRef forModule(MlirModule module); |
528 | PyModule(PyModule &) = delete; |
529 | PyModule(PyMlirContext &&) = delete; |
530 | ~PyModule(); |
531 | |
532 | /// Gets the backing MlirModule. |
533 | MlirModule get() { return module; } |
534 | |
535 | /// Gets a strong reference to this module. |
536 | PyModuleRef getRef() { |
537 | return PyModuleRef(this, |
538 | pybind11::reinterpret_borrow<pybind11::object>(handle)); |
539 | } |
540 | |
541 | /// Gets a capsule wrapping the void* within the MlirModule. |
542 | /// Note that the module does not (yet) provide a corresponding factory for |
543 | /// constructing from a capsule as that would require uniquing PyModule |
544 | /// instances, which is not currently done. |
545 | pybind11::object getCapsule(); |
546 | |
547 | /// Creates a PyModule from the MlirModule wrapped by a capsule. |
548 | /// Note that PyModule instances are uniqued, so the returned object |
549 | /// may be a pre-existing object. Ownership of the underlying MlirModule |
550 | /// is taken by calling this function. |
551 | static pybind11::object createFromCapsule(pybind11::object capsule); |
552 | |
553 | private: |
554 | PyModule(PyMlirContextRef contextRef, MlirModule module); |
555 | MlirModule module; |
556 | pybind11::handle handle; |
557 | }; |
558 | |
559 | class PyAsmState; |
560 | |
561 | /// Base class for PyOperation and PyOpView which exposes the primary, user |
562 | /// visible methods for manipulating it. |
563 | class PyOperationBase { |
564 | public: |
565 | virtual ~PyOperationBase() = default; |
566 | /// Implements the bound 'print' method and helps with others. |
567 | void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo, |
568 | bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, |
569 | bool assumeVerified, py::object fileObject, bool binary); |
570 | void print(PyAsmState &state, py::object fileObject, bool binary); |
571 | |
572 | pybind11::object getAsm(bool binary, |
573 | std::optional<int64_t> largeElementsLimit, |
574 | bool enableDebugInfo, bool prettyDebugInfo, |
575 | bool printGenericOpForm, bool useLocalScope, |
576 | bool assumeVerified); |
577 | |
578 | // Implement the bound 'writeBytecode' method. |
579 | void writeBytecode(const pybind11::object &fileObject, |
580 | std::optional<int64_t> bytecodeVersion); |
581 | |
582 | // Implement the walk method. |
583 | void walk(std::function<MlirWalkResult(MlirOperation)> callback, |
584 | MlirWalkOrder walkOrder); |
585 | |
586 | /// Moves the operation before or after the other operation. |
587 | void moveAfter(PyOperationBase &other); |
588 | void moveBefore(PyOperationBase &other); |
589 | |
590 | /// Verify the operation. Throws `MLIRError` if verification fails, and |
591 | /// returns `true` otherwise. |
592 | bool verify(); |
593 | |
594 | /// Each must provide access to the raw Operation. |
595 | virtual PyOperation &getOperation() = 0; |
596 | }; |
597 | |
598 | /// Wrapper around PyOperation. |
599 | /// Operations exist in either an attached (dependent) or detached (top-level) |
600 | /// state. In the detached state (as on creation), an operation is owned by |
601 | /// the creator and its lifetime extends either until its reference count |
602 | /// drops to zero or it is attached to a parent, at which point its lifetime |
603 | /// is bounded by its top-level parent reference. |
604 | class PyOperation; |
605 | using PyOperationRef = PyObjectRef<PyOperation>; |
606 | class PyOperation : public PyOperationBase, public BaseContextObject { |
607 | public: |
608 | ~PyOperation() override; |
609 | PyOperation &getOperation() override { return *this; } |
610 | |
611 | /// Returns a PyOperation for the given MlirOperation, optionally associating |
612 | /// it with a parentKeepAlive. |
613 | static PyOperationRef |
614 | forOperation(PyMlirContextRef contextRef, MlirOperation operation, |
615 | pybind11::object parentKeepAlive = pybind11::object()); |
616 | |
617 | /// Creates a detached operation. The operation must not be associated with |
618 | /// any existing live operation. |
619 | static PyOperationRef |
620 | createDetached(PyMlirContextRef contextRef, MlirOperation operation, |
621 | pybind11::object parentKeepAlive = pybind11::object()); |
622 | |
623 | /// Parses a source string (either text assembly or bytecode), creating a |
624 | /// detached operation. |
625 | static PyOperationRef parse(PyMlirContextRef contextRef, |
626 | const std::string &sourceStr, |
627 | const std::string &sourceName); |
628 | |
629 | /// Detaches the operation from its parent block and updates its state |
630 | /// accordingly. |
631 | void detachFromParent() { |
632 | mlirOperationRemoveFromParent(getOperation()); |
633 | setDetached(); |
634 | parentKeepAlive = pybind11::object(); |
635 | } |
636 | |
637 | /// Gets the backing operation. |
638 | operator MlirOperation() const { return get(); } |
639 | MlirOperation get() const { |
640 | checkValid(); |
641 | return operation; |
642 | } |
643 | |
644 | PyOperationRef getRef() { |
645 | return PyOperationRef( |
646 | this, pybind11::reinterpret_borrow<pybind11::object>(handle)); |
647 | } |
648 | |
649 | bool isAttached() { return attached; } |
650 | void setAttached(const pybind11::object &parent = pybind11::object()) { |
651 | assert(!attached && "operation already attached" ); |
652 | attached = true; |
653 | } |
654 | void setDetached() { |
655 | assert(attached && "operation already detached" ); |
656 | attached = false; |
657 | } |
658 | void checkValid() const; |
659 | |
660 | /// Gets the owning block or raises an exception if the operation has no |
661 | /// owning block. |
662 | PyBlock getBlock(); |
663 | |
664 | /// Gets the parent operation or raises an exception if the operation has |
665 | /// no parent. |
666 | std::optional<PyOperationRef> getParentOperation(); |
667 | |
668 | /// Gets a capsule wrapping the void* within the MlirOperation. |
669 | pybind11::object getCapsule(); |
670 | |
671 | /// Creates a PyOperation from the MlirOperation wrapped by a capsule. |
672 | /// Ownership of the underlying MlirOperation is taken by calling this |
673 | /// function. |
674 | static pybind11::object createFromCapsule(pybind11::object capsule); |
675 | |
676 | /// Creates an operation. See corresponding python docstring. |
677 | static pybind11::object |
678 | create(const std::string &name, std::optional<std::vector<PyType *>> results, |
679 | std::optional<std::vector<PyValue *>> operands, |
680 | std::optional<pybind11::dict> attributes, |
681 | std::optional<std::vector<PyBlock *>> successors, int regions, |
682 | DefaultingPyLocation location, const pybind11::object &ip, |
683 | bool inferType); |
684 | |
685 | /// Creates an OpView suitable for this operation. |
686 | pybind11::object createOpView(); |
687 | |
688 | /// Erases the underlying MlirOperation, removes its pointer from the |
689 | /// parent context's live operations map, and sets the valid bit false. |
690 | void erase(); |
691 | |
692 | /// Invalidate the operation. |
693 | void setInvalid() { valid = false; } |
694 | |
695 | /// Clones this operation. |
696 | pybind11::object clone(const pybind11::object &ip); |
697 | |
698 | private: |
699 | PyOperation(PyMlirContextRef contextRef, MlirOperation operation); |
700 | static PyOperationRef createInstance(PyMlirContextRef contextRef, |
701 | MlirOperation operation, |
702 | pybind11::object parentKeepAlive); |
703 | |
704 | MlirOperation operation; |
705 | pybind11::handle handle; |
706 | // Keeps the parent alive, regardless of whether it is an Operation or |
707 | // Module. |
708 | // TODO: As implemented, this facility is only sufficient for modeling the |
709 | // trivial module parent back-reference. Generalize this to also account for |
710 | // transitions from detached to attached and address TODOs in the |
711 | // ir_operation.py regarding testing corresponding lifetime guarantees. |
712 | pybind11::object parentKeepAlive; |
713 | bool attached = true; |
714 | bool valid = true; |
715 | |
716 | friend class PyOperationBase; |
717 | friend class PySymbolTable; |
718 | }; |
719 | |
720 | /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for |
721 | /// providing more instance-specific accessors and serve as the base class for |
722 | /// custom ODS-style operation classes. Since this class is subclass on the |
723 | /// python side, it must present an __init__ method that operates in pure |
724 | /// python types. |
725 | class PyOpView : public PyOperationBase { |
726 | public: |
727 | PyOpView(const pybind11::object &operationObject); |
728 | PyOperation &getOperation() override { return operation; } |
729 | |
730 | pybind11::object getOperationObject() { return operationObject; } |
731 | |
732 | static pybind11::object buildGeneric( |
733 | const pybind11::object &cls, std::optional<pybind11::list> resultTypeList, |
734 | pybind11::list operandList, std::optional<pybind11::dict> attributes, |
735 | std::optional<std::vector<PyBlock *>> successors, |
736 | std::optional<int> regions, DefaultingPyLocation location, |
737 | const pybind11::object &maybeIp); |
738 | |
739 | /// Construct an instance of a class deriving from OpView, bypassing its |
740 | /// `__init__` method. The derived class will typically define a constructor |
741 | /// that provides a convenient builder, but we need to side-step this when |
742 | /// constructing an `OpView` for an already-built operation. |
743 | /// |
744 | /// The caller is responsible for verifying that `operation` is a valid |
745 | /// operation to construct `cls` with. |
746 | static pybind11::object constructDerived(const pybind11::object &cls, |
747 | const PyOperation &operation); |
748 | |
749 | private: |
750 | PyOperation &operation; // For efficient, cast-free access from C++ |
751 | pybind11::object operationObject; // Holds the reference. |
752 | }; |
753 | |
754 | /// Wrapper around an MlirRegion. |
755 | /// Regions are managed completely by their containing operation. Unlike the |
756 | /// C++ API, the python API does not support detached regions. |
757 | class PyRegion { |
758 | public: |
759 | PyRegion(PyOperationRef parentOperation, MlirRegion region) |
760 | : parentOperation(std::move(parentOperation)), region(region) { |
761 | assert(!mlirRegionIsNull(region) && "python region cannot be null" ); |
762 | } |
763 | operator MlirRegion() const { return region; } |
764 | |
765 | MlirRegion get() { return region; } |
766 | PyOperationRef &getParentOperation() { return parentOperation; } |
767 | |
768 | void checkValid() { return parentOperation->checkValid(); } |
769 | |
770 | private: |
771 | PyOperationRef parentOperation; |
772 | MlirRegion region; |
773 | }; |
774 | |
775 | /// Wrapper around an MlirAsmState. |
776 | class PyAsmState { |
777 | public: |
778 | PyAsmState(MlirValue value, bool useLocalScope) { |
779 | flags = mlirOpPrintingFlagsCreate(); |
780 | // The OpPrintingFlags are not exposed Python side, create locally and |
781 | // associate lifetime with the state. |
782 | if (useLocalScope) |
783 | mlirOpPrintingFlagsUseLocalScope(flags); |
784 | state = mlirAsmStateCreateForValue(value, flags); |
785 | } |
786 | |
787 | PyAsmState(PyOperationBase &operation, bool useLocalScope) { |
788 | flags = mlirOpPrintingFlagsCreate(); |
789 | // The OpPrintingFlags are not exposed Python side, create locally and |
790 | // associate lifetime with the state. |
791 | if (useLocalScope) |
792 | mlirOpPrintingFlagsUseLocalScope(flags); |
793 | state = |
794 | mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); |
795 | } |
796 | ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } |
797 | // Delete copy constructors. |
798 | PyAsmState(PyAsmState &other) = delete; |
799 | PyAsmState(const PyAsmState &other) = delete; |
800 | |
801 | MlirAsmState get() { return state; } |
802 | |
803 | private: |
804 | MlirAsmState state; |
805 | MlirOpPrintingFlags flags; |
806 | }; |
807 | |
808 | /// Wrapper around an MlirBlock. |
809 | /// Blocks are managed completely by their containing operation. Unlike the |
810 | /// C++ API, the python API does not support detached blocks. |
811 | class PyBlock { |
812 | public: |
813 | PyBlock(PyOperationRef parentOperation, MlirBlock block) |
814 | : parentOperation(std::move(parentOperation)), block(block) { |
815 | assert(!mlirBlockIsNull(block) && "python block cannot be null" ); |
816 | } |
817 | |
818 | MlirBlock get() { return block; } |
819 | PyOperationRef &getParentOperation() { return parentOperation; } |
820 | |
821 | void checkValid() { return parentOperation->checkValid(); } |
822 | |
823 | /// Gets a capsule wrapping the void* within the MlirBlock. |
824 | pybind11::object getCapsule(); |
825 | |
826 | private: |
827 | PyOperationRef parentOperation; |
828 | MlirBlock block; |
829 | }; |
830 | |
831 | /// An insertion point maintains a pointer to a Block and a reference operation. |
832 | /// Calls to insert() will insert a new operation before the |
833 | /// reference operation. If the reference operation is null, then appends to |
834 | /// the end of the block. |
835 | class PyInsertionPoint { |
836 | public: |
837 | /// Creates an insertion point positioned after the last operation in the |
838 | /// block, but still inside the block. |
839 | PyInsertionPoint(PyBlock &block); |
840 | /// Creates an insertion point positioned before a reference operation. |
841 | PyInsertionPoint(PyOperationBase &beforeOperationBase); |
842 | |
843 | /// Shortcut to create an insertion point at the beginning of the block. |
844 | static PyInsertionPoint atBlockBegin(PyBlock &block); |
845 | /// Shortcut to create an insertion point before the block terminator. |
846 | static PyInsertionPoint atBlockTerminator(PyBlock &block); |
847 | |
848 | /// Inserts an operation. |
849 | void insert(PyOperationBase &operationBase); |
850 | |
851 | /// Enter and exit the context manager. |
852 | pybind11::object contextEnter(); |
853 | void contextExit(const pybind11::object &excType, |
854 | const pybind11::object &excVal, |
855 | const pybind11::object &excTb); |
856 | |
857 | PyBlock &getBlock() { return block; } |
858 | std::optional<PyOperationRef> &getRefOperation() { return refOperation; } |
859 | |
860 | private: |
861 | // Trampoline constructor that avoids null initializing members while |
862 | // looking up parents. |
863 | PyInsertionPoint(PyBlock block, std::optional<PyOperationRef> refOperation) |
864 | : refOperation(std::move(refOperation)), block(std::move(block)) {} |
865 | |
866 | std::optional<PyOperationRef> refOperation; |
867 | PyBlock block; |
868 | }; |
869 | /// Wrapper around the generic MlirType. |
870 | /// The lifetime of a type is bound by the PyContext that created it. |
871 | class PyType : public BaseContextObject { |
872 | public: |
873 | PyType(PyMlirContextRef contextRef, MlirType type) |
874 | : BaseContextObject(std::move(contextRef)), type(type) {} |
875 | bool operator==(const PyType &other) const; |
876 | operator MlirType() const { return type; } |
877 | MlirType get() const { return type; } |
878 | |
879 | /// Gets a capsule wrapping the void* within the MlirType. |
880 | pybind11::object getCapsule(); |
881 | |
882 | /// Creates a PyType from the MlirType wrapped by a capsule. |
883 | /// Note that PyType instances are uniqued, so the returned object |
884 | /// may be a pre-existing object. Ownership of the underlying MlirType |
885 | /// is taken by calling this function. |
886 | static PyType createFromCapsule(pybind11::object capsule); |
887 | |
888 | private: |
889 | MlirType type; |
890 | }; |
891 | |
892 | /// A TypeID provides an efficient and unique identifier for a specific C++ |
893 | /// type. This allows for a C++ type to be compared, hashed, and stored in an |
894 | /// opaque context. This class wraps around the generic MlirTypeID. |
895 | class PyTypeID { |
896 | public: |
897 | PyTypeID(MlirTypeID typeID) : typeID(typeID) {} |
898 | // Note, this tests whether the underlying TypeIDs are the same, |
899 | // not whether the wrapper MlirTypeIDs are the same, nor whether |
900 | // the PyTypeID objects are the same (i.e., PyTypeID is a value type). |
901 | bool operator==(const PyTypeID &other) const; |
902 | operator MlirTypeID() const { return typeID; } |
903 | MlirTypeID get() { return typeID; } |
904 | |
905 | /// Gets a capsule wrapping the void* within the MlirTypeID. |
906 | pybind11::object getCapsule(); |
907 | |
908 | /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. |
909 | static PyTypeID createFromCapsule(pybind11::object capsule); |
910 | |
911 | private: |
912 | MlirTypeID typeID; |
913 | }; |
914 | |
915 | /// CRTP base classes for Python types that subclass Type and should be |
916 | /// castable from it (i.e. via something like IntegerType(t)). |
917 | /// By default, type class hierarchies are one level deep (i.e. a |
918 | /// concrete type class extends PyType); however, intermediate python-visible |
919 | /// base classes can be modeled by specifying a BaseTy. |
920 | template <typename DerivedTy, typename BaseTy = PyType> |
921 | class PyConcreteType : public BaseTy { |
922 | public: |
923 | // Derived classes must define statics for: |
924 | // IsAFunctionTy isaFunction |
925 | // const char *pyClassName |
926 | using ClassTy = pybind11::class_<DerivedTy, BaseTy>; |
927 | using IsAFunctionTy = bool (*)(MlirType); |
928 | using GetTypeIDFunctionTy = MlirTypeID (*)(); |
929 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; |
930 | |
931 | PyConcreteType() = default; |
932 | PyConcreteType(PyMlirContextRef contextRef, MlirType t) |
933 | : BaseTy(std::move(contextRef), t) {} |
934 | PyConcreteType(PyType &orig) |
935 | : PyConcreteType(orig.getContext(), castFrom(orig)) {} |
936 | |
937 | static MlirType castFrom(PyType &orig) { |
938 | if (!DerivedTy::isaFunction(orig)) { |
939 | auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); |
940 | throw py::value_error((llvm::Twine("Cannot cast type to " ) + |
941 | DerivedTy::pyClassName + " (from " + origRepr + |
942 | ")" ) |
943 | .str()); |
944 | } |
945 | return orig; |
946 | } |
947 | |
948 | static void bind(pybind11::module &m) { |
949 | auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); |
950 | cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(), |
951 | pybind11::arg("cast_from_type" )); |
952 | cls.def_static( |
953 | "isinstance" , |
954 | [](PyType &otherType) -> bool { |
955 | return DerivedTy::isaFunction(otherType); |
956 | }, |
957 | pybind11::arg("other" )); |
958 | cls.def_property_readonly_static( |
959 | "static_typeid" , [](py::object & /*class*/) -> MlirTypeID { |
960 | if (DerivedTy::getTypeIdFunction) |
961 | return DerivedTy::getTypeIdFunction(); |
962 | throw py::attribute_error( |
963 | (DerivedTy::pyClassName + llvm::Twine(" has no typeid." )).str()); |
964 | }); |
965 | cls.def_property_readonly("typeid" , [](PyType &self) { |
966 | return py::cast(self).attr("typeid" ).cast<MlirTypeID>(); |
967 | }); |
968 | cls.def("__repr__" , [](DerivedTy &self) { |
969 | PyPrintAccumulator printAccum; |
970 | printAccum.parts.append(DerivedTy::pyClassName); |
971 | printAccum.parts.append("(" ); |
972 | mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); |
973 | printAccum.parts.append(")" ); |
974 | return printAccum.join(); |
975 | }); |
976 | |
977 | if (DerivedTy::getTypeIdFunction) { |
978 | PyGlobals::get().registerTypeCaster( |
979 | DerivedTy::getTypeIdFunction(), |
980 | pybind11::cpp_function( |
981 | [](PyType pyType) -> DerivedTy { return pyType; })); |
982 | } |
983 | |
984 | DerivedTy::bindDerived(cls); |
985 | } |
986 | |
987 | /// Implemented by derived classes to add methods to the Python subclass. |
988 | static void bindDerived(ClassTy &m) {} |
989 | }; |
990 | |
991 | /// Wrapper around the generic MlirAttribute. |
992 | /// The lifetime of a type is bound by the PyContext that created it. |
993 | class PyAttribute : public BaseContextObject { |
994 | public: |
995 | PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) |
996 | : BaseContextObject(std::move(contextRef)), attr(attr) {} |
997 | bool operator==(const PyAttribute &other) const; |
998 | operator MlirAttribute() const { return attr; } |
999 | MlirAttribute get() const { return attr; } |
1000 | |
1001 | /// Gets a capsule wrapping the void* within the MlirAttribute. |
1002 | pybind11::object getCapsule(); |
1003 | |
1004 | /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. |
1005 | /// Note that PyAttribute instances are uniqued, so the returned object |
1006 | /// may be a pre-existing object. Ownership of the underlying MlirAttribute |
1007 | /// is taken by calling this function. |
1008 | static PyAttribute createFromCapsule(pybind11::object capsule); |
1009 | |
1010 | private: |
1011 | MlirAttribute attr; |
1012 | }; |
1013 | |
1014 | /// Represents a Python MlirNamedAttr, carrying an optional owned name. |
1015 | /// TODO: Refactor this and the C-API to be based on an Identifier owned |
1016 | /// by the context so as to avoid ownership issues here. |
1017 | class PyNamedAttribute { |
1018 | public: |
1019 | /// Constructs a PyNamedAttr that retains an owned name. This should be |
1020 | /// used in any code that originates an MlirNamedAttribute from a python |
1021 | /// string. |
1022 | /// The lifetime of the PyNamedAttr must extend to the lifetime of the |
1023 | /// passed attribute. |
1024 | PyNamedAttribute(MlirAttribute attr, std::string ownedName); |
1025 | |
1026 | MlirNamedAttribute namedAttr; |
1027 | |
1028 | private: |
1029 | // Since the MlirNamedAttr contains an internal pointer to the actual |
1030 | // memory of the owned string, it must be heap allocated to remain valid. |
1031 | // Otherwise, strings that fit within the small object optimization threshold |
1032 | // will have their memory address change as the containing object is moved, |
1033 | // resulting in an invalid aliased pointer. |
1034 | std::unique_ptr<std::string> ownedName; |
1035 | }; |
1036 | |
1037 | /// CRTP base classes for Python attributes that subclass Attribute and should |
1038 | /// be castable from it (i.e. via something like StringAttr(attr)). |
1039 | /// By default, attribute class hierarchies are one level deep (i.e. a |
1040 | /// concrete attribute class extends PyAttribute); however, intermediate |
1041 | /// python-visible base classes can be modeled by specifying a BaseTy. |
1042 | template <typename DerivedTy, typename BaseTy = PyAttribute> |
1043 | class PyConcreteAttribute : public BaseTy { |
1044 | public: |
1045 | // Derived classes must define statics for: |
1046 | // IsAFunctionTy isaFunction |
1047 | // const char *pyClassName |
1048 | using ClassTy = pybind11::class_<DerivedTy, BaseTy>; |
1049 | using IsAFunctionTy = bool (*)(MlirAttribute); |
1050 | using GetTypeIDFunctionTy = MlirTypeID (*)(); |
1051 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; |
1052 | |
1053 | PyConcreteAttribute() = default; |
1054 | PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) |
1055 | : BaseTy(std::move(contextRef), attr) {} |
1056 | PyConcreteAttribute(PyAttribute &orig) |
1057 | : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} |
1058 | |
1059 | static MlirAttribute castFrom(PyAttribute &orig) { |
1060 | if (!DerivedTy::isaFunction(orig)) { |
1061 | auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); |
1062 | throw py::value_error((llvm::Twine("Cannot cast attribute to " ) + |
1063 | DerivedTy::pyClassName + " (from " + origRepr + |
1064 | ")" ) |
1065 | .str()); |
1066 | } |
1067 | return orig; |
1068 | } |
1069 | |
1070 | static void bind(pybind11::module &m) { |
1071 | auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), |
1072 | pybind11::module_local()); |
1073 | cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(), |
1074 | pybind11::arg("cast_from_attr" )); |
1075 | cls.def_static( |
1076 | "isinstance" , |
1077 | [](PyAttribute &otherAttr) -> bool { |
1078 | return DerivedTy::isaFunction(otherAttr); |
1079 | }, |
1080 | pybind11::arg("other" )); |
1081 | cls.def_property_readonly( |
1082 | "type" , [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); |
1083 | cls.def_property_readonly_static( |
1084 | "static_typeid" , [](py::object & /*class*/) -> MlirTypeID { |
1085 | if (DerivedTy::getTypeIdFunction) |
1086 | return DerivedTy::getTypeIdFunction(); |
1087 | throw py::attribute_error( |
1088 | (DerivedTy::pyClassName + llvm::Twine(" has no typeid." )).str()); |
1089 | }); |
1090 | cls.def_property_readonly("typeid" , [](PyAttribute &self) { |
1091 | return py::cast(self).attr("typeid" ).cast<MlirTypeID>(); |
1092 | }); |
1093 | cls.def("__repr__" , [](DerivedTy &self) { |
1094 | PyPrintAccumulator printAccum; |
1095 | printAccum.parts.append(DerivedTy::pyClassName); |
1096 | printAccum.parts.append("(" ); |
1097 | mlirAttributePrint(self, printAccum.getCallback(), |
1098 | printAccum.getUserData()); |
1099 | printAccum.parts.append(")" ); |
1100 | return printAccum.join(); |
1101 | }); |
1102 | |
1103 | if (DerivedTy::getTypeIdFunction) { |
1104 | PyGlobals::get().registerTypeCaster( |
1105 | DerivedTy::getTypeIdFunction(), |
1106 | pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { |
1107 | return pyAttribute; |
1108 | })); |
1109 | } |
1110 | |
1111 | DerivedTy::bindDerived(cls); |
1112 | } |
1113 | |
1114 | /// Implemented by derived classes to add methods to the Python subclass. |
1115 | static void bindDerived(ClassTy &m) {} |
1116 | }; |
1117 | |
1118 | /// Wrapper around the generic MlirValue. |
1119 | /// Values are managed completely by the operation that resulted in their |
1120 | /// definition. For op result value, this is the operation that defines the |
1121 | /// value. For block argument values, this is the operation that contains the |
1122 | /// block to which the value is an argument (blocks cannot be detached in Python |
1123 | /// bindings so such operation always exists). |
1124 | class PyValue { |
1125 | public: |
1126 | // The virtual here is "load bearing" in that it enables RTTI |
1127 | // for PyConcreteValue CRTP classes that support maybeDownCast. |
1128 | // See PyValue::maybeDownCast. |
1129 | virtual ~PyValue() = default; |
1130 | PyValue(PyOperationRef parentOperation, MlirValue value) |
1131 | : parentOperation(std::move(parentOperation)), value(value) {} |
1132 | operator MlirValue() const { return value; } |
1133 | |
1134 | MlirValue get() { return value; } |
1135 | PyOperationRef &getParentOperation() { return parentOperation; } |
1136 | |
1137 | void checkValid() { return parentOperation->checkValid(); } |
1138 | |
1139 | /// Gets a capsule wrapping the void* within the MlirValue. |
1140 | pybind11::object getCapsule(); |
1141 | |
1142 | pybind11::object maybeDownCast(); |
1143 | |
1144 | /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of |
1145 | /// the underlying MlirValue is still tied to the owning operation. |
1146 | static PyValue createFromCapsule(pybind11::object capsule); |
1147 | |
1148 | private: |
1149 | PyOperationRef parentOperation; |
1150 | MlirValue value; |
1151 | }; |
1152 | |
1153 | /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. |
1154 | class PyAffineExpr : public BaseContextObject { |
1155 | public: |
1156 | PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) |
1157 | : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} |
1158 | bool operator==(const PyAffineExpr &other) const; |
1159 | operator MlirAffineExpr() const { return affineExpr; } |
1160 | MlirAffineExpr get() const { return affineExpr; } |
1161 | |
1162 | /// Gets a capsule wrapping the void* within the MlirAffineExpr. |
1163 | pybind11::object getCapsule(); |
1164 | |
1165 | /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. |
1166 | /// Note that PyAffineExpr instances are uniqued, so the returned object |
1167 | /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr |
1168 | /// is taken by calling this function. |
1169 | static PyAffineExpr createFromCapsule(pybind11::object capsule); |
1170 | |
1171 | PyAffineExpr add(const PyAffineExpr &other) const; |
1172 | PyAffineExpr mul(const PyAffineExpr &other) const; |
1173 | PyAffineExpr floorDiv(const PyAffineExpr &other) const; |
1174 | PyAffineExpr ceilDiv(const PyAffineExpr &other) const; |
1175 | PyAffineExpr mod(const PyAffineExpr &other) const; |
1176 | |
1177 | private: |
1178 | MlirAffineExpr affineExpr; |
1179 | }; |
1180 | |
1181 | class PyAffineMap : public BaseContextObject { |
1182 | public: |
1183 | PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) |
1184 | : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} |
1185 | bool operator==(const PyAffineMap &other) const; |
1186 | operator MlirAffineMap() const { return affineMap; } |
1187 | MlirAffineMap get() const { return affineMap; } |
1188 | |
1189 | /// Gets a capsule wrapping the void* within the MlirAffineMap. |
1190 | pybind11::object getCapsule(); |
1191 | |
1192 | /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. |
1193 | /// Note that PyAffineMap instances are uniqued, so the returned object |
1194 | /// may be a pre-existing object. Ownership of the underlying MlirAffineMap |
1195 | /// is taken by calling this function. |
1196 | static PyAffineMap createFromCapsule(pybind11::object capsule); |
1197 | |
1198 | private: |
1199 | MlirAffineMap affineMap; |
1200 | }; |
1201 | |
1202 | class PyIntegerSet : public BaseContextObject { |
1203 | public: |
1204 | PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) |
1205 | : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} |
1206 | bool operator==(const PyIntegerSet &other) const; |
1207 | operator MlirIntegerSet() const { return integerSet; } |
1208 | MlirIntegerSet get() const { return integerSet; } |
1209 | |
1210 | /// Gets a capsule wrapping the void* within the MlirIntegerSet. |
1211 | pybind11::object getCapsule(); |
1212 | |
1213 | /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. |
1214 | /// Note that PyIntegerSet instances may be uniqued, so the returned object |
1215 | /// may be a pre-existing object. Integer sets are owned by the context. |
1216 | static PyIntegerSet createFromCapsule(pybind11::object capsule); |
1217 | |
1218 | private: |
1219 | MlirIntegerSet integerSet; |
1220 | }; |
1221 | |
1222 | /// Bindings for MLIR symbol tables. |
1223 | class PySymbolTable { |
1224 | public: |
1225 | /// Constructs a symbol table for the given operation. |
1226 | explicit PySymbolTable(PyOperationBase &operation); |
1227 | |
1228 | /// Destroys the symbol table. |
1229 | ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } |
1230 | |
1231 | /// Returns the symbol (opview) with the given name, throws if there is no |
1232 | /// such symbol in the table. |
1233 | pybind11::object dunderGetItem(const std::string &name); |
1234 | |
1235 | /// Removes the given operation from the symbol table and erases it. |
1236 | void erase(PyOperationBase &symbol); |
1237 | |
1238 | /// Removes the operation with the given name from the symbol table and erases |
1239 | /// it, throws if there is no such symbol in the table. |
1240 | void dunderDel(const std::string &name); |
1241 | |
1242 | /// Inserts the given operation into the symbol table. The operation must have |
1243 | /// the symbol trait. |
1244 | MlirAttribute insert(PyOperationBase &symbol); |
1245 | |
1246 | /// Gets and sets the name of a symbol op. |
1247 | static MlirAttribute getSymbolName(PyOperationBase &symbol); |
1248 | static void setSymbolName(PyOperationBase &symbol, const std::string &name); |
1249 | |
1250 | /// Gets and sets the visibility of a symbol op. |
1251 | static MlirAttribute getVisibility(PyOperationBase &symbol); |
1252 | static void setVisibility(PyOperationBase &symbol, |
1253 | const std::string &visibility); |
1254 | |
1255 | /// Replaces all symbol uses within an operation. See the API |
1256 | /// mlirSymbolTableReplaceAllSymbolUses for all caveats. |
1257 | static void replaceAllSymbolUses(const std::string &oldSymbol, |
1258 | const std::string &newSymbol, |
1259 | PyOperationBase &from); |
1260 | |
1261 | /// Walks all symbol tables under and including 'from'. |
1262 | static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, |
1263 | pybind11::object callback); |
1264 | |
1265 | /// Casts the bindings class into the C API structure. |
1266 | operator MlirSymbolTable() { return symbolTable; } |
1267 | |
1268 | private: |
1269 | PyOperationRef operation; |
1270 | MlirSymbolTable symbolTable; |
1271 | }; |
1272 | |
1273 | /// Custom exception that allows access to error diagnostic information. This is |
1274 | /// converted to the `ir.MLIRError` python exception when thrown. |
1275 | struct MLIRError { |
1276 | MLIRError(llvm::Twine message, |
1277 | std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {}) |
1278 | : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} |
1279 | std::string message; |
1280 | std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics; |
1281 | }; |
1282 | |
1283 | void populateIRAffine(pybind11::module &m); |
1284 | void populateIRAttributes(pybind11::module &m); |
1285 | void populateIRCore(pybind11::module &m); |
1286 | void populateIRInterfaces(pybind11::module &m); |
1287 | void populateIRTypes(pybind11::module &m); |
1288 | |
1289 | } // namespace python |
1290 | } // namespace mlir |
1291 | |
1292 | namespace pybind11 { |
1293 | namespace detail { |
1294 | |
1295 | template <> |
1296 | struct type_caster<mlir::python::DefaultingPyMlirContext> |
1297 | : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; |
1298 | template <> |
1299 | struct type_caster<mlir::python::DefaultingPyLocation> |
1300 | : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; |
1301 | |
1302 | } // namespace detail |
1303 | } // namespace pybind11 |
1304 | |
1305 | #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H |
1306 | |