| 1 | //===- IRModules.h - IR Submodules of pybind module -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 8 | //===----------------------------------------------------------------------===// |
| 9 | |
| 10 | #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H |
| 11 | #define MLIR_BINDINGS_PYTHON_IRMODULES_H |
| 12 | |
| 13 | #include <optional> |
| 14 | #include <sstream> |
| 15 | #include <utility> |
| 16 | #include <vector> |
| 17 | |
| 18 | #include "Globals.h" |
| 19 | #include "NanobindUtils.h" |
| 20 | #include "mlir-c/AffineExpr.h" |
| 21 | #include "mlir-c/AffineMap.h" |
| 22 | #include "mlir-c/Diagnostics.h" |
| 23 | #include "mlir-c/IR.h" |
| 24 | #include "mlir-c/IntegerSet.h" |
| 25 | #include "mlir-c/Transforms.h" |
| 26 | #include "mlir/Bindings/Python/Nanobind.h" |
| 27 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
| 28 | #include "llvm/ADT/DenseMap.h" |
| 29 | #include "llvm/Support/ThreadPool.h" |
| 30 | |
| 31 | namespace mlir { |
| 32 | namespace python { |
| 33 | |
| 34 | class PyBlock; |
| 35 | class PyDiagnostic; |
| 36 | class PyDiagnosticHandler; |
| 37 | class PyInsertionPoint; |
| 38 | class PyLocation; |
| 39 | class DefaultingPyLocation; |
| 40 | class PyMlirContext; |
| 41 | class DefaultingPyMlirContext; |
| 42 | class PyModule; |
| 43 | class PyOperation; |
| 44 | class PyOperationBase; |
| 45 | class PyType; |
| 46 | class PySymbolTable; |
| 47 | class PyValue; |
| 48 | |
| 49 | /// Template for a reference to a concrete type which captures a python |
| 50 | /// reference to its underlying python object. |
| 51 | template <typename T> |
| 52 | class PyObjectRef { |
| 53 | public: |
| 54 | PyObjectRef(T *referrent, nanobind::object object) |
| 55 | : referrent(referrent), object(std::move(object)) { |
| 56 | assert(this->referrent && |
| 57 | "cannot construct PyObjectRef with null referrent" ); |
| 58 | assert(this->object && "cannot construct PyObjectRef with null object" ); |
| 59 | } |
| 60 | PyObjectRef(PyObjectRef &&other) noexcept |
| 61 | : referrent(other.referrent), object(std::move(other.object)) { |
| 62 | other.referrent = nullptr; |
| 63 | assert(!other.object); |
| 64 | } |
| 65 | PyObjectRef(const PyObjectRef &other) |
| 66 | : referrent(other.referrent), object(other.object /* copies */) {} |
| 67 | ~PyObjectRef() = default; |
| 68 | |
| 69 | int getRefCount() { |
| 70 | if (!object) |
| 71 | return 0; |
| 72 | return Py_REFCNT(object.ptr()); |
| 73 | } |
| 74 | |
| 75 | /// Releases the object held by this instance, returning it. |
| 76 | /// This is the proper thing to return from a function that wants to return |
| 77 | /// the reference. Note that this does not work from initializers. |
| 78 | nanobind::object releaseObject() { |
| 79 | assert(referrent && object); |
| 80 | referrent = nullptr; |
| 81 | auto stolen = std::move(object); |
| 82 | return stolen; |
| 83 | } |
| 84 | |
| 85 | T *get() { return referrent; } |
| 86 | T *operator->() { |
| 87 | assert(referrent && object); |
| 88 | return referrent; |
| 89 | } |
| 90 | nanobind::object getObject() { |
| 91 | assert(referrent && object); |
| 92 | return object; |
| 93 | } |
| 94 | operator bool() const { return referrent && object; } |
| 95 | |
| 96 | private: |
| 97 | T *referrent; |
| 98 | nanobind::object object; |
| 99 | }; |
| 100 | |
| 101 | /// Tracks an entry in the thread context stack. New entries are pushed onto |
| 102 | /// here for each with block that activates a new InsertionPoint, Context or |
| 103 | /// Location. |
| 104 | /// |
| 105 | /// Pushing either a Location or InsertionPoint also pushes its associated |
| 106 | /// Context. Pushing a Context will not modify the Location or InsertionPoint |
| 107 | /// unless if they are from a different context, in which case, they are |
| 108 | /// cleared. |
| 109 | class PyThreadContextEntry { |
| 110 | public: |
| 111 | enum class FrameKind { |
| 112 | Context, |
| 113 | InsertionPoint, |
| 114 | Location, |
| 115 | }; |
| 116 | |
| 117 | PyThreadContextEntry(FrameKind frameKind, nanobind::object context, |
| 118 | nanobind::object insertionPoint, |
| 119 | nanobind::object location) |
| 120 | : context(std::move(context)), insertionPoint(std::move(insertionPoint)), |
| 121 | location(std::move(location)), frameKind(frameKind) {} |
| 122 | |
| 123 | /// Gets the top of stack context and return nullptr if not defined. |
| 124 | static PyMlirContext *getDefaultContext(); |
| 125 | |
| 126 | /// Gets the top of stack insertion point and return nullptr if not defined. |
| 127 | static PyInsertionPoint *getDefaultInsertionPoint(); |
| 128 | |
| 129 | /// Gets the top of stack location and returns nullptr if not defined. |
| 130 | static PyLocation *getDefaultLocation(); |
| 131 | |
| 132 | PyMlirContext *getContext(); |
| 133 | PyInsertionPoint *getInsertionPoint(); |
| 134 | PyLocation *getLocation(); |
| 135 | FrameKind getFrameKind() { return frameKind; } |
| 136 | |
| 137 | /// Stack management. |
| 138 | static PyThreadContextEntry *getTopOfStack(); |
| 139 | static nanobind::object pushContext(nanobind::object context); |
| 140 | static void popContext(PyMlirContext &context); |
| 141 | static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); |
| 142 | static void popInsertionPoint(PyInsertionPoint &insertionPoint); |
| 143 | static nanobind::object pushLocation(nanobind::object location); |
| 144 | static void popLocation(PyLocation &location); |
| 145 | |
| 146 | /// Gets the thread local stack. |
| 147 | static std::vector<PyThreadContextEntry> &getStack(); |
| 148 | |
| 149 | private: |
| 150 | static void push(FrameKind frameKind, nanobind::object context, |
| 151 | nanobind::object insertionPoint, nanobind::object location); |
| 152 | |
| 153 | /// An object reference to the PyContext. |
| 154 | nanobind::object context; |
| 155 | /// An object reference to the current insertion point. |
| 156 | nanobind::object insertionPoint; |
| 157 | /// An object reference to the current location. |
| 158 | nanobind::object location; |
| 159 | // The kind of push that was performed. |
| 160 | FrameKind frameKind; |
| 161 | }; |
| 162 | |
| 163 | /// Wrapper around MlirLlvmThreadPool |
| 164 | /// Python object owns the C++ thread pool |
| 165 | class PyThreadPool { |
| 166 | public: |
| 167 | PyThreadPool() { |
| 168 | ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>(); |
| 169 | } |
| 170 | PyThreadPool(const PyThreadPool &) = delete; |
| 171 | PyThreadPool(PyThreadPool &&) = delete; |
| 172 | |
| 173 | int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); } |
| 174 | MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); } |
| 175 | |
| 176 | std::string _mlir_thread_pool_ptr() const { |
| 177 | std::stringstream ss; |
| 178 | ss << ownedThreadPool.get(); |
| 179 | return ss.str(); |
| 180 | } |
| 181 | |
| 182 | private: |
| 183 | std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool; |
| 184 | }; |
| 185 | |
| 186 | /// Wrapper around MlirContext. |
| 187 | using PyMlirContextRef = PyObjectRef<PyMlirContext>; |
| 188 | class PyMlirContext { |
| 189 | public: |
| 190 | PyMlirContext() = delete; |
| 191 | PyMlirContext(MlirContext context); |
| 192 | PyMlirContext(const PyMlirContext &) = delete; |
| 193 | PyMlirContext(PyMlirContext &&) = delete; |
| 194 | |
| 195 | /// For the case of a python __init__ (nanobind::init) method, pybind11 is |
| 196 | /// quite strict about needing to return a pointer that is not yet associated |
| 197 | /// to an nanobind::object. Since the forContext() method acts like a pool, |
| 198 | /// possibly returning a recycled context, it does not satisfy this need. The |
| 199 | /// usual way in python to accomplish such a thing is to override __new__, but |
| 200 | /// that is also not supported by pybind11. Instead, we use this entry |
| 201 | /// point which always constructs a fresh context (which cannot alias an |
| 202 | /// existing one because it is fresh). |
| 203 | static PyMlirContext *createNewContextForInit(); |
| 204 | |
| 205 | /// Returns a context reference for the singleton PyMlirContext wrapper for |
| 206 | /// the given context. |
| 207 | static PyMlirContextRef forContext(MlirContext context); |
| 208 | ~PyMlirContext(); |
| 209 | |
| 210 | /// Accesses the underlying MlirContext. |
| 211 | MlirContext get() { return context; } |
| 212 | |
| 213 | /// Gets a strong reference to this context, which will ensure it is kept |
| 214 | /// alive for the life of the reference. |
| 215 | PyMlirContextRef getRef() { |
| 216 | return PyMlirContextRef(this, nanobind::cast(this)); |
| 217 | } |
| 218 | |
| 219 | /// Gets a capsule wrapping the void* within the MlirContext. |
| 220 | nanobind::object getCapsule(); |
| 221 | |
| 222 | /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. |
| 223 | /// Note that PyMlirContext instances are uniqued, so the returned object |
| 224 | /// may be a pre-existing object. Ownership of the underlying MlirContext |
| 225 | /// is taken by calling this function. |
| 226 | static nanobind::object createFromCapsule(nanobind::object capsule); |
| 227 | |
| 228 | /// Gets the count of live context objects. Used for testing. |
| 229 | static size_t getLiveCount(); |
| 230 | |
| 231 | /// Get a list of Python objects which are still in the live context map. |
| 232 | std::vector<PyOperation *> getLiveOperationObjects(); |
| 233 | |
| 234 | /// Gets the count of live operations associated with this context. |
| 235 | /// Used for testing. |
| 236 | size_t getLiveOperationCount(); |
| 237 | |
| 238 | /// Clears the live operations map, returning the number of entries which were |
| 239 | /// invalidated. To be used as a safety mechanism so that API end-users can't |
| 240 | /// corrupt by holding references they shouldn't have accessed in the first |
| 241 | /// place. |
| 242 | size_t clearLiveOperations(); |
| 243 | |
| 244 | /// Removes an operation from the live operations map and sets it invalid. |
| 245 | /// This is useful for when some non-bindings code destroys the operation and |
| 246 | /// the bindings need to made aware. For example, in the case when pass |
| 247 | /// manager is run. |
| 248 | /// |
| 249 | /// Note that this does *NOT* clear the nested operations. |
| 250 | void clearOperation(MlirOperation op); |
| 251 | |
| 252 | /// Clears all operations nested inside the given op using |
| 253 | /// `clearOperation(MlirOperation)`. |
| 254 | void clearOperationsInside(PyOperationBase &op); |
| 255 | void clearOperationsInside(MlirOperation op); |
| 256 | |
| 257 | /// Clears the operaiton _and_ all operations inside using |
| 258 | /// `clearOperation(MlirOperation)`. |
| 259 | void clearOperationAndInside(PyOperationBase &op); |
| 260 | |
| 261 | /// Gets the count of live modules associated with this context. |
| 262 | /// Used for testing. |
| 263 | size_t getLiveModuleCount(); |
| 264 | |
| 265 | /// Enter and exit the context manager. |
| 266 | static nanobind::object contextEnter(nanobind::object context); |
| 267 | void contextExit(const nanobind::object &excType, |
| 268 | const nanobind::object &excVal, |
| 269 | const nanobind::object &excTb); |
| 270 | |
| 271 | /// Attaches a Python callback as a diagnostic handler, returning a |
| 272 | /// registration object (internally a PyDiagnosticHandler). |
| 273 | nanobind::object attachDiagnosticHandler(nanobind::object callback); |
| 274 | |
| 275 | /// Controls whether error diagnostics should be propagated to diagnostic |
| 276 | /// handlers, instead of being captured by `ErrorCapture`. |
| 277 | void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } |
| 278 | struct ErrorCapture; |
| 279 | |
| 280 | private: |
| 281 | // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, |
| 282 | // preserving the relationship that an MlirContext maps to a single |
| 283 | // PyMlirContext wrapper. This could be replaced in the future with an |
| 284 | // extension mechanism on the MlirContext for stashing user pointers. |
| 285 | // Note that this holds a handle, which does not imply ownership. |
| 286 | // Mappings will be removed when the context is destructed. |
| 287 | using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; |
| 288 | static nanobind::ft_mutex live_contexts_mutex; |
| 289 | static LiveContextMap &getLiveContexts(); |
| 290 | |
| 291 | // Interns all live modules associated with this context. Modules tracked |
| 292 | // in this map are valid. When a module is invalidated, it is removed |
| 293 | // from this map, and while it still exists as an instance, any |
| 294 | // attempt to access it will raise an error. |
| 295 | using LiveModuleMap = |
| 296 | llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>; |
| 297 | LiveModuleMap liveModules; |
| 298 | |
| 299 | // Interns all live operations associated with this context. Operations |
| 300 | // tracked in this map are valid. When an operation is invalidated, it is |
| 301 | // removed from this map, and while it still exists as an instance, any |
| 302 | // attempt to access it will raise an error. |
| 303 | using LiveOperationMap = |
| 304 | llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>; |
| 305 | nanobind::ft_mutex liveOperationsMutex; |
| 306 | |
| 307 | // Guarded by liveOperationsMutex in free-threading mode. |
| 308 | LiveOperationMap liveOperations; |
| 309 | |
| 310 | bool emitErrorDiagnostics = false; |
| 311 | |
| 312 | MlirContext context; |
| 313 | friend class PyModule; |
| 314 | friend class PyOperation; |
| 315 | }; |
| 316 | |
| 317 | /// Used in function arguments when None should resolve to the current context |
| 318 | /// manager set instance. |
| 319 | class DefaultingPyMlirContext |
| 320 | : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { |
| 321 | public: |
| 322 | using Defaulting::Defaulting; |
| 323 | static constexpr const char kTypeDescription[] = "mlir.ir.Context" ; |
| 324 | static PyMlirContext &resolve(); |
| 325 | }; |
| 326 | |
| 327 | /// Base class for all objects that directly or indirectly depend on an |
| 328 | /// MlirContext. The lifetime of the context will extend at least to the |
| 329 | /// lifetime of these instances. |
| 330 | /// Immutable objects that depend on a context extend this directly. |
| 331 | class BaseContextObject { |
| 332 | public: |
| 333 | BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { |
| 334 | assert(this->contextRef && |
| 335 | "context object constructed with null context ref" ); |
| 336 | } |
| 337 | |
| 338 | /// Accesses the context reference. |
| 339 | PyMlirContextRef &getContext() { return contextRef; } |
| 340 | |
| 341 | private: |
| 342 | PyMlirContextRef contextRef; |
| 343 | }; |
| 344 | |
| 345 | /// Wrapper around an MlirLocation. |
| 346 | class PyLocation : public BaseContextObject { |
| 347 | public: |
| 348 | PyLocation(PyMlirContextRef contextRef, MlirLocation loc) |
| 349 | : BaseContextObject(std::move(contextRef)), loc(loc) {} |
| 350 | |
| 351 | operator MlirLocation() const { return loc; } |
| 352 | MlirLocation get() const { return loc; } |
| 353 | |
| 354 | /// Enter and exit the context manager. |
| 355 | static nanobind::object contextEnter(nanobind::object location); |
| 356 | void contextExit(const nanobind::object &excType, |
| 357 | const nanobind::object &excVal, |
| 358 | const nanobind::object &excTb); |
| 359 | |
| 360 | /// Gets a capsule wrapping the void* within the MlirLocation. |
| 361 | nanobind::object getCapsule(); |
| 362 | |
| 363 | /// Creates a PyLocation from the MlirLocation wrapped by a capsule. |
| 364 | /// Note that PyLocation instances are uniqued, so the returned object |
| 365 | /// may be a pre-existing object. Ownership of the underlying MlirLocation |
| 366 | /// is taken by calling this function. |
| 367 | static PyLocation createFromCapsule(nanobind::object capsule); |
| 368 | |
| 369 | private: |
| 370 | MlirLocation loc; |
| 371 | }; |
| 372 | |
| 373 | /// Python class mirroring the C MlirDiagnostic struct. Note that these structs |
| 374 | /// are only valid for the duration of a diagnostic callback and attempting |
| 375 | /// to access them outside of that will raise an exception. This applies to |
| 376 | /// nested diagnostics (in the notes) as well. |
| 377 | class PyDiagnostic { |
| 378 | public: |
| 379 | PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} |
| 380 | void invalidate(); |
| 381 | bool isValid() { return valid; } |
| 382 | MlirDiagnosticSeverity getSeverity(); |
| 383 | PyLocation getLocation(); |
| 384 | nanobind::str getMessage(); |
| 385 | nanobind::tuple getNotes(); |
| 386 | |
| 387 | /// Materialized diagnostic information. This is safe to access outside the |
| 388 | /// diagnostic callback. |
| 389 | struct DiagnosticInfo { |
| 390 | MlirDiagnosticSeverity severity; |
| 391 | PyLocation location; |
| 392 | std::string message; |
| 393 | std::vector<DiagnosticInfo> notes; |
| 394 | }; |
| 395 | DiagnosticInfo getInfo(); |
| 396 | |
| 397 | private: |
| 398 | MlirDiagnostic diagnostic; |
| 399 | |
| 400 | void checkValid(); |
| 401 | /// If notes have been materialized from the diagnostic, then this will |
| 402 | /// be populated with the corresponding objects (all castable to |
| 403 | /// PyDiagnostic). |
| 404 | std::optional<nanobind::tuple> materializedNotes; |
| 405 | bool valid = true; |
| 406 | }; |
| 407 | |
| 408 | /// Represents a diagnostic handler attached to the context. The handler's |
| 409 | /// callback will be invoked with PyDiagnostic instances until the detach() |
| 410 | /// method is called or the context is destroyed. A diagnostic handler can be |
| 411 | /// the subject of a `with` block, which will detach it when the block exits. |
| 412 | /// |
| 413 | /// Since diagnostic handlers can call back into Python code which can do |
| 414 | /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, |
| 415 | /// etc), this is generally not deemed to be a great user-level API. Users |
| 416 | /// should generally use some form of DiagnosticCollector. If the handler raises |
| 417 | /// any exceptions, they will just be emitted to stderr and dropped. |
| 418 | /// |
| 419 | /// The unique usage of this class means that its lifetime management is |
| 420 | /// different from most other parts of the API. Instances are always created |
| 421 | /// in an attached state and can transition to a detached state by either: |
| 422 | /// a) The context being destroyed and unregistering all handlers. |
| 423 | /// b) An explicit call to detach(). |
| 424 | /// The object may remain live from a Python perspective for an arbitrary time |
| 425 | /// after detachment, but there is nothing the user can do with it (since there |
| 426 | /// is no way to attach an existing handler object). |
| 427 | class PyDiagnosticHandler { |
| 428 | public: |
| 429 | PyDiagnosticHandler(MlirContext context, nanobind::object callback); |
| 430 | ~PyDiagnosticHandler(); |
| 431 | |
| 432 | bool isAttached() { return registeredID.has_value(); } |
| 433 | bool getHadError() { return hadError; } |
| 434 | |
| 435 | /// Detaches the handler. Does nothing if not attached. |
| 436 | void detach(); |
| 437 | |
| 438 | nanobind::object contextEnter() { return nanobind::cast(this); } |
| 439 | void contextExit(const nanobind::object &excType, |
| 440 | const nanobind::object &excVal, |
| 441 | const nanobind::object &excTb) { |
| 442 | detach(); |
| 443 | } |
| 444 | |
| 445 | private: |
| 446 | MlirContext context; |
| 447 | nanobind::object callback; |
| 448 | std::optional<MlirDiagnosticHandlerID> registeredID; |
| 449 | bool hadError = false; |
| 450 | friend class PyMlirContext; |
| 451 | }; |
| 452 | |
| 453 | /// RAII object that captures any error diagnostics emitted to the provided |
| 454 | /// context. |
| 455 | struct PyMlirContext::ErrorCapture { |
| 456 | ErrorCapture(PyMlirContextRef ctx) |
| 457 | : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( |
| 458 | ctx->get(), handler, /*userData=*/this, |
| 459 | /*deleteUserData=*/nullptr)) {} |
| 460 | ~ErrorCapture() { |
| 461 | mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); |
| 462 | assert(errors.empty() && "unhandled captured errors" ); |
| 463 | } |
| 464 | |
| 465 | std::vector<PyDiagnostic::DiagnosticInfo> take() { |
| 466 | return std::move(errors); |
| 467 | }; |
| 468 | |
| 469 | private: |
| 470 | PyMlirContextRef ctx; |
| 471 | MlirDiagnosticHandlerID handlerID; |
| 472 | std::vector<PyDiagnostic::DiagnosticInfo> errors; |
| 473 | |
| 474 | static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); |
| 475 | }; |
| 476 | |
| 477 | /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in |
| 478 | /// order to differentiate it from the `Dialect` base class which is extended by |
| 479 | /// plugins which extend dialect functionality through extension python code. |
| 480 | /// This should be seen as the "low-level" object and `Dialect` as the |
| 481 | /// high-level, user facing object. |
| 482 | class PyDialectDescriptor : public BaseContextObject { |
| 483 | public: |
| 484 | PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) |
| 485 | : BaseContextObject(std::move(contextRef)), dialect(dialect) {} |
| 486 | |
| 487 | MlirDialect get() { return dialect; } |
| 488 | |
| 489 | private: |
| 490 | MlirDialect dialect; |
| 491 | }; |
| 492 | |
| 493 | /// User-level object for accessing dialects with dotted syntax such as: |
| 494 | /// ctx.dialect.std |
| 495 | class PyDialects : public BaseContextObject { |
| 496 | public: |
| 497 | PyDialects(PyMlirContextRef contextRef) |
| 498 | : BaseContextObject(std::move(contextRef)) {} |
| 499 | |
| 500 | MlirDialect getDialectForKey(const std::string &key, bool attrError); |
| 501 | }; |
| 502 | |
| 503 | /// User-level dialect object. For dialects that have a registered extension, |
| 504 | /// this will be the base class of the extension dialect type. For un-extended, |
| 505 | /// objects of this type will be returned directly. |
| 506 | class PyDialect { |
| 507 | public: |
| 508 | PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} |
| 509 | |
| 510 | nanobind::object getDescriptor() { return descriptor; } |
| 511 | |
| 512 | private: |
| 513 | nanobind::object descriptor; |
| 514 | }; |
| 515 | |
| 516 | /// Wrapper around an MlirDialectRegistry. |
| 517 | /// Upon construction, the Python wrapper takes ownership of the |
| 518 | /// underlying MlirDialectRegistry. |
| 519 | class PyDialectRegistry { |
| 520 | public: |
| 521 | PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} |
| 522 | PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} |
| 523 | ~PyDialectRegistry() { |
| 524 | if (!mlirDialectRegistryIsNull(registry)) |
| 525 | mlirDialectRegistryDestroy(registry); |
| 526 | } |
| 527 | PyDialectRegistry(PyDialectRegistry &) = delete; |
| 528 | PyDialectRegistry(PyDialectRegistry &&other) noexcept |
| 529 | : registry(other.registry) { |
| 530 | other.registry = {nullptr}; |
| 531 | } |
| 532 | |
| 533 | operator MlirDialectRegistry() const { return registry; } |
| 534 | MlirDialectRegistry get() const { return registry; } |
| 535 | |
| 536 | nanobind::object getCapsule(); |
| 537 | static PyDialectRegistry createFromCapsule(nanobind::object capsule); |
| 538 | |
| 539 | private: |
| 540 | MlirDialectRegistry registry; |
| 541 | }; |
| 542 | |
| 543 | /// Used in function arguments when None should resolve to the current context |
| 544 | /// manager set instance. |
| 545 | class DefaultingPyLocation |
| 546 | : public Defaulting<DefaultingPyLocation, PyLocation> { |
| 547 | public: |
| 548 | using Defaulting::Defaulting; |
| 549 | static constexpr const char kTypeDescription[] = "mlir.ir.Location" ; |
| 550 | static PyLocation &resolve(); |
| 551 | |
| 552 | operator MlirLocation() const { return *get(); } |
| 553 | }; |
| 554 | |
| 555 | /// Wrapper around MlirModule. |
| 556 | /// This is the top-level, user-owned object that contains regions/ops/blocks. |
| 557 | class PyModule; |
| 558 | using PyModuleRef = PyObjectRef<PyModule>; |
| 559 | class PyModule : public BaseContextObject { |
| 560 | public: |
| 561 | /// Returns a PyModule reference for the given MlirModule. This may return |
| 562 | /// a pre-existing or new object. |
| 563 | static PyModuleRef forModule(MlirModule module); |
| 564 | PyModule(PyModule &) = delete; |
| 565 | PyModule(PyMlirContext &&) = delete; |
| 566 | ~PyModule(); |
| 567 | |
| 568 | /// Gets the backing MlirModule. |
| 569 | MlirModule get() { return module; } |
| 570 | |
| 571 | /// Gets a strong reference to this module. |
| 572 | PyModuleRef getRef() { |
| 573 | return PyModuleRef(this, nanobind::borrow<nanobind::object>(handle)); |
| 574 | } |
| 575 | |
| 576 | /// Gets a capsule wrapping the void* within the MlirModule. |
| 577 | /// Note that the module does not (yet) provide a corresponding factory for |
| 578 | /// constructing from a capsule as that would require uniquing PyModule |
| 579 | /// instances, which is not currently done. |
| 580 | nanobind::object getCapsule(); |
| 581 | |
| 582 | /// Creates a PyModule from the MlirModule wrapped by a capsule. |
| 583 | /// Note that PyModule instances are uniqued, so the returned object |
| 584 | /// may be a pre-existing object. Ownership of the underlying MlirModule |
| 585 | /// is taken by calling this function. |
| 586 | static nanobind::object createFromCapsule(nanobind::object capsule); |
| 587 | |
| 588 | private: |
| 589 | PyModule(PyMlirContextRef contextRef, MlirModule module); |
| 590 | MlirModule module; |
| 591 | nanobind::handle handle; |
| 592 | }; |
| 593 | |
| 594 | class PyAsmState; |
| 595 | |
| 596 | /// Base class for PyOperation and PyOpView which exposes the primary, user |
| 597 | /// visible methods for manipulating it. |
| 598 | class PyOperationBase { |
| 599 | public: |
| 600 | virtual ~PyOperationBase() = default; |
| 601 | /// Implements the bound 'print' method and helps with others. |
| 602 | void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo, |
| 603 | bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, |
| 604 | bool useNameLocAsPrefix, bool assumeVerified, |
| 605 | nanobind::object fileObject, bool binary, bool skipRegions); |
| 606 | void print(PyAsmState &state, nanobind::object fileObject, bool binary); |
| 607 | |
| 608 | nanobind::object getAsm(bool binary, |
| 609 | std::optional<int64_t> largeElementsLimit, |
| 610 | bool enableDebugInfo, bool prettyDebugInfo, |
| 611 | bool printGenericOpForm, bool useLocalScope, |
| 612 | bool useNameLocAsPrefix, bool assumeVerified, |
| 613 | bool skipRegions); |
| 614 | |
| 615 | // Implement the bound 'writeBytecode' method. |
| 616 | void writeBytecode(const nanobind::object &fileObject, |
| 617 | std::optional<int64_t> bytecodeVersion); |
| 618 | |
| 619 | // Implement the walk method. |
| 620 | void walk(std::function<MlirWalkResult(MlirOperation)> callback, |
| 621 | MlirWalkOrder walkOrder); |
| 622 | |
| 623 | /// Moves the operation before or after the other operation. |
| 624 | void moveAfter(PyOperationBase &other); |
| 625 | void moveBefore(PyOperationBase &other); |
| 626 | |
| 627 | /// Verify the operation. Throws `MLIRError` if verification fails, and |
| 628 | /// returns `true` otherwise. |
| 629 | bool verify(); |
| 630 | |
| 631 | /// Each must provide access to the raw Operation. |
| 632 | virtual PyOperation &getOperation() = 0; |
| 633 | }; |
| 634 | |
| 635 | /// Wrapper around PyOperation. |
| 636 | /// Operations exist in either an attached (dependent) or detached (top-level) |
| 637 | /// state. In the detached state (as on creation), an operation is owned by |
| 638 | /// the creator and its lifetime extends either until its reference count |
| 639 | /// drops to zero or it is attached to a parent, at which point its lifetime |
| 640 | /// is bounded by its top-level parent reference. |
| 641 | class PyOperation; |
| 642 | using PyOperationRef = PyObjectRef<PyOperation>; |
| 643 | class PyOperation : public PyOperationBase, public BaseContextObject { |
| 644 | public: |
| 645 | ~PyOperation() override; |
| 646 | PyOperation &getOperation() override { return *this; } |
| 647 | |
| 648 | /// Returns a PyOperation for the given MlirOperation, optionally associating |
| 649 | /// it with a parentKeepAlive. |
| 650 | static PyOperationRef |
| 651 | forOperation(PyMlirContextRef contextRef, MlirOperation operation, |
| 652 | nanobind::object parentKeepAlive = nanobind::object()); |
| 653 | |
| 654 | /// Creates a detached operation. The operation must not be associated with |
| 655 | /// any existing live operation. |
| 656 | static PyOperationRef |
| 657 | createDetached(PyMlirContextRef contextRef, MlirOperation operation, |
| 658 | nanobind::object parentKeepAlive = nanobind::object()); |
| 659 | |
| 660 | /// Parses a source string (either text assembly or bytecode), creating a |
| 661 | /// detached operation. |
| 662 | static PyOperationRef parse(PyMlirContextRef contextRef, |
| 663 | const std::string &sourceStr, |
| 664 | const std::string &sourceName); |
| 665 | |
| 666 | /// Detaches the operation from its parent block and updates its state |
| 667 | /// accordingly. |
| 668 | void detachFromParent() { |
| 669 | mlirOperationRemoveFromParent(getOperation()); |
| 670 | setDetached(); |
| 671 | parentKeepAlive = nanobind::object(); |
| 672 | } |
| 673 | |
| 674 | /// Gets the backing operation. |
| 675 | operator MlirOperation() const { return get(); } |
| 676 | MlirOperation get() const { |
| 677 | checkValid(); |
| 678 | return operation; |
| 679 | } |
| 680 | |
| 681 | PyOperationRef getRef() { |
| 682 | return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle)); |
| 683 | } |
| 684 | |
| 685 | bool isAttached() { return attached; } |
| 686 | void setAttached(const nanobind::object &parent = nanobind::object()) { |
| 687 | assert(!attached && "operation already attached" ); |
| 688 | attached = true; |
| 689 | } |
| 690 | void setDetached() { |
| 691 | assert(attached && "operation already detached" ); |
| 692 | attached = false; |
| 693 | } |
| 694 | void checkValid() const; |
| 695 | |
| 696 | /// Gets the owning block or raises an exception if the operation has no |
| 697 | /// owning block. |
| 698 | PyBlock getBlock(); |
| 699 | |
| 700 | /// Gets the parent operation or raises an exception if the operation has |
| 701 | /// no parent. |
| 702 | std::optional<PyOperationRef> getParentOperation(); |
| 703 | |
| 704 | /// Gets a capsule wrapping the void* within the MlirOperation. |
| 705 | nanobind::object getCapsule(); |
| 706 | |
| 707 | /// Creates a PyOperation from the MlirOperation wrapped by a capsule. |
| 708 | /// Ownership of the underlying MlirOperation is taken by calling this |
| 709 | /// function. |
| 710 | static nanobind::object createFromCapsule(nanobind::object capsule); |
| 711 | |
| 712 | /// Creates an operation. See corresponding python docstring. |
| 713 | static nanobind::object |
| 714 | create(std::string_view name, std::optional<std::vector<PyType *>> results, |
| 715 | llvm::ArrayRef<MlirValue> operands, |
| 716 | std::optional<nanobind::dict> attributes, |
| 717 | std::optional<std::vector<PyBlock *>> successors, int regions, |
| 718 | DefaultingPyLocation location, const nanobind::object &ip, |
| 719 | bool inferType); |
| 720 | |
| 721 | /// Creates an OpView suitable for this operation. |
| 722 | nanobind::object createOpView(); |
| 723 | |
| 724 | /// Erases the underlying MlirOperation, removes its pointer from the |
| 725 | /// parent context's live operations map, and sets the valid bit false. |
| 726 | void erase(); |
| 727 | |
| 728 | /// Invalidate the operation. |
| 729 | void setInvalid() { valid = false; } |
| 730 | |
| 731 | /// Clones this operation. |
| 732 | nanobind::object clone(const nanobind::object &ip); |
| 733 | |
| 734 | PyOperation(PyMlirContextRef contextRef, MlirOperation operation); |
| 735 | |
| 736 | private: |
| 737 | static PyOperationRef createInstance(PyMlirContextRef contextRef, |
| 738 | MlirOperation operation, |
| 739 | nanobind::object parentKeepAlive); |
| 740 | |
| 741 | MlirOperation operation; |
| 742 | nanobind::handle handle; |
| 743 | // Keeps the parent alive, regardless of whether it is an Operation or |
| 744 | // Module. |
| 745 | // TODO: As implemented, this facility is only sufficient for modeling the |
| 746 | // trivial module parent back-reference. Generalize this to also account for |
| 747 | // transitions from detached to attached and address TODOs in the |
| 748 | // ir_operation.py regarding testing corresponding lifetime guarantees. |
| 749 | nanobind::object parentKeepAlive; |
| 750 | bool attached = true; |
| 751 | bool valid = true; |
| 752 | |
| 753 | friend class PyOperationBase; |
| 754 | friend class PySymbolTable; |
| 755 | }; |
| 756 | |
| 757 | /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for |
| 758 | /// providing more instance-specific accessors and serve as the base class for |
| 759 | /// custom ODS-style operation classes. Since this class is subclass on the |
| 760 | /// python side, it must present an __init__ method that operates in pure |
| 761 | /// python types. |
| 762 | class PyOpView : public PyOperationBase { |
| 763 | public: |
| 764 | PyOpView(const nanobind::object &operationObject); |
| 765 | PyOperation &getOperation() override { return operation; } |
| 766 | |
| 767 | nanobind::object getOperationObject() { return operationObject; } |
| 768 | |
| 769 | static nanobind::object |
| 770 | buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec, |
| 771 | nanobind::object operandSegmentSpecObj, |
| 772 | nanobind::object resultSegmentSpecObj, |
| 773 | std::optional<nanobind::list> resultTypeList, |
| 774 | nanobind::list operandList, |
| 775 | std::optional<nanobind::dict> attributes, |
| 776 | std::optional<std::vector<PyBlock *>> successors, |
| 777 | std::optional<int> regions, DefaultingPyLocation location, |
| 778 | const nanobind::object &maybeIp); |
| 779 | |
| 780 | /// Construct an instance of a class deriving from OpView, bypassing its |
| 781 | /// `__init__` method. The derived class will typically define a constructor |
| 782 | /// that provides a convenient builder, but we need to side-step this when |
| 783 | /// constructing an `OpView` for an already-built operation. |
| 784 | /// |
| 785 | /// The caller is responsible for verifying that `operation` is a valid |
| 786 | /// operation to construct `cls` with. |
| 787 | static nanobind::object constructDerived(const nanobind::object &cls, |
| 788 | const nanobind::object &operation); |
| 789 | |
| 790 | private: |
| 791 | PyOperation &operation; // For efficient, cast-free access from C++ |
| 792 | nanobind::object operationObject; // Holds the reference. |
| 793 | }; |
| 794 | |
| 795 | /// Wrapper around an MlirRegion. |
| 796 | /// Regions are managed completely by their containing operation. Unlike the |
| 797 | /// C++ API, the python API does not support detached regions. |
| 798 | class PyRegion { |
| 799 | public: |
| 800 | PyRegion(PyOperationRef parentOperation, MlirRegion region) |
| 801 | : parentOperation(std::move(parentOperation)), region(region) { |
| 802 | assert(!mlirRegionIsNull(region) && "python region cannot be null" ); |
| 803 | } |
| 804 | operator MlirRegion() const { return region; } |
| 805 | |
| 806 | MlirRegion get() { return region; } |
| 807 | PyOperationRef &getParentOperation() { return parentOperation; } |
| 808 | |
| 809 | void checkValid() { return parentOperation->checkValid(); } |
| 810 | |
| 811 | private: |
| 812 | PyOperationRef parentOperation; |
| 813 | MlirRegion region; |
| 814 | }; |
| 815 | |
| 816 | /// Wrapper around an MlirAsmState. |
| 817 | class PyAsmState { |
| 818 | public: |
| 819 | PyAsmState(MlirValue value, bool useLocalScope) { |
| 820 | flags = mlirOpPrintingFlagsCreate(); |
| 821 | // The OpPrintingFlags are not exposed Python side, create locally and |
| 822 | // associate lifetime with the state. |
| 823 | if (useLocalScope) |
| 824 | mlirOpPrintingFlagsUseLocalScope(flags); |
| 825 | state = mlirAsmStateCreateForValue(value, flags); |
| 826 | } |
| 827 | |
| 828 | PyAsmState(PyOperationBase &operation, bool useLocalScope) { |
| 829 | flags = mlirOpPrintingFlagsCreate(); |
| 830 | // The OpPrintingFlags are not exposed Python side, create locally and |
| 831 | // associate lifetime with the state. |
| 832 | if (useLocalScope) |
| 833 | mlirOpPrintingFlagsUseLocalScope(flags); |
| 834 | state = |
| 835 | mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); |
| 836 | } |
| 837 | ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } |
| 838 | // Delete copy constructors. |
| 839 | PyAsmState(PyAsmState &other) = delete; |
| 840 | PyAsmState(const PyAsmState &other) = delete; |
| 841 | |
| 842 | MlirAsmState get() { return state; } |
| 843 | |
| 844 | private: |
| 845 | MlirAsmState state; |
| 846 | MlirOpPrintingFlags flags; |
| 847 | }; |
| 848 | |
| 849 | /// Wrapper around an MlirBlock. |
| 850 | /// Blocks are managed completely by their containing operation. Unlike the |
| 851 | /// C++ API, the python API does not support detached blocks. |
| 852 | class PyBlock { |
| 853 | public: |
| 854 | PyBlock(PyOperationRef parentOperation, MlirBlock block) |
| 855 | : parentOperation(std::move(parentOperation)), block(block) { |
| 856 | assert(!mlirBlockIsNull(block) && "python block cannot be null" ); |
| 857 | } |
| 858 | |
| 859 | MlirBlock get() { return block; } |
| 860 | PyOperationRef &getParentOperation() { return parentOperation; } |
| 861 | |
| 862 | void checkValid() { return parentOperation->checkValid(); } |
| 863 | |
| 864 | /// Gets a capsule wrapping the void* within the MlirBlock. |
| 865 | nanobind::object getCapsule(); |
| 866 | |
| 867 | private: |
| 868 | PyOperationRef parentOperation; |
| 869 | MlirBlock block; |
| 870 | }; |
| 871 | |
| 872 | /// An insertion point maintains a pointer to a Block and a reference operation. |
| 873 | /// Calls to insert() will insert a new operation before the |
| 874 | /// reference operation. If the reference operation is null, then appends to |
| 875 | /// the end of the block. |
| 876 | class PyInsertionPoint { |
| 877 | public: |
| 878 | /// Creates an insertion point positioned after the last operation in the |
| 879 | /// block, but still inside the block. |
| 880 | PyInsertionPoint(PyBlock &block); |
| 881 | /// Creates an insertion point positioned before a reference operation. |
| 882 | PyInsertionPoint(PyOperationBase &beforeOperationBase); |
| 883 | |
| 884 | /// Shortcut to create an insertion point at the beginning of the block. |
| 885 | static PyInsertionPoint atBlockBegin(PyBlock &block); |
| 886 | /// Shortcut to create an insertion point before the block terminator. |
| 887 | static PyInsertionPoint atBlockTerminator(PyBlock &block); |
| 888 | |
| 889 | /// Inserts an operation. |
| 890 | void insert(PyOperationBase &operationBase); |
| 891 | |
| 892 | /// Enter and exit the context manager. |
| 893 | static nanobind::object contextEnter(nanobind::object insertionPoint); |
| 894 | void contextExit(const nanobind::object &excType, |
| 895 | const nanobind::object &excVal, |
| 896 | const nanobind::object &excTb); |
| 897 | |
| 898 | PyBlock &getBlock() { return block; } |
| 899 | std::optional<PyOperationRef> &getRefOperation() { return refOperation; } |
| 900 | |
| 901 | private: |
| 902 | // Trampoline constructor that avoids null initializing members while |
| 903 | // looking up parents. |
| 904 | PyInsertionPoint(PyBlock block, std::optional<PyOperationRef> refOperation) |
| 905 | : refOperation(std::move(refOperation)), block(std::move(block)) {} |
| 906 | |
| 907 | std::optional<PyOperationRef> refOperation; |
| 908 | PyBlock block; |
| 909 | }; |
| 910 | /// Wrapper around the generic MlirType. |
| 911 | /// The lifetime of a type is bound by the PyContext that created it. |
| 912 | class PyType : public BaseContextObject { |
| 913 | public: |
| 914 | PyType(PyMlirContextRef contextRef, MlirType type) |
| 915 | : BaseContextObject(std::move(contextRef)), type(type) {} |
| 916 | bool operator==(const PyType &other) const; |
| 917 | operator MlirType() const { return type; } |
| 918 | MlirType get() const { return type; } |
| 919 | |
| 920 | /// Gets a capsule wrapping the void* within the MlirType. |
| 921 | nanobind::object getCapsule(); |
| 922 | |
| 923 | /// Creates a PyType from the MlirType wrapped by a capsule. |
| 924 | /// Note that PyType instances are uniqued, so the returned object |
| 925 | /// may be a pre-existing object. Ownership of the underlying MlirType |
| 926 | /// is taken by calling this function. |
| 927 | static PyType createFromCapsule(nanobind::object capsule); |
| 928 | |
| 929 | private: |
| 930 | MlirType type; |
| 931 | }; |
| 932 | |
| 933 | /// A TypeID provides an efficient and unique identifier for a specific C++ |
| 934 | /// type. This allows for a C++ type to be compared, hashed, and stored in an |
| 935 | /// opaque context. This class wraps around the generic MlirTypeID. |
| 936 | class PyTypeID { |
| 937 | public: |
| 938 | PyTypeID(MlirTypeID typeID) : typeID(typeID) {} |
| 939 | // Note, this tests whether the underlying TypeIDs are the same, |
| 940 | // not whether the wrapper MlirTypeIDs are the same, nor whether |
| 941 | // the PyTypeID objects are the same (i.e., PyTypeID is a value type). |
| 942 | bool operator==(const PyTypeID &other) const; |
| 943 | operator MlirTypeID() const { return typeID; } |
| 944 | MlirTypeID get() { return typeID; } |
| 945 | |
| 946 | /// Gets a capsule wrapping the void* within the MlirTypeID. |
| 947 | nanobind::object getCapsule(); |
| 948 | |
| 949 | /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. |
| 950 | static PyTypeID createFromCapsule(nanobind::object capsule); |
| 951 | |
| 952 | private: |
| 953 | MlirTypeID typeID; |
| 954 | }; |
| 955 | |
| 956 | /// CRTP base classes for Python types that subclass Type and should be |
| 957 | /// castable from it (i.e. via something like IntegerType(t)). |
| 958 | /// By default, type class hierarchies are one level deep (i.e. a |
| 959 | /// concrete type class extends PyType); however, intermediate python-visible |
| 960 | /// base classes can be modeled by specifying a BaseTy. |
| 961 | template <typename DerivedTy, typename BaseTy = PyType> |
| 962 | class PyConcreteType : public BaseTy { |
| 963 | public: |
| 964 | // Derived classes must define statics for: |
| 965 | // IsAFunctionTy isaFunction |
| 966 | // const char *pyClassName |
| 967 | using ClassTy = nanobind::class_<DerivedTy, BaseTy>; |
| 968 | using IsAFunctionTy = bool (*)(MlirType); |
| 969 | using GetTypeIDFunctionTy = MlirTypeID (*)(); |
| 970 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; |
| 971 | |
| 972 | PyConcreteType() = default; |
| 973 | PyConcreteType(PyMlirContextRef contextRef, MlirType t) |
| 974 | : BaseTy(std::move(contextRef), t) {} |
| 975 | PyConcreteType(PyType &orig) |
| 976 | : PyConcreteType(orig.getContext(), castFrom(orig)) {} |
| 977 | |
| 978 | static MlirType castFrom(PyType &orig) { |
| 979 | if (!DerivedTy::isaFunction(orig)) { |
| 980 | auto origRepr = |
| 981 | nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig))); |
| 982 | throw nanobind::value_error((llvm::Twine("Cannot cast type to " ) + |
| 983 | DerivedTy::pyClassName + " (from " + |
| 984 | origRepr + ")" ) |
| 985 | .str() |
| 986 | .c_str()); |
| 987 | } |
| 988 | return orig; |
| 989 | } |
| 990 | |
| 991 | static void bind(nanobind::module_ &m) { |
| 992 | auto cls = ClassTy(m, DerivedTy::pyClassName); |
| 993 | cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(), |
| 994 | nanobind::arg("cast_from_type" )); |
| 995 | cls.def_static( |
| 996 | "isinstance" , |
| 997 | [](PyType &otherType) -> bool { |
| 998 | return DerivedTy::isaFunction(otherType); |
| 999 | }, |
| 1000 | nanobind::arg("other" )); |
| 1001 | cls.def_prop_ro_static( |
| 1002 | "static_typeid" , [](nanobind::object & /*class*/) -> MlirTypeID { |
| 1003 | if (DerivedTy::getTypeIdFunction) |
| 1004 | return DerivedTy::getTypeIdFunction(); |
| 1005 | throw nanobind::attribute_error( |
| 1006 | (DerivedTy::pyClassName + llvm::Twine(" has no typeid." )) |
| 1007 | .str() |
| 1008 | .c_str()); |
| 1009 | }); |
| 1010 | cls.def_prop_ro("typeid" , [](PyType &self) { |
| 1011 | return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid" )); |
| 1012 | }); |
| 1013 | cls.def("__repr__" , [](DerivedTy &self) { |
| 1014 | PyPrintAccumulator printAccum; |
| 1015 | printAccum.parts.append(DerivedTy::pyClassName); |
| 1016 | printAccum.parts.append("(" ); |
| 1017 | mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); |
| 1018 | printAccum.parts.append(")" ); |
| 1019 | return printAccum.join(); |
| 1020 | }); |
| 1021 | |
| 1022 | if (DerivedTy::getTypeIdFunction) { |
| 1023 | PyGlobals::get().registerTypeCaster( |
| 1024 | DerivedTy::getTypeIdFunction(), |
| 1025 | nanobind::cast<nanobind::callable>(nanobind::cpp_function( |
| 1026 | [](PyType pyType) -> DerivedTy { return pyType; }))); |
| 1027 | } |
| 1028 | |
| 1029 | DerivedTy::bindDerived(cls); |
| 1030 | } |
| 1031 | |
| 1032 | /// Implemented by derived classes to add methods to the Python subclass. |
| 1033 | static void bindDerived(ClassTy &m) {} |
| 1034 | }; |
| 1035 | |
| 1036 | /// Wrapper around the generic MlirAttribute. |
| 1037 | /// The lifetime of a type is bound by the PyContext that created it. |
| 1038 | class PyAttribute : public BaseContextObject { |
| 1039 | public: |
| 1040 | PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) |
| 1041 | : BaseContextObject(std::move(contextRef)), attr(attr) {} |
| 1042 | bool operator==(const PyAttribute &other) const; |
| 1043 | operator MlirAttribute() const { return attr; } |
| 1044 | MlirAttribute get() const { return attr; } |
| 1045 | |
| 1046 | /// Gets a capsule wrapping the void* within the MlirAttribute. |
| 1047 | nanobind::object getCapsule(); |
| 1048 | |
| 1049 | /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. |
| 1050 | /// Note that PyAttribute instances are uniqued, so the returned object |
| 1051 | /// may be a pre-existing object. Ownership of the underlying MlirAttribute |
| 1052 | /// is taken by calling this function. |
| 1053 | static PyAttribute createFromCapsule(nanobind::object capsule); |
| 1054 | |
| 1055 | private: |
| 1056 | MlirAttribute attr; |
| 1057 | }; |
| 1058 | |
| 1059 | /// Represents a Python MlirNamedAttr, carrying an optional owned name. |
| 1060 | /// TODO: Refactor this and the C-API to be based on an Identifier owned |
| 1061 | /// by the context so as to avoid ownership issues here. |
| 1062 | class PyNamedAttribute { |
| 1063 | public: |
| 1064 | /// Constructs a PyNamedAttr that retains an owned name. This should be |
| 1065 | /// used in any code that originates an MlirNamedAttribute from a python |
| 1066 | /// string. |
| 1067 | /// The lifetime of the PyNamedAttr must extend to the lifetime of the |
| 1068 | /// passed attribute. |
| 1069 | PyNamedAttribute(MlirAttribute attr, std::string ownedName); |
| 1070 | |
| 1071 | MlirNamedAttribute namedAttr; |
| 1072 | |
| 1073 | private: |
| 1074 | // Since the MlirNamedAttr contains an internal pointer to the actual |
| 1075 | // memory of the owned string, it must be heap allocated to remain valid. |
| 1076 | // Otherwise, strings that fit within the small object optimization threshold |
| 1077 | // will have their memory address change as the containing object is moved, |
| 1078 | // resulting in an invalid aliased pointer. |
| 1079 | std::unique_ptr<std::string> ownedName; |
| 1080 | }; |
| 1081 | |
| 1082 | /// CRTP base classes for Python attributes that subclass Attribute and should |
| 1083 | /// be castable from it (i.e. via something like StringAttr(attr)). |
| 1084 | /// By default, attribute class hierarchies are one level deep (i.e. a |
| 1085 | /// concrete attribute class extends PyAttribute); however, intermediate |
| 1086 | /// python-visible base classes can be modeled by specifying a BaseTy. |
| 1087 | template <typename DerivedTy, typename BaseTy = PyAttribute> |
| 1088 | class PyConcreteAttribute : public BaseTy { |
| 1089 | public: |
| 1090 | // Derived classes must define statics for: |
| 1091 | // IsAFunctionTy isaFunction |
| 1092 | // const char *pyClassName |
| 1093 | using ClassTy = nanobind::class_<DerivedTy, BaseTy>; |
| 1094 | using IsAFunctionTy = bool (*)(MlirAttribute); |
| 1095 | using GetTypeIDFunctionTy = MlirTypeID (*)(); |
| 1096 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; |
| 1097 | |
| 1098 | PyConcreteAttribute() = default; |
| 1099 | PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) |
| 1100 | : BaseTy(std::move(contextRef), attr) {} |
| 1101 | PyConcreteAttribute(PyAttribute &orig) |
| 1102 | : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} |
| 1103 | |
| 1104 | static MlirAttribute castFrom(PyAttribute &orig) { |
| 1105 | if (!DerivedTy::isaFunction(orig)) { |
| 1106 | auto origRepr = |
| 1107 | nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig))); |
| 1108 | throw nanobind::value_error((llvm::Twine("Cannot cast attribute to " ) + |
| 1109 | DerivedTy::pyClassName + " (from " + |
| 1110 | origRepr + ")" ) |
| 1111 | .str() |
| 1112 | .c_str()); |
| 1113 | } |
| 1114 | return orig; |
| 1115 | } |
| 1116 | |
| 1117 | static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { |
| 1118 | ClassTy cls; |
| 1119 | if (slots) { |
| 1120 | cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); |
| 1121 | } else { |
| 1122 | cls = ClassTy(m, DerivedTy::pyClassName); |
| 1123 | } |
| 1124 | cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(), |
| 1125 | nanobind::arg("cast_from_attr" )); |
| 1126 | cls.def_static( |
| 1127 | "isinstance" , |
| 1128 | [](PyAttribute &otherAttr) -> bool { |
| 1129 | return DerivedTy::isaFunction(otherAttr); |
| 1130 | }, |
| 1131 | nanobind::arg("other" )); |
| 1132 | cls.def_prop_ro( |
| 1133 | "type" , [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); |
| 1134 | cls.def_prop_ro_static( |
| 1135 | "static_typeid" , [](nanobind::object & /*class*/) -> MlirTypeID { |
| 1136 | if (DerivedTy::getTypeIdFunction) |
| 1137 | return DerivedTy::getTypeIdFunction(); |
| 1138 | throw nanobind::attribute_error( |
| 1139 | (DerivedTy::pyClassName + llvm::Twine(" has no typeid." )) |
| 1140 | .str() |
| 1141 | .c_str()); |
| 1142 | }); |
| 1143 | cls.def_prop_ro("typeid" , [](PyAttribute &self) { |
| 1144 | return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid" )); |
| 1145 | }); |
| 1146 | cls.def("__repr__" , [](DerivedTy &self) { |
| 1147 | PyPrintAccumulator printAccum; |
| 1148 | printAccum.parts.append(DerivedTy::pyClassName); |
| 1149 | printAccum.parts.append("(" ); |
| 1150 | mlirAttributePrint(self, printAccum.getCallback(), |
| 1151 | printAccum.getUserData()); |
| 1152 | printAccum.parts.append(")" ); |
| 1153 | return printAccum.join(); |
| 1154 | }); |
| 1155 | |
| 1156 | if (DerivedTy::getTypeIdFunction) { |
| 1157 | PyGlobals::get().registerTypeCaster( |
| 1158 | DerivedTy::getTypeIdFunction(), |
| 1159 | nanobind::cast<nanobind::callable>( |
| 1160 | nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { |
| 1161 | return pyAttribute; |
| 1162 | }))); |
| 1163 | } |
| 1164 | |
| 1165 | DerivedTy::bindDerived(cls); |
| 1166 | } |
| 1167 | |
| 1168 | /// Implemented by derived classes to add methods to the Python subclass. |
| 1169 | static void bindDerived(ClassTy &m) {} |
| 1170 | }; |
| 1171 | |
| 1172 | /// Wrapper around the generic MlirValue. |
| 1173 | /// Values are managed completely by the operation that resulted in their |
| 1174 | /// definition. For op result value, this is the operation that defines the |
| 1175 | /// value. For block argument values, this is the operation that contains the |
| 1176 | /// block to which the value is an argument (blocks cannot be detached in Python |
| 1177 | /// bindings so such operation always exists). |
| 1178 | class PyValue { |
| 1179 | public: |
| 1180 | // The virtual here is "load bearing" in that it enables RTTI |
| 1181 | // for PyConcreteValue CRTP classes that support maybeDownCast. |
| 1182 | // See PyValue::maybeDownCast. |
| 1183 | virtual ~PyValue() = default; |
| 1184 | PyValue(PyOperationRef parentOperation, MlirValue value) |
| 1185 | : parentOperation(std::move(parentOperation)), value(value) {} |
| 1186 | operator MlirValue() const { return value; } |
| 1187 | |
| 1188 | MlirValue get() { return value; } |
| 1189 | PyOperationRef &getParentOperation() { return parentOperation; } |
| 1190 | |
| 1191 | void checkValid() { return parentOperation->checkValid(); } |
| 1192 | |
| 1193 | /// Gets a capsule wrapping the void* within the MlirValue. |
| 1194 | nanobind::object getCapsule(); |
| 1195 | |
| 1196 | nanobind::object maybeDownCast(); |
| 1197 | |
| 1198 | /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of |
| 1199 | /// the underlying MlirValue is still tied to the owning operation. |
| 1200 | static PyValue createFromCapsule(nanobind::object capsule); |
| 1201 | |
| 1202 | private: |
| 1203 | PyOperationRef parentOperation; |
| 1204 | MlirValue value; |
| 1205 | }; |
| 1206 | |
| 1207 | /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. |
| 1208 | class PyAffineExpr : public BaseContextObject { |
| 1209 | public: |
| 1210 | PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) |
| 1211 | : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} |
| 1212 | bool operator==(const PyAffineExpr &other) const; |
| 1213 | operator MlirAffineExpr() const { return affineExpr; } |
| 1214 | MlirAffineExpr get() const { return affineExpr; } |
| 1215 | |
| 1216 | /// Gets a capsule wrapping the void* within the MlirAffineExpr. |
| 1217 | nanobind::object getCapsule(); |
| 1218 | |
| 1219 | /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. |
| 1220 | /// Note that PyAffineExpr instances are uniqued, so the returned object |
| 1221 | /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr |
| 1222 | /// is taken by calling this function. |
| 1223 | static PyAffineExpr createFromCapsule(nanobind::object capsule); |
| 1224 | |
| 1225 | PyAffineExpr add(const PyAffineExpr &other) const; |
| 1226 | PyAffineExpr mul(const PyAffineExpr &other) const; |
| 1227 | PyAffineExpr floorDiv(const PyAffineExpr &other) const; |
| 1228 | PyAffineExpr ceilDiv(const PyAffineExpr &other) const; |
| 1229 | PyAffineExpr mod(const PyAffineExpr &other) const; |
| 1230 | |
| 1231 | private: |
| 1232 | MlirAffineExpr affineExpr; |
| 1233 | }; |
| 1234 | |
| 1235 | class PyAffineMap : public BaseContextObject { |
| 1236 | public: |
| 1237 | PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) |
| 1238 | : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} |
| 1239 | bool operator==(const PyAffineMap &other) const; |
| 1240 | operator MlirAffineMap() const { return affineMap; } |
| 1241 | MlirAffineMap get() const { return affineMap; } |
| 1242 | |
| 1243 | /// Gets a capsule wrapping the void* within the MlirAffineMap. |
| 1244 | nanobind::object getCapsule(); |
| 1245 | |
| 1246 | /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. |
| 1247 | /// Note that PyAffineMap instances are uniqued, so the returned object |
| 1248 | /// may be a pre-existing object. Ownership of the underlying MlirAffineMap |
| 1249 | /// is taken by calling this function. |
| 1250 | static PyAffineMap createFromCapsule(nanobind::object capsule); |
| 1251 | |
| 1252 | private: |
| 1253 | MlirAffineMap affineMap; |
| 1254 | }; |
| 1255 | |
| 1256 | class PyIntegerSet : public BaseContextObject { |
| 1257 | public: |
| 1258 | PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) |
| 1259 | : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} |
| 1260 | bool operator==(const PyIntegerSet &other) const; |
| 1261 | operator MlirIntegerSet() const { return integerSet; } |
| 1262 | MlirIntegerSet get() const { return integerSet; } |
| 1263 | |
| 1264 | /// Gets a capsule wrapping the void* within the MlirIntegerSet. |
| 1265 | nanobind::object getCapsule(); |
| 1266 | |
| 1267 | /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. |
| 1268 | /// Note that PyIntegerSet instances may be uniqued, so the returned object |
| 1269 | /// may be a pre-existing object. Integer sets are owned by the context. |
| 1270 | static PyIntegerSet createFromCapsule(nanobind::object capsule); |
| 1271 | |
| 1272 | private: |
| 1273 | MlirIntegerSet integerSet; |
| 1274 | }; |
| 1275 | |
| 1276 | /// Bindings for MLIR symbol tables. |
| 1277 | class PySymbolTable { |
| 1278 | public: |
| 1279 | /// Constructs a symbol table for the given operation. |
| 1280 | explicit PySymbolTable(PyOperationBase &operation); |
| 1281 | |
| 1282 | /// Destroys the symbol table. |
| 1283 | ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } |
| 1284 | |
| 1285 | /// Returns the symbol (opview) with the given name, throws if there is no |
| 1286 | /// such symbol in the table. |
| 1287 | nanobind::object dunderGetItem(const std::string &name); |
| 1288 | |
| 1289 | /// Removes the given operation from the symbol table and erases it. |
| 1290 | void erase(PyOperationBase &symbol); |
| 1291 | |
| 1292 | /// Removes the operation with the given name from the symbol table and erases |
| 1293 | /// it, throws if there is no such symbol in the table. |
| 1294 | void dunderDel(const std::string &name); |
| 1295 | |
| 1296 | /// Inserts the given operation into the symbol table. The operation must have |
| 1297 | /// the symbol trait. |
| 1298 | MlirAttribute insert(PyOperationBase &symbol); |
| 1299 | |
| 1300 | /// Gets and sets the name of a symbol op. |
| 1301 | static MlirAttribute getSymbolName(PyOperationBase &symbol); |
| 1302 | static void setSymbolName(PyOperationBase &symbol, const std::string &name); |
| 1303 | |
| 1304 | /// Gets and sets the visibility of a symbol op. |
| 1305 | static MlirAttribute getVisibility(PyOperationBase &symbol); |
| 1306 | static void setVisibility(PyOperationBase &symbol, |
| 1307 | const std::string &visibility); |
| 1308 | |
| 1309 | /// Replaces all symbol uses within an operation. See the API |
| 1310 | /// mlirSymbolTableReplaceAllSymbolUses for all caveats. |
| 1311 | static void replaceAllSymbolUses(const std::string &oldSymbol, |
| 1312 | const std::string &newSymbol, |
| 1313 | PyOperationBase &from); |
| 1314 | |
| 1315 | /// Walks all symbol tables under and including 'from'. |
| 1316 | static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, |
| 1317 | nanobind::object callback); |
| 1318 | |
| 1319 | /// Casts the bindings class into the C API structure. |
| 1320 | operator MlirSymbolTable() { return symbolTable; } |
| 1321 | |
| 1322 | private: |
| 1323 | PyOperationRef operation; |
| 1324 | MlirSymbolTable symbolTable; |
| 1325 | }; |
| 1326 | |
| 1327 | /// Custom exception that allows access to error diagnostic information. This is |
| 1328 | /// converted to the `ir.MLIRError` python exception when thrown. |
| 1329 | struct MLIRError { |
| 1330 | MLIRError(llvm::Twine message, |
| 1331 | std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {}) |
| 1332 | : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} |
| 1333 | std::string message; |
| 1334 | std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics; |
| 1335 | }; |
| 1336 | |
| 1337 | void populateIRAffine(nanobind::module_ &m); |
| 1338 | void populateIRAttributes(nanobind::module_ &m); |
| 1339 | void populateIRCore(nanobind::module_ &m); |
| 1340 | void populateIRInterfaces(nanobind::module_ &m); |
| 1341 | void populateIRTypes(nanobind::module_ &m); |
| 1342 | |
| 1343 | } // namespace python |
| 1344 | } // namespace mlir |
| 1345 | |
| 1346 | namespace nanobind { |
| 1347 | namespace detail { |
| 1348 | |
| 1349 | template <> |
| 1350 | struct type_caster<mlir::python::DefaultingPyMlirContext> |
| 1351 | : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; |
| 1352 | template <> |
| 1353 | struct type_caster<mlir::python::DefaultingPyLocation> |
| 1354 | : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; |
| 1355 | |
| 1356 | } // namespace detail |
| 1357 | } // namespace nanobind |
| 1358 | |
| 1359 | #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H |
| 1360 | |