1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IRModule.h"
10
11#include "Globals.h"
12#include "PybindUtils.h"
13
14#include "mlir-c/Bindings/Python/Interop.h"
15#include "mlir-c/BuiltinAttributes.h"
16#include "mlir-c/Debug.h"
17#include "mlir-c/Diagnostics.h"
18#include "mlir-c/IR.h"
19#include "mlir-c/Support.h"
20#include "mlir/Bindings/Python/PybindAdaptors.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/SmallVector.h"
23
24#include <optional>
25#include <utility>
26
27namespace py = pybind11;
28using namespace py::literals;
29using namespace mlir;
30using namespace mlir::python;
31
32using llvm::SmallVector;
33using llvm::StringRef;
34using llvm::Twine;
35
36//------------------------------------------------------------------------------
37// Docstrings (trivial, non-duplicated docstrings are included inline).
38//------------------------------------------------------------------------------
39
40static const char kContextParseTypeDocstring[] =
41 R"(Parses the assembly form of a type.
42
43Returns a Type object or raises an MLIRError if the type cannot be parsed.
44
45See also: https://mlir.llvm.org/docs/LangRef/#type-system
46)";
47
48static const char kContextGetCallSiteLocationDocstring[] =
49 R"(Gets a Location representing a caller and callsite)";
50
51static const char kContextGetFileLocationDocstring[] =
52 R"(Gets a Location representing a file, line and column)";
53
54static const char kContextGetFusedLocationDocstring[] =
55 R"(Gets a Location representing a fused location with optional metadata)";
56
57static const char kContextGetNameLocationDocString[] =
58 R"(Gets a Location representing a named location with optional child location)";
59
60static const char kModuleParseDocstring[] =
61 R"(Parses a module's assembly format from a string.
62
63Returns a new MlirModule or raises an MLIRError if the parsing fails.
64
65See also: https://mlir.llvm.org/docs/LangRef/
66)";
67
68static const char kOperationCreateDocstring[] =
69 R"(Creates a new operation.
70
71Args:
72 name: Operation name (e.g. "dialect.operation").
73 results: Sequence of Type representing op result types.
74 attributes: Dict of str:Attribute.
75 successors: List of Block for the operation's successors.
76 regions: Number of regions to create.
77 location: A Location object (defaults to resolve from context manager).
78 ip: An InsertionPoint (defaults to resolve from context manager or set to
79 False to disable insertion, even with an insertion point set in the
80 context manager).
81 infer_type: Whether to infer result types.
82Returns:
83 A new "detached" Operation object. Detached operations can be added
84 to blocks, which causes them to become "attached."
85)";
86
87static const char kOperationPrintDocstring[] =
88 R"(Prints the assembly form of the operation to a file like object.
89
90Args:
91 file: The file like object to write to. Defaults to sys.stdout.
92 binary: Whether to write bytes (True) or str (False). Defaults to False.
93 large_elements_limit: Whether to elide elements attributes above this
94 number of elements. Defaults to None (no limit).
95 enable_debug_info: Whether to print debug/location information. Defaults
96 to False.
97 pretty_debug_info: Whether to format debug information for easier reading
98 by a human (warning: the result is unparseable).
99 print_generic_op_form: Whether to print the generic assembly forms of all
100 ops. Defaults to False.
101 use_local_Scope: Whether to print in a way that is more optimized for
102 multi-threaded access but may not be consistent with how the overall
103 module prints.
104 assume_verified: By default, if not printing generic form, the verifier
105 will be run and if it fails, generic form will be printed with a comment
106 about failed verification. While a reasonable default for interactive use,
107 for systematic use, it is often better for the caller to verify explicitly
108 and report failures in a more robust fashion. Set this to True if doing this
109 in order to avoid running a redundant verification. If the IR is actually
110 invalid, behavior is undefined.
111)";
112
113static const char kOperationPrintStateDocstring[] =
114 R"(Prints the assembly form of the operation to a file like object.
115
116Args:
117 file: The file like object to write to. Defaults to sys.stdout.
118 binary: Whether to write bytes (True) or str (False). Defaults to False.
119 state: AsmState capturing the operation numbering and flags.
120)";
121
122static const char kOperationGetAsmDocstring[] =
123 R"(Gets the assembly form of the operation with all options available.
124
125Args:
126 binary: Whether to return a bytes (True) or str (False) object. Defaults to
127 False.
128 ... others ...: See the print() method for common keyword arguments for
129 configuring the printout.
130Returns:
131 Either a bytes or str object, depending on the setting of the 'binary'
132 argument.
133)";
134
135static const char kOperationPrintBytecodeDocstring[] =
136 R"(Write the bytecode form of the operation to a file like object.
137
138Args:
139 file: The file like object to write to.
140 desired_version: The version of bytecode to emit.
141Returns:
142 The bytecode writer status.
143)";
144
145static const char kOperationStrDunderDocstring[] =
146 R"(Gets the assembly form of the operation with default options.
147
148If more advanced control over the assembly formatting or I/O options is needed,
149use the dedicated print or get_asm method, which supports keyword arguments to
150customize behavior.
151)";
152
153static const char kDumpDocstring[] =
154 R"(Dumps a debug representation of the object to stderr.)";
155
156static const char kAppendBlockDocstring[] =
157 R"(Appends a new block, with argument types as positional args.
158
159Returns:
160 The created block.
161)";
162
163static const char kValueDunderStrDocstring[] =
164 R"(Returns the string form of the value.
165
166If the value is a block argument, this is the assembly form of its type and the
167position in the argument list. If the value is an operation result, this is
168equivalent to printing the operation that produced it.
169)";
170
171static const char kGetNameAsOperand[] =
172 R"(Returns the string form of value as an operand (i.e., the ValueID).
173)";
174
175static const char kValueReplaceAllUsesWithDocstring[] =
176 R"(Replace all uses of value with the new value, updating anything in
177the IR that uses 'self' to use the other value instead.
178)";
179
180//------------------------------------------------------------------------------
181// Utilities.
182//------------------------------------------------------------------------------
183
184/// Helper for creating an @classmethod.
185template <class Func, typename... Args>
186py::object classmethod(Func f, Args... args) {
187 py::object cf = py::cpp_function(f, args...);
188 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
189}
190
191static py::object
192createCustomDialectWrapper(const std::string &dialectNamespace,
193 py::object dialectDescriptor) {
194 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
195 if (!dialectClass) {
196 // Use the base class.
197 return py::cast(PyDialect(std::move(dialectDescriptor)));
198 }
199
200 // Create the custom implementation.
201 return (*dialectClass)(std::move(dialectDescriptor));
202}
203
204static MlirStringRef toMlirStringRef(const std::string &s) {
205 return mlirStringRefCreate(s.data(), s.size());
206}
207
208/// Create a block, using the current location context if no locations are
209/// specified.
210static MlirBlock createBlock(const py::sequence &pyArgTypes,
211 const std::optional<py::sequence> &pyArgLocs) {
212 SmallVector<MlirType> argTypes;
213 argTypes.reserve(pyArgTypes.size());
214 for (const auto &pyType : pyArgTypes)
215 argTypes.push_back(pyType.cast<PyType &>());
216
217 SmallVector<MlirLocation> argLocs;
218 if (pyArgLocs) {
219 argLocs.reserve(pyArgLocs->size());
220 for (const auto &pyLoc : *pyArgLocs)
221 argLocs.push_back(pyLoc.cast<PyLocation &>());
222 } else if (!argTypes.empty()) {
223 argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
224 }
225
226 if (argTypes.size() != argLocs.size())
227 throw py::value_error(("Expected " + Twine(argTypes.size()) +
228 " locations, got: " + Twine(argLocs.size()))
229 .str());
230 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
231}
232
233/// Wrapper for the global LLVM debugging flag.
234struct PyGlobalDebugFlag {
235 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
236
237 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
238
239 static void bind(py::module &m) {
240 // Debug flags.
241 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
242 .def_property_static("flag", &PyGlobalDebugFlag::get,
243 &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
244 }
245};
246
247struct PyAttrBuilderMap {
248 static bool dunderContains(const std::string &attributeKind) {
249 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
250 }
251 static py::function dundeGetItemNamed(const std::string &attributeKind) {
252 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
253 if (!builder)
254 throw py::key_error(attributeKind);
255 return *builder;
256 }
257 static void dundeSetItemNamed(const std::string &attributeKind,
258 py::function func, bool replace) {
259 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
260 replace);
261 }
262
263 static void bind(py::module &m) {
264 py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
265 .def_static("contains", &PyAttrBuilderMap::dunderContains)
266 .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
267 .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
268 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
269 "Register an attribute builder for building MLIR "
270 "attributes from python values.");
271 }
272};
273
274//------------------------------------------------------------------------------
275// PyBlock
276//------------------------------------------------------------------------------
277
278py::object PyBlock::getCapsule() {
279 return py::reinterpret_steal<py::object>(mlirPythonBlockToCapsule(get()));
280}
281
282//------------------------------------------------------------------------------
283// Collections.
284//------------------------------------------------------------------------------
285
286namespace {
287
288class PyRegionIterator {
289public:
290 PyRegionIterator(PyOperationRef operation)
291 : operation(std::move(operation)) {}
292
293 PyRegionIterator &dunderIter() { return *this; }
294
295 PyRegion dunderNext() {
296 operation->checkValid();
297 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
298 throw py::stop_iteration();
299 }
300 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
301 return PyRegion(operation, region);
302 }
303
304 static void bind(py::module &m) {
305 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
306 .def("__iter__", &PyRegionIterator::dunderIter)
307 .def("__next__", &PyRegionIterator::dunderNext);
308 }
309
310private:
311 PyOperationRef operation;
312 int nextIndex = 0;
313};
314
315/// Regions of an op are fixed length and indexed numerically so are represented
316/// with a sequence-like container.
317class PyRegionList {
318public:
319 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
320
321 PyRegionIterator dunderIter() {
322 operation->checkValid();
323 return PyRegionIterator(operation);
324 }
325
326 intptr_t dunderLen() {
327 operation->checkValid();
328 return mlirOperationGetNumRegions(operation->get());
329 }
330
331 PyRegion dunderGetItem(intptr_t index) {
332 // dunderLen checks validity.
333 if (index < 0 || index >= dunderLen()) {
334 throw py::index_error("attempt to access out of bounds region");
335 }
336 MlirRegion region = mlirOperationGetRegion(operation->get(), index);
337 return PyRegion(operation, region);
338 }
339
340 static void bind(py::module &m) {
341 py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
342 .def("__len__", &PyRegionList::dunderLen)
343 .def("__iter__", &PyRegionList::dunderIter)
344 .def("__getitem__", &PyRegionList::dunderGetItem);
345 }
346
347private:
348 PyOperationRef operation;
349};
350
351class PyBlockIterator {
352public:
353 PyBlockIterator(PyOperationRef operation, MlirBlock next)
354 : operation(std::move(operation)), next(next) {}
355
356 PyBlockIterator &dunderIter() { return *this; }
357
358 PyBlock dunderNext() {
359 operation->checkValid();
360 if (mlirBlockIsNull(next)) {
361 throw py::stop_iteration();
362 }
363
364 PyBlock returnBlock(operation, next);
365 next = mlirBlockGetNextInRegion(next);
366 return returnBlock;
367 }
368
369 static void bind(py::module &m) {
370 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
371 .def("__iter__", &PyBlockIterator::dunderIter)
372 .def("__next__", &PyBlockIterator::dunderNext);
373 }
374
375private:
376 PyOperationRef operation;
377 MlirBlock next;
378};
379
380/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
381/// we present them as a more full-featured list-like container but optimize
382/// it for forward iteration. Blocks are always owned by a region.
383class PyBlockList {
384public:
385 PyBlockList(PyOperationRef operation, MlirRegion region)
386 : operation(std::move(operation)), region(region) {}
387
388 PyBlockIterator dunderIter() {
389 operation->checkValid();
390 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
391 }
392
393 intptr_t dunderLen() {
394 operation->checkValid();
395 intptr_t count = 0;
396 MlirBlock block = mlirRegionGetFirstBlock(region);
397 while (!mlirBlockIsNull(block)) {
398 count += 1;
399 block = mlirBlockGetNextInRegion(block);
400 }
401 return count;
402 }
403
404 PyBlock dunderGetItem(intptr_t index) {
405 operation->checkValid();
406 if (index < 0) {
407 throw py::index_error("attempt to access out of bounds block");
408 }
409 MlirBlock block = mlirRegionGetFirstBlock(region);
410 while (!mlirBlockIsNull(block)) {
411 if (index == 0) {
412 return PyBlock(operation, block);
413 }
414 block = mlirBlockGetNextInRegion(block);
415 index -= 1;
416 }
417 throw py::index_error("attempt to access out of bounds block");
418 }
419
420 PyBlock appendBlock(const py::args &pyArgTypes,
421 const std::optional<py::sequence> &pyArgLocs) {
422 operation->checkValid();
423 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
424 mlirRegionAppendOwnedBlock(region, block);
425 return PyBlock(operation, block);
426 }
427
428 static void bind(py::module &m) {
429 py::class_<PyBlockList>(m, "BlockList", py::module_local())
430 .def("__getitem__", &PyBlockList::dunderGetItem)
431 .def("__iter__", &PyBlockList::dunderIter)
432 .def("__len__", &PyBlockList::dunderLen)
433 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
434 py::arg("arg_locs") = std::nullopt);
435 }
436
437private:
438 PyOperationRef operation;
439 MlirRegion region;
440};
441
442class PyOperationIterator {
443public:
444 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
445 : parentOperation(std::move(parentOperation)), next(next) {}
446
447 PyOperationIterator &dunderIter() { return *this; }
448
449 py::object dunderNext() {
450 parentOperation->checkValid();
451 if (mlirOperationIsNull(next)) {
452 throw py::stop_iteration();
453 }
454
455 PyOperationRef returnOperation =
456 PyOperation::forOperation(parentOperation->getContext(), next);
457 next = mlirOperationGetNextInBlock(next);
458 return returnOperation->createOpView();
459 }
460
461 static void bind(py::module &m) {
462 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
463 .def("__iter__", &PyOperationIterator::dunderIter)
464 .def("__next__", &PyOperationIterator::dunderNext);
465 }
466
467private:
468 PyOperationRef parentOperation;
469 MlirOperation next;
470};
471
472/// Operations are exposed by the C-API as a forward-only linked list. In
473/// Python, we present them as a more full-featured list-like container but
474/// optimize it for forward iteration. Iterable operations are always owned
475/// by a block.
476class PyOperationList {
477public:
478 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
479 : parentOperation(std::move(parentOperation)), block(block) {}
480
481 PyOperationIterator dunderIter() {
482 parentOperation->checkValid();
483 return PyOperationIterator(parentOperation,
484 mlirBlockGetFirstOperation(block));
485 }
486
487 intptr_t dunderLen() {
488 parentOperation->checkValid();
489 intptr_t count = 0;
490 MlirOperation childOp = mlirBlockGetFirstOperation(block);
491 while (!mlirOperationIsNull(childOp)) {
492 count += 1;
493 childOp = mlirOperationGetNextInBlock(childOp);
494 }
495 return count;
496 }
497
498 py::object dunderGetItem(intptr_t index) {
499 parentOperation->checkValid();
500 if (index < 0) {
501 throw py::index_error("attempt to access out of bounds operation");
502 }
503 MlirOperation childOp = mlirBlockGetFirstOperation(block);
504 while (!mlirOperationIsNull(childOp)) {
505 if (index == 0) {
506 return PyOperation::forOperation(parentOperation->getContext(), childOp)
507 ->createOpView();
508 }
509 childOp = mlirOperationGetNextInBlock(childOp);
510 index -= 1;
511 }
512 throw py::index_error("attempt to access out of bounds operation");
513 }
514
515 static void bind(py::module &m) {
516 py::class_<PyOperationList>(m, "OperationList", py::module_local())
517 .def("__getitem__", &PyOperationList::dunderGetItem)
518 .def("__iter__", &PyOperationList::dunderIter)
519 .def("__len__", &PyOperationList::dunderLen);
520 }
521
522private:
523 PyOperationRef parentOperation;
524 MlirBlock block;
525};
526
527class PyOpOperand {
528public:
529 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
530
531 py::object getOwner() {
532 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
533 PyMlirContextRef context =
534 PyMlirContext::forContext(mlirOperationGetContext(owner));
535 return PyOperation::forOperation(context, owner)->createOpView();
536 }
537
538 size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
539
540 static void bind(py::module &m) {
541 py::class_<PyOpOperand>(m, "OpOperand", py::module_local())
542 .def_property_readonly("owner", &PyOpOperand::getOwner)
543 .def_property_readonly("operand_number",
544 &PyOpOperand::getOperandNumber);
545 }
546
547private:
548 MlirOpOperand opOperand;
549};
550
551class PyOpOperandIterator {
552public:
553 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
554
555 PyOpOperandIterator &dunderIter() { return *this; }
556
557 PyOpOperand dunderNext() {
558 if (mlirOpOperandIsNull(opOperand))
559 throw py::stop_iteration();
560
561 PyOpOperand returnOpOperand(opOperand);
562 opOperand = mlirOpOperandGetNextUse(opOperand);
563 return returnOpOperand;
564 }
565
566 static void bind(py::module &m) {
567 py::class_<PyOpOperandIterator>(m, "OpOperandIterator", py::module_local())
568 .def("__iter__", &PyOpOperandIterator::dunderIter)
569 .def("__next__", &PyOpOperandIterator::dunderNext);
570 }
571
572private:
573 MlirOpOperand opOperand;
574};
575
576} // namespace
577
578//------------------------------------------------------------------------------
579// PyMlirContext
580//------------------------------------------------------------------------------
581
582PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
583 py::gil_scoped_acquire acquire;
584 auto &liveContexts = getLiveContexts();
585 liveContexts[context.ptr] = this;
586}
587
588PyMlirContext::~PyMlirContext() {
589 // Note that the only public way to construct an instance is via the
590 // forContext method, which always puts the associated handle into
591 // liveContexts.
592 py::gil_scoped_acquire acquire;
593 getLiveContexts().erase(context.ptr);
594 mlirContextDestroy(context);
595}
596
597py::object PyMlirContext::getCapsule() {
598 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
599}
600
601py::object PyMlirContext::createFromCapsule(py::object capsule) {
602 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
603 if (mlirContextIsNull(rawContext))
604 throw py::error_already_set();
605 return forContext(rawContext).releaseObject();
606}
607
608PyMlirContext *PyMlirContext::createNewContextForInit() {
609 MlirContext context = mlirContextCreateWithThreading(false);
610 return new PyMlirContext(context);
611}
612
613PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
614 py::gil_scoped_acquire acquire;
615 auto &liveContexts = getLiveContexts();
616 auto it = liveContexts.find(context.ptr);
617 if (it == liveContexts.end()) {
618 // Create.
619 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
620 py::object pyRef = py::cast(unownedContextWrapper);
621 assert(pyRef && "cast to py::object failed");
622 liveContexts[context.ptr] = unownedContextWrapper;
623 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
624 }
625 // Use existing.
626 py::object pyRef = py::cast(it->second);
627 return PyMlirContextRef(it->second, std::move(pyRef));
628}
629
630PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
631 static LiveContextMap liveContexts;
632 return liveContexts;
633}
634
635size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
636
637size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
638
639std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
640 std::vector<PyOperation *> liveObjects;
641 for (auto &entry : liveOperations)
642 liveObjects.push_back(entry.second.second);
643 return liveObjects;
644}
645
646size_t PyMlirContext::clearLiveOperations() {
647 for (auto &op : liveOperations)
648 op.second.second->setInvalid();
649 size_t numInvalidated = liveOperations.size();
650 liveOperations.clear();
651 return numInvalidated;
652}
653
654void PyMlirContext::clearOperation(MlirOperation op) {
655 auto it = liveOperations.find(op.ptr);
656 if (it != liveOperations.end()) {
657 it->second.second->setInvalid();
658 liveOperations.erase(it);
659 }
660}
661
662void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
663 typedef struct {
664 PyOperation &rootOp;
665 bool rootSeen;
666 } callBackData;
667 callBackData data{.rootOp: op.getOperation(), .rootSeen: false};
668 // Mark all ops below the op that the passmanager will be rooted
669 // at (but not op itself - note the preorder) as invalid.
670 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
671 void *userData) {
672 callBackData *data = static_cast<callBackData *>(userData);
673 if (LLVM_LIKELY(data->rootSeen))
674 data->rootOp.getOperation().getContext()->clearOperation(op);
675 else
676 data->rootSeen = true;
677 return MlirWalkResult::MlirWalkResultAdvance;
678 };
679 mlirOperationWalk(op.getOperation(), invalidatingCallback,
680 static_cast<void *>(&data), MlirWalkPreOrder);
681}
682void PyMlirContext::clearOperationsInside(MlirOperation op) {
683 PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
684 clearOperationsInside(opRef->getOperation());
685}
686
687size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
688
689pybind11::object PyMlirContext::contextEnter() {
690 return PyThreadContextEntry::pushContext(*this);
691}
692
693void PyMlirContext::contextExit(const pybind11::object &excType,
694 const pybind11::object &excVal,
695 const pybind11::object &excTb) {
696 PyThreadContextEntry::popContext(context&: *this);
697}
698
699py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
700 // Note that ownership is transferred to the delete callback below by way of
701 // an explicit inc_ref (borrow).
702 PyDiagnosticHandler *pyHandler =
703 new PyDiagnosticHandler(get(), std::move(callback));
704 py::object pyHandlerObject =
705 py::cast(pyHandler, py::return_value_policy::take_ownership);
706 pyHandlerObject.inc_ref();
707
708 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
709 // guaranteed to be known to pybind.
710 auto handlerCallback =
711 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
712 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
713 py::object pyDiagnosticObject =
714 py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
715
716 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
717 bool result = false;
718 {
719 // Since this can be called from arbitrary C++ contexts, always get the
720 // gil.
721 py::gil_scoped_acquire gil;
722 try {
723 result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
724 } catch (std::exception &e) {
725 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
726 e.what());
727 pyHandler->hadError = true;
728 }
729 }
730
731 pyDiagnostic->invalidate();
732 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
733 };
734 auto deleteCallback = +[](void *userData) {
735 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
736 assert(pyHandler->registeredID && "handler is not registered");
737 pyHandler->registeredID.reset();
738
739 // Decrement reference, balancing the inc_ref() above.
740 py::object pyHandlerObject =
741 py::cast(pyHandler, py::return_value_policy::reference);
742 pyHandlerObject.dec_ref();
743 };
744
745 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
746 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
747 return pyHandlerObject;
748}
749
750MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
751 void *userData) {
752 auto *self = static_cast<ErrorCapture *>(userData);
753 // Check if the context requested we emit errors instead of capturing them.
754 if (self->ctx->emitErrorDiagnostics)
755 return mlirLogicalResultFailure();
756
757 if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
758 return mlirLogicalResultFailure();
759
760 self->errors.emplace_back(args: PyDiagnostic(diag).getInfo());
761 return mlirLogicalResultSuccess();
762}
763
764PyMlirContext &DefaultingPyMlirContext::resolve() {
765 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
766 if (!context) {
767 throw std::runtime_error(
768 "An MLIR function requires a Context but none was provided in the call "
769 "or from the surrounding environment. Either pass to the function with "
770 "a 'context=' argument or establish a default using 'with Context():'");
771 }
772 return *context;
773}
774
775//------------------------------------------------------------------------------
776// PyThreadContextEntry management
777//------------------------------------------------------------------------------
778
779std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
780 static thread_local std::vector<PyThreadContextEntry> stack;
781 return stack;
782}
783
784PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
785 auto &stack = getStack();
786 if (stack.empty())
787 return nullptr;
788 return &stack.back();
789}
790
791void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
792 py::object insertionPoint,
793 py::object location) {
794 auto &stack = getStack();
795 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
796 std::move(location));
797 // If the new stack has more than one entry and the context of the new top
798 // entry matches the previous, copy the insertionPoint and location from the
799 // previous entry if missing from the new top entry.
800 if (stack.size() > 1) {
801 auto &prev = *(stack.rbegin() + 1);
802 auto &current = stack.back();
803 if (current.context.is(prev.context)) {
804 // Default non-context objects from the previous entry.
805 if (!current.insertionPoint)
806 current.insertionPoint = prev.insertionPoint;
807 if (!current.location)
808 current.location = prev.location;
809 }
810 }
811}
812
813PyMlirContext *PyThreadContextEntry::getContext() {
814 if (!context)
815 return nullptr;
816 return py::cast<PyMlirContext *>(context);
817}
818
819PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
820 if (!insertionPoint)
821 return nullptr;
822 return py::cast<PyInsertionPoint *>(insertionPoint);
823}
824
825PyLocation *PyThreadContextEntry::getLocation() {
826 if (!location)
827 return nullptr;
828 return py::cast<PyLocation *>(location);
829}
830
831PyMlirContext *PyThreadContextEntry::getDefaultContext() {
832 auto *tos = getTopOfStack();
833 return tos ? tos->getContext() : nullptr;
834}
835
836PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
837 auto *tos = getTopOfStack();
838 return tos ? tos->getInsertionPoint() : nullptr;
839}
840
841PyLocation *PyThreadContextEntry::getDefaultLocation() {
842 auto *tos = getTopOfStack();
843 return tos ? tos->getLocation() : nullptr;
844}
845
846py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
847 py::object contextObj = py::cast(context);
848 push(FrameKind::Context, /*context=*/contextObj,
849 /*insertionPoint=*/py::object(),
850 /*location=*/py::object());
851 return contextObj;
852}
853
854void PyThreadContextEntry::popContext(PyMlirContext &context) {
855 auto &stack = getStack();
856 if (stack.empty())
857 throw std::runtime_error("Unbalanced Context enter/exit");
858 auto &tos = stack.back();
859 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
860 throw std::runtime_error("Unbalanced Context enter/exit");
861 stack.pop_back();
862}
863
864py::object
865PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
866 py::object contextObj =
867 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
868 py::object insertionPointObj = py::cast(insertionPoint);
869 push(FrameKind::InsertionPoint,
870 /*context=*/contextObj,
871 /*insertionPoint=*/insertionPointObj,
872 /*location=*/py::object());
873 return insertionPointObj;
874}
875
876void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
877 auto &stack = getStack();
878 if (stack.empty())
879 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
880 auto &tos = stack.back();
881 if (tos.frameKind != FrameKind::InsertionPoint &&
882 tos.getInsertionPoint() != &insertionPoint)
883 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
884 stack.pop_back();
885}
886
887py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
888 py::object contextObj = location.getContext().getObject();
889 py::object locationObj = py::cast(location);
890 push(FrameKind::Location, /*context=*/contextObj,
891 /*insertionPoint=*/py::object(),
892 /*location=*/locationObj);
893 return locationObj;
894}
895
896void PyThreadContextEntry::popLocation(PyLocation &location) {
897 auto &stack = getStack();
898 if (stack.empty())
899 throw std::runtime_error("Unbalanced Location enter/exit");
900 auto &tos = stack.back();
901 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
902 throw std::runtime_error("Unbalanced Location enter/exit");
903 stack.pop_back();
904}
905
906//------------------------------------------------------------------------------
907// PyDiagnostic*
908//------------------------------------------------------------------------------
909
910void PyDiagnostic::invalidate() {
911 valid = false;
912 if (materializedNotes) {
913 for (auto &noteObject : *materializedNotes) {
914 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
915 note->invalidate();
916 }
917 }
918}
919
920PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
921 py::object callback)
922 : context(context), callback(std::move(callback)) {}
923
924PyDiagnosticHandler::~PyDiagnosticHandler() = default;
925
926void PyDiagnosticHandler::detach() {
927 if (!registeredID)
928 return;
929 MlirDiagnosticHandlerID localID = *registeredID;
930 mlirContextDetachDiagnosticHandler(context, localID);
931 assert(!registeredID && "should have unregistered");
932 // Not strictly necessary but keeps stale pointers from being around to cause
933 // issues.
934 context = {nullptr};
935}
936
937void PyDiagnostic::checkValid() {
938 if (!valid) {
939 throw std::invalid_argument(
940 "Diagnostic is invalid (used outside of callback)");
941 }
942}
943
944MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
945 checkValid();
946 return mlirDiagnosticGetSeverity(diagnostic);
947}
948
949PyLocation PyDiagnostic::getLocation() {
950 checkValid();
951 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
952 MlirContext context = mlirLocationGetContext(loc);
953 return PyLocation(PyMlirContext::forContext(context), loc);
954}
955
956py::str PyDiagnostic::getMessage() {
957 checkValid();
958 py::object fileObject = py::module::import("io").attr("StringIO")();
959 PyFileAccumulator accum(fileObject, /*binary=*/false);
960 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
961 return fileObject.attr("getvalue")();
962}
963
964py::tuple PyDiagnostic::getNotes() {
965 checkValid();
966 if (materializedNotes)
967 return *materializedNotes;
968 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
969 materializedNotes = py::tuple(numNotes);
970 for (intptr_t i = 0; i < numNotes; ++i) {
971 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
972 (*materializedNotes)[i] = PyDiagnostic(noteDiag);
973 }
974 return *materializedNotes;
975}
976
977PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
978 std::vector<DiagnosticInfo> notes;
979 for (py::handle n : getNotes())
980 notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
981 return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
982}
983
984//------------------------------------------------------------------------------
985// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
986//------------------------------------------------------------------------------
987
988MlirDialect PyDialects::getDialectForKey(const std::string &key,
989 bool attrError) {
990 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
991 {key.data(), key.size()});
992 if (mlirDialectIsNull(dialect)) {
993 std::string msg = (Twine("Dialect '") + key + "' not found").str();
994 if (attrError)
995 throw py::attribute_error(msg);
996 throw py::index_error(msg);
997 }
998 return dialect;
999}
1000
1001py::object PyDialectRegistry::getCapsule() {
1002 return py::reinterpret_steal<py::object>(
1003 mlirPythonDialectRegistryToCapsule(*this));
1004}
1005
1006PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
1007 MlirDialectRegistry rawRegistry =
1008 mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1009 if (mlirDialectRegistryIsNull(rawRegistry))
1010 throw py::error_already_set();
1011 return PyDialectRegistry(rawRegistry);
1012}
1013
1014//------------------------------------------------------------------------------
1015// PyLocation
1016//------------------------------------------------------------------------------
1017
1018py::object PyLocation::getCapsule() {
1019 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
1020}
1021
1022PyLocation PyLocation::createFromCapsule(py::object capsule) {
1023 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1024 if (mlirLocationIsNull(rawLoc))
1025 throw py::error_already_set();
1026 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
1027 rawLoc);
1028}
1029
1030py::object PyLocation::contextEnter() {
1031 return PyThreadContextEntry::pushLocation(*this);
1032}
1033
1034void PyLocation::contextExit(const pybind11::object &excType,
1035 const pybind11::object &excVal,
1036 const pybind11::object &excTb) {
1037 PyThreadContextEntry::popLocation(location&: *this);
1038}
1039
1040PyLocation &DefaultingPyLocation::resolve() {
1041 auto *location = PyThreadContextEntry::getDefaultLocation();
1042 if (!location) {
1043 throw std::runtime_error(
1044 "An MLIR function requires a Location but none was provided in the "
1045 "call or from the surrounding environment. Either pass to the function "
1046 "with a 'loc=' argument or establish a default using 'with loc:'");
1047 }
1048 return *location;
1049}
1050
1051//------------------------------------------------------------------------------
1052// PyModule
1053//------------------------------------------------------------------------------
1054
1055PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1056 : BaseContextObject(std::move(contextRef)), module(module) {}
1057
1058PyModule::~PyModule() {
1059 py::gil_scoped_acquire acquire;
1060 auto &liveModules = getContext()->liveModules;
1061 assert(liveModules.count(module.ptr) == 1 &&
1062 "destroying module not in live map");
1063 liveModules.erase(module.ptr);
1064 mlirModuleDestroy(module);
1065}
1066
1067PyModuleRef PyModule::forModule(MlirModule module) {
1068 MlirContext context = mlirModuleGetContext(module);
1069 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1070
1071 py::gil_scoped_acquire acquire;
1072 auto &liveModules = contextRef->liveModules;
1073 auto it = liveModules.find(module.ptr);
1074 if (it == liveModules.end()) {
1075 // Create.
1076 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1077 // Note that the default return value policy on cast is automatic_reference,
1078 // which does not take ownership (delete will not be called).
1079 // Just be explicit.
1080 py::object pyRef =
1081 py::cast(unownedModule, py::return_value_policy::take_ownership);
1082 unownedModule->handle = pyRef;
1083 liveModules[module.ptr] =
1084 std::make_pair(unownedModule->handle, unownedModule);
1085 return PyModuleRef(unownedModule, std::move(pyRef));
1086 }
1087 // Use existing.
1088 PyModule *existing = it->second.second;
1089 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1090 return PyModuleRef(existing, std::move(pyRef));
1091}
1092
1093py::object PyModule::createFromCapsule(py::object capsule) {
1094 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1095 if (mlirModuleIsNull(rawModule))
1096 throw py::error_already_set();
1097 return forModule(rawModule).releaseObject();
1098}
1099
1100py::object PyModule::getCapsule() {
1101 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
1102}
1103
1104//------------------------------------------------------------------------------
1105// PyOperation
1106//------------------------------------------------------------------------------
1107
1108PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1109 : BaseContextObject(std::move(contextRef)), operation(operation) {}
1110
1111PyOperation::~PyOperation() {
1112 // If the operation has already been invalidated there is nothing to do.
1113 if (!valid)
1114 return;
1115 auto &liveOperations = getContext()->liveOperations;
1116 assert(liveOperations.count(operation.ptr) == 1 &&
1117 "destroying operation not in live map");
1118 liveOperations.erase(operation.ptr);
1119 if (!isAttached()) {
1120 mlirOperationDestroy(operation);
1121 }
1122}
1123
1124PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1125 MlirOperation operation,
1126 py::object parentKeepAlive) {
1127 auto &liveOperations = contextRef->liveOperations;
1128 // Create.
1129 PyOperation *unownedOperation =
1130 new PyOperation(std::move(contextRef), operation);
1131 // Note that the default return value policy on cast is automatic_reference,
1132 // which does not take ownership (delete will not be called).
1133 // Just be explicit.
1134 py::object pyRef =
1135 py::cast(unownedOperation, py::return_value_policy::take_ownership);
1136 unownedOperation->handle = pyRef;
1137 if (parentKeepAlive) {
1138 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1139 }
1140 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
1141 return PyOperationRef(unownedOperation, std::move(pyRef));
1142}
1143
1144PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
1145 MlirOperation operation,
1146 py::object parentKeepAlive) {
1147 auto &liveOperations = contextRef->liveOperations;
1148 auto it = liveOperations.find(operation.ptr);
1149 if (it == liveOperations.end()) {
1150 // Create.
1151 return createInstance(std::move(contextRef), operation,
1152 std::move(parentKeepAlive));
1153 }
1154 // Use existing.
1155 PyOperation *existing = it->second.second;
1156 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1157 return PyOperationRef(existing, std::move(pyRef));
1158}
1159
1160PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
1161 MlirOperation operation,
1162 py::object parentKeepAlive) {
1163 auto &liveOperations = contextRef->liveOperations;
1164 assert(liveOperations.count(operation.ptr) == 0 &&
1165 "cannot create detached operation that already exists");
1166 (void)liveOperations;
1167
1168 PyOperationRef created = createInstance(std::move(contextRef), operation,
1169 std::move(parentKeepAlive));
1170 created->attached = false;
1171 return created;
1172}
1173
1174PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
1175 const std::string &sourceStr,
1176 const std::string &sourceName) {
1177 PyMlirContext::ErrorCapture errors(contextRef);
1178 MlirOperation op =
1179 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1180 toMlirStringRef(sourceName));
1181 if (mlirOperationIsNull(op))
1182 throw MLIRError("Unable to parse operation assembly", errors.take());
1183 return PyOperation::createDetached(std::move(contextRef), op);
1184}
1185
1186void PyOperation::checkValid() const {
1187 if (!valid) {
1188 throw std::runtime_error("the operation has been invalidated");
1189 }
1190}
1191
1192void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1193 bool enableDebugInfo, bool prettyDebugInfo,
1194 bool printGenericOpForm, bool useLocalScope,
1195 bool assumeVerified, py::object fileObject,
1196 bool binary) {
1197 PyOperation &operation = getOperation();
1198 operation.checkValid();
1199 if (fileObject.is_none())
1200 fileObject = py::module::import("sys").attr("stdout");
1201
1202 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1203 if (largeElementsLimit)
1204 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1205 if (enableDebugInfo)
1206 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1207 /*prettyForm=*/prettyDebugInfo);
1208 if (printGenericOpForm)
1209 mlirOpPrintingFlagsPrintGenericOpForm(flags);
1210 if (useLocalScope)
1211 mlirOpPrintingFlagsUseLocalScope(flags);
1212 if (assumeVerified)
1213 mlirOpPrintingFlagsAssumeVerified(flags);
1214
1215 PyFileAccumulator accum(fileObject, binary);
1216 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1217 accum.getUserData());
1218 mlirOpPrintingFlagsDestroy(flags);
1219}
1220
1221void PyOperationBase::print(PyAsmState &state, py::object fileObject,
1222 bool binary) {
1223 PyOperation &operation = getOperation();
1224 operation.checkValid();
1225 if (fileObject.is_none())
1226 fileObject = py::module::import("sys").attr("stdout");
1227 PyFileAccumulator accum(fileObject, binary);
1228 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1229 accum.getUserData());
1230}
1231
1232void PyOperationBase::writeBytecode(const py::object &fileObject,
1233 std::optional<int64_t> bytecodeVersion) {
1234 PyOperation &operation = getOperation();
1235 operation.checkValid();
1236 PyFileAccumulator accum(fileObject, /*binary=*/true);
1237
1238 if (!bytecodeVersion.has_value())
1239 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1240 accum.getUserData());
1241
1242 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1243 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1244 MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
1245 operation, config, accum.getCallback(), accum.getUserData());
1246 mlirBytecodeWriterConfigDestroy(config);
1247 if (mlirLogicalResultIsFailure(res))
1248 throw py::value_error((Twine("Unable to honor desired bytecode version ") +
1249 Twine(*bytecodeVersion))
1250 .str());
1251}
1252
1253void PyOperationBase::walk(
1254 std::function<MlirWalkResult(MlirOperation)> callback,
1255 MlirWalkOrder walkOrder) {
1256 PyOperation &operation = getOperation();
1257 operation.checkValid();
1258 struct UserData {
1259 std::function<MlirWalkResult(MlirOperation)> callback;
1260 bool gotException;
1261 std::string exceptionWhat;
1262 py::object exceptionType;
1263 };
1264 UserData userData{callback, false, {}, {}};
1265 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1266 void *userData) {
1267 UserData *calleeUserData = static_cast<UserData *>(userData);
1268 try {
1269 return (calleeUserData->callback)(op);
1270 } catch (py::error_already_set &e) {
1271 calleeUserData->gotException = true;
1272 calleeUserData->exceptionWhat = e.what();
1273 calleeUserData->exceptionType = e.type();
1274 return MlirWalkResult::MlirWalkResultInterrupt;
1275 }
1276 };
1277 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1278 if (userData.gotException) {
1279 std::string message("Exception raised in callback: ");
1280 message.append(str: userData.exceptionWhat);
1281 throw std::runtime_error(message);
1282 }
1283}
1284
1285py::object PyOperationBase::getAsm(bool binary,
1286 std::optional<int64_t> largeElementsLimit,
1287 bool enableDebugInfo, bool prettyDebugInfo,
1288 bool printGenericOpForm, bool useLocalScope,
1289 bool assumeVerified) {
1290 py::object fileObject;
1291 if (binary) {
1292 fileObject = py::module::import("io").attr("BytesIO")();
1293 } else {
1294 fileObject = py::module::import("io").attr("StringIO")();
1295 }
1296 print(/*largeElementsLimit=*/largeElementsLimit,
1297 /*enableDebugInfo=*/enableDebugInfo,
1298 /*prettyDebugInfo=*/prettyDebugInfo,
1299 /*printGenericOpForm=*/printGenericOpForm,
1300 /*useLocalScope=*/useLocalScope,
1301 /*assumeVerified=*/assumeVerified,
1302 /*fileObject=*/fileObject,
1303 /*binary=*/binary);
1304
1305 return fileObject.attr("getvalue")();
1306}
1307
1308void PyOperationBase::moveAfter(PyOperationBase &other) {
1309 PyOperation &operation = getOperation();
1310 PyOperation &otherOp = other.getOperation();
1311 operation.checkValid();
1312 otherOp.checkValid();
1313 mlirOperationMoveAfter(operation, otherOp);
1314 operation.parentKeepAlive = otherOp.parentKeepAlive;
1315}
1316
1317void PyOperationBase::moveBefore(PyOperationBase &other) {
1318 PyOperation &operation = getOperation();
1319 PyOperation &otherOp = other.getOperation();
1320 operation.checkValid();
1321 otherOp.checkValid();
1322 mlirOperationMoveBefore(operation, otherOp);
1323 operation.parentKeepAlive = otherOp.parentKeepAlive;
1324}
1325
1326bool PyOperationBase::verify() {
1327 PyOperation &op = getOperation();
1328 PyMlirContext::ErrorCapture errors(op.getContext());
1329 if (!mlirOperationVerify(op.get()))
1330 throw MLIRError("Verification failed", errors.take());
1331 return true;
1332}
1333
1334std::optional<PyOperationRef> PyOperation::getParentOperation() {
1335 checkValid();
1336 if (!isAttached())
1337 throw py::value_error("Detached operations have no parent");
1338 MlirOperation operation = mlirOperationGetParentOperation(get());
1339 if (mlirOperationIsNull(operation))
1340 return {};
1341 return PyOperation::forOperation(getContext(), operation);
1342}
1343
1344PyBlock PyOperation::getBlock() {
1345 checkValid();
1346 std::optional<PyOperationRef> parentOperation = getParentOperation();
1347 MlirBlock block = mlirOperationGetBlock(get());
1348 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1349 assert(parentOperation && "Operation has no parent");
1350 return PyBlock{std::move(*parentOperation), block};
1351}
1352
1353py::object PyOperation::getCapsule() {
1354 checkValid();
1355 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1356}
1357
1358py::object PyOperation::createFromCapsule(py::object capsule) {
1359 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1360 if (mlirOperationIsNull(rawOperation))
1361 throw py::error_already_set();
1362 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1363 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1364 .releaseObject();
1365}
1366
1367static void maybeInsertOperation(PyOperationRef &op,
1368 const py::object &maybeIp) {
1369 // InsertPoint active?
1370 if (!maybeIp.is(py::cast(false))) {
1371 PyInsertionPoint *ip;
1372 if (maybeIp.is_none()) {
1373 ip = PyThreadContextEntry::getDefaultInsertionPoint();
1374 } else {
1375 ip = py::cast<PyInsertionPoint *>(maybeIp);
1376 }
1377 if (ip)
1378 ip->insert(*op.get());
1379 }
1380}
1381
1382py::object PyOperation::create(const std::string &name,
1383 std::optional<std::vector<PyType *>> results,
1384 std::optional<std::vector<PyValue *>> operands,
1385 std::optional<py::dict> attributes,
1386 std::optional<std::vector<PyBlock *>> successors,
1387 int regions, DefaultingPyLocation location,
1388 const py::object &maybeIp, bool inferType) {
1389 llvm::SmallVector<MlirValue, 4> mlirOperands;
1390 llvm::SmallVector<MlirType, 4> mlirResults;
1391 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1392 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1393
1394 // General parameter validation.
1395 if (regions < 0)
1396 throw py::value_error("number of regions must be >= 0");
1397
1398 // Unpack/validate operands.
1399 if (operands) {
1400 mlirOperands.reserve(operands->size());
1401 for (PyValue *operand : *operands) {
1402 if (!operand)
1403 throw py::value_error("operand value cannot be None");
1404 mlirOperands.push_back(operand->get());
1405 }
1406 }
1407
1408 // Unpack/validate results.
1409 if (results) {
1410 mlirResults.reserve(results->size());
1411 for (PyType *result : *results) {
1412 // TODO: Verify result type originate from the same context.
1413 if (!result)
1414 throw py::value_error("result type cannot be None");
1415 mlirResults.push_back(*result);
1416 }
1417 }
1418 // Unpack/validate attributes.
1419 if (attributes) {
1420 mlirAttributes.reserve(attributes->size());
1421 for (auto &it : *attributes) {
1422 std::string key;
1423 try {
1424 key = it.first.cast<std::string>();
1425 } catch (py::cast_error &err) {
1426 std::string msg = "Invalid attribute key (not a string) when "
1427 "attempting to create the operation \"" +
1428 name + "\" (" + err.what() + ")";
1429 throw py::cast_error(msg);
1430 }
1431 try {
1432 auto &attribute = it.second.cast<PyAttribute &>();
1433 // TODO: Verify attribute originates from the same context.
1434 mlirAttributes.emplace_back(std::move(key), attribute);
1435 } catch (py::reference_cast_error &) {
1436 // This exception seems thrown when the value is "None".
1437 std::string msg =
1438 "Found an invalid (`None`?) attribute value for the key \"" + key +
1439 "\" when attempting to create the operation \"" + name + "\"";
1440 throw py::cast_error(msg);
1441 } catch (py::cast_error &err) {
1442 std::string msg = "Invalid attribute value for the key \"" + key +
1443 "\" when attempting to create the operation \"" +
1444 name + "\" (" + err.what() + ")";
1445 throw py::cast_error(msg);
1446 }
1447 }
1448 }
1449 // Unpack/validate successors.
1450 if (successors) {
1451 mlirSuccessors.reserve(successors->size());
1452 for (auto *successor : *successors) {
1453 // TODO: Verify successor originate from the same context.
1454 if (!successor)
1455 throw py::value_error("successor block cannot be None");
1456 mlirSuccessors.push_back(successor->get());
1457 }
1458 }
1459
1460 // Apply unpacked/validated to the operation state. Beyond this
1461 // point, exceptions cannot be thrown or else the state will leak.
1462 MlirOperationState state =
1463 mlirOperationStateGet(toMlirStringRef(name), location);
1464 if (!mlirOperands.empty())
1465 mlirOperationStateAddOperands(&state, mlirOperands.size(),
1466 mlirOperands.data());
1467 state.enableResultTypeInference = inferType;
1468 if (!mlirResults.empty())
1469 mlirOperationStateAddResults(&state, mlirResults.size(),
1470 mlirResults.data());
1471 if (!mlirAttributes.empty()) {
1472 // Note that the attribute names directly reference bytes in
1473 // mlirAttributes, so that vector must not be changed from here
1474 // on.
1475 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1476 mlirNamedAttributes.reserve(mlirAttributes.size());
1477 for (auto &it : mlirAttributes)
1478 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1479 mlirIdentifierGet(mlirAttributeGetContext(it.second),
1480 toMlirStringRef(it.first)),
1481 it.second));
1482 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1483 mlirNamedAttributes.data());
1484 }
1485 if (!mlirSuccessors.empty())
1486 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1487 mlirSuccessors.data());
1488 if (regions) {
1489 llvm::SmallVector<MlirRegion, 4> mlirRegions;
1490 mlirRegions.resize(regions);
1491 for (int i = 0; i < regions; ++i)
1492 mlirRegions[i] = mlirRegionCreate();
1493 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1494 mlirRegions.data());
1495 }
1496
1497 // Construct the operation.
1498 MlirOperation operation = mlirOperationCreate(&state);
1499 if (!operation.ptr)
1500 throw py::value_error("Operation creation failed");
1501 PyOperationRef created =
1502 PyOperation::createDetached(location->getContext(), operation);
1503 maybeInsertOperation(created, maybeIp);
1504
1505 return created->createOpView();
1506}
1507
1508py::object PyOperation::clone(const py::object &maybeIp) {
1509 MlirOperation clonedOperation = mlirOperationClone(operation);
1510 PyOperationRef cloned =
1511 PyOperation::createDetached(getContext(), clonedOperation);
1512 maybeInsertOperation(cloned, maybeIp);
1513
1514 return cloned->createOpView();
1515}
1516
1517py::object PyOperation::createOpView() {
1518 checkValid();
1519 MlirIdentifier ident = mlirOperationGetName(get());
1520 MlirStringRef identStr = mlirIdentifierStr(ident);
1521 auto operationCls = PyGlobals::get().lookupOperationClass(
1522 StringRef(identStr.data, identStr.length));
1523 if (operationCls)
1524 return PyOpView::constructDerived(*operationCls, *getRef().get());
1525 return py::cast(PyOpView(getRef().getObject()));
1526}
1527
1528void PyOperation::erase() {
1529 checkValid();
1530 // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1531 // Python reference to a child operation is live. All children should also
1532 // have their `valid` bit set to false.
1533 auto &liveOperations = getContext()->liveOperations;
1534 if (liveOperations.count(operation.ptr))
1535 liveOperations.erase(operation.ptr);
1536 mlirOperationDestroy(operation);
1537 valid = false;
1538}
1539
1540//------------------------------------------------------------------------------
1541// PyOpView
1542//------------------------------------------------------------------------------
1543
1544static void populateResultTypes(StringRef name, py::list resultTypeList,
1545 const py::object &resultSegmentSpecObj,
1546 std::vector<int32_t> &resultSegmentLengths,
1547 std::vector<PyType *> &resultTypes) {
1548 resultTypes.reserve(n: resultTypeList.size());
1549 if (resultSegmentSpecObj.is_none()) {
1550 // Non-variadic result unpacking.
1551 for (const auto &it : llvm::enumerate(resultTypeList)) {
1552 try {
1553 resultTypes.push_back(py::cast<PyType *>(it.value()));
1554 if (!resultTypes.back())
1555 throw py::cast_error();
1556 } catch (py::cast_error &err) {
1557 throw py::value_error((llvm::Twine("Result ") +
1558 llvm::Twine(it.index()) + " of operation \"" +
1559 name + "\" must be a Type (" + err.what() + ")")
1560 .str());
1561 }
1562 }
1563 } else {
1564 // Sized result unpacking.
1565 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1566 if (resultSegmentSpec.size() != resultTypeList.size()) {
1567 throw py::value_error((llvm::Twine("Operation \"") + name +
1568 "\" requires " +
1569 llvm::Twine(resultSegmentSpec.size()) +
1570 " result segments but was provided " +
1571 llvm::Twine(resultTypeList.size()))
1572 .str());
1573 }
1574 resultSegmentLengths.reserve(n: resultTypeList.size());
1575 for (const auto &it :
1576 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1577 int segmentSpec = std::get<1>(it.value());
1578 if (segmentSpec == 1 || segmentSpec == 0) {
1579 // Unpack unary element.
1580 try {
1581 auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1582 if (resultType) {
1583 resultTypes.push_back(resultType);
1584 resultSegmentLengths.push_back(1);
1585 } else if (segmentSpec == 0) {
1586 // Allowed to be optional.
1587 resultSegmentLengths.push_back(0);
1588 } else {
1589 throw py::cast_error("was None and result is not optional");
1590 }
1591 } catch (py::cast_error &err) {
1592 throw py::value_error((llvm::Twine("Result ") +
1593 llvm::Twine(it.index()) + " of operation \"" +
1594 name + "\" must be a Type (" + err.what() +
1595 ")")
1596 .str());
1597 }
1598 } else if (segmentSpec == -1) {
1599 // Unpack sequence by appending.
1600 try {
1601 if (std::get<0>(it.value()).is_none()) {
1602 // Treat it as an empty list.
1603 resultSegmentLengths.push_back(0);
1604 } else {
1605 // Unpack the list.
1606 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1607 for (py::object segmentItem : segment) {
1608 resultTypes.push_back(py::cast<PyType *>(segmentItem));
1609 if (!resultTypes.back()) {
1610 throw py::cast_error("contained a None item");
1611 }
1612 }
1613 resultSegmentLengths.push_back(segment.size());
1614 }
1615 } catch (std::exception &err) {
1616 // NOTE: Sloppy to be using a catch-all here, but there are at least
1617 // three different unrelated exceptions that can be thrown in the
1618 // above "casts". Just keep the scope above small and catch them all.
1619 throw py::value_error((llvm::Twine("Result ") +
1620 llvm::Twine(it.index()) + " of operation \"" +
1621 name + "\" must be a Sequence of Types (" +
1622 err.what() + ")")
1623 .str());
1624 }
1625 } else {
1626 throw py::value_error("Unexpected segment spec");
1627 }
1628 }
1629 }
1630}
1631
1632py::object PyOpView::buildGeneric(
1633 const py::object &cls, std::optional<py::list> resultTypeList,
1634 py::list operandList, std::optional<py::dict> attributes,
1635 std::optional<std::vector<PyBlock *>> successors,
1636 std::optional<int> regions, DefaultingPyLocation location,
1637 const py::object &maybeIp) {
1638 PyMlirContextRef context = location->getContext();
1639 // Class level operation construction metadata.
1640 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1641 // Operand and result segment specs are either none, which does no
1642 // variadic unpacking, or a list of ints with segment sizes, where each
1643 // element is either a positive number (typically 1 for a scalar) or -1 to
1644 // indicate that it is derived from the length of the same-indexed operand
1645 // or result (implying that it is a list at that position).
1646 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1647 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1648
1649 std::vector<int32_t> operandSegmentLengths;
1650 std::vector<int32_t> resultSegmentLengths;
1651
1652 // Validate/determine region count.
1653 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1654 int opMinRegionCount = std::get<0>(opRegionSpec);
1655 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1656 if (!regions) {
1657 regions = opMinRegionCount;
1658 }
1659 if (*regions < opMinRegionCount) {
1660 throw py::value_error(
1661 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1662 llvm::Twine(opMinRegionCount) +
1663 " regions but was built with regions=" + llvm::Twine(*regions))
1664 .str());
1665 }
1666 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1667 throw py::value_error(
1668 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1669 llvm::Twine(opMinRegionCount) +
1670 " regions but was built with regions=" + llvm::Twine(*regions))
1671 .str());
1672 }
1673
1674 // Unpack results.
1675 std::vector<PyType *> resultTypes;
1676 if (resultTypeList.has_value()) {
1677 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1678 resultSegmentLengths, resultTypes);
1679 }
1680
1681 // Unpack operands.
1682 std::vector<PyValue *> operands;
1683 operands.reserve(n: operands.size());
1684 if (operandSegmentSpecObj.is_none()) {
1685 // Non-sized operand unpacking.
1686 for (const auto &it : llvm::enumerate(operandList)) {
1687 try {
1688 operands.push_back(py::cast<PyValue *>(it.value()));
1689 if (!operands.back())
1690 throw py::cast_error();
1691 } catch (py::cast_error &err) {
1692 throw py::value_error((llvm::Twine("Operand ") +
1693 llvm::Twine(it.index()) + " of operation \"" +
1694 name + "\" must be a Value (" + err.what() + ")")
1695 .str());
1696 }
1697 }
1698 } else {
1699 // Sized operand unpacking.
1700 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1701 if (operandSegmentSpec.size() != operandList.size()) {
1702 throw py::value_error((llvm::Twine("Operation \"") + name +
1703 "\" requires " +
1704 llvm::Twine(operandSegmentSpec.size()) +
1705 "operand segments but was provided " +
1706 llvm::Twine(operandList.size()))
1707 .str());
1708 }
1709 operandSegmentLengths.reserve(n: operandList.size());
1710 for (const auto &it :
1711 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1712 int segmentSpec = std::get<1>(it.value());
1713 if (segmentSpec == 1 || segmentSpec == 0) {
1714 // Unpack unary element.
1715 try {
1716 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1717 if (operandValue) {
1718 operands.push_back(operandValue);
1719 operandSegmentLengths.push_back(1);
1720 } else if (segmentSpec == 0) {
1721 // Allowed to be optional.
1722 operandSegmentLengths.push_back(0);
1723 } else {
1724 throw py::cast_error("was None and operand is not optional");
1725 }
1726 } catch (py::cast_error &err) {
1727 throw py::value_error((llvm::Twine("Operand ") +
1728 llvm::Twine(it.index()) + " of operation \"" +
1729 name + "\" must be a Value (" + err.what() +
1730 ")")
1731 .str());
1732 }
1733 } else if (segmentSpec == -1) {
1734 // Unpack sequence by appending.
1735 try {
1736 if (std::get<0>(it.value()).is_none()) {
1737 // Treat it as an empty list.
1738 operandSegmentLengths.push_back(0);
1739 } else {
1740 // Unpack the list.
1741 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1742 for (py::object segmentItem : segment) {
1743 operands.push_back(py::cast<PyValue *>(segmentItem));
1744 if (!operands.back()) {
1745 throw py::cast_error("contained a None item");
1746 }
1747 }
1748 operandSegmentLengths.push_back(segment.size());
1749 }
1750 } catch (std::exception &err) {
1751 // NOTE: Sloppy to be using a catch-all here, but there are at least
1752 // three different unrelated exceptions that can be thrown in the
1753 // above "casts". Just keep the scope above small and catch them all.
1754 throw py::value_error((llvm::Twine("Operand ") +
1755 llvm::Twine(it.index()) + " of operation \"" +
1756 name + "\" must be a Sequence of Values (" +
1757 err.what() + ")")
1758 .str());
1759 }
1760 } else {
1761 throw py::value_error("Unexpected segment spec");
1762 }
1763 }
1764 }
1765
1766 // Merge operand/result segment lengths into attributes if needed.
1767 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1768 // Dup.
1769 if (attributes) {
1770 attributes = py::dict(*attributes);
1771 } else {
1772 attributes = py::dict();
1773 }
1774 if (attributes->contains("resultSegmentSizes") ||
1775 attributes->contains("operandSegmentSizes")) {
1776 throw py::value_error("Manually setting a 'resultSegmentSizes' or "
1777 "'operandSegmentSizes' attribute is unsupported. "
1778 "Use Operation.create for such low-level access.");
1779 }
1780
1781 // Add resultSegmentSizes attribute.
1782 if (!resultSegmentLengths.empty()) {
1783 MlirAttribute segmentLengthAttr =
1784 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1785 resultSegmentLengths.data());
1786 (*attributes)["resultSegmentSizes"] =
1787 PyAttribute(context, segmentLengthAttr);
1788 }
1789
1790 // Add operandSegmentSizes attribute.
1791 if (!operandSegmentLengths.empty()) {
1792 MlirAttribute segmentLengthAttr =
1793 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1794 operandSegmentLengths.data());
1795 (*attributes)["operandSegmentSizes"] =
1796 PyAttribute(context, segmentLengthAttr);
1797 }
1798 }
1799
1800 // Delegate to create.
1801 return PyOperation::create(name,
1802 /*results=*/std::move(resultTypes),
1803 /*operands=*/std::move(operands),
1804 /*attributes=*/std::move(attributes),
1805 /*successors=*/std::move(successors),
1806 /*regions=*/*regions, location, maybeIp,
1807 !resultTypeList);
1808}
1809
1810pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
1811 const PyOperation &operation) {
1812 // TODO: pybind11 2.6 supports a more direct form.
1813 // Upgrade many years from now.
1814 // auto opViewType = py::type::of<PyOpView>();
1815 py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1816 py::object instance = cls.attr("__new__")(cls);
1817 opViewType.attr("__init__")(instance, operation);
1818 return instance;
1819}
1820
1821PyOpView::PyOpView(const py::object &operationObject)
1822 // Casting through the PyOperationBase base-class and then back to the
1823 // Operation lets us accept any PyOperationBase subclass.
1824 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1825 operationObject(operation.getRef().getObject()) {}
1826
1827//------------------------------------------------------------------------------
1828// PyInsertionPoint.
1829//------------------------------------------------------------------------------
1830
1831PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1832
1833PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1834 : refOperation(beforeOperationBase.getOperation().getRef()),
1835 block((*refOperation)->getBlock()) {}
1836
1837void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1838 PyOperation &operation = operationBase.getOperation();
1839 if (operation.isAttached())
1840 throw py::value_error(
1841 "Attempt to insert operation that is already attached");
1842 block.getParentOperation()->checkValid();
1843 MlirOperation beforeOp = {nullptr};
1844 if (refOperation) {
1845 // Insert before operation.
1846 (*refOperation)->checkValid();
1847 beforeOp = (*refOperation)->get();
1848 } else {
1849 // Insert at end (before null) is only valid if the block does not
1850 // already end in a known terminator (violating this will cause assertion
1851 // failures later).
1852 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1853 throw py::index_error("Cannot insert operation at the end of a block "
1854 "that already has a terminator. Did you mean to "
1855 "use 'InsertionPoint.at_block_terminator(block)' "
1856 "versus 'InsertionPoint(block)'?");
1857 }
1858 }
1859 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1860 operation.setAttached();
1861}
1862
1863PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1864 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1865 if (mlirOperationIsNull(firstOp)) {
1866 // Just insert at end.
1867 return PyInsertionPoint(block);
1868 }
1869
1870 // Insert before first op.
1871 PyOperationRef firstOpRef = PyOperation::forOperation(
1872 block.getParentOperation()->getContext(), firstOp);
1873 return PyInsertionPoint{block, std::move(firstOpRef)};
1874}
1875
1876PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1877 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1878 if (mlirOperationIsNull(terminator))
1879 throw py::value_error("Block has no terminator");
1880 PyOperationRef terminatorOpRef = PyOperation::forOperation(
1881 block.getParentOperation()->getContext(), terminator);
1882 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1883}
1884
1885py::object PyInsertionPoint::contextEnter() {
1886 return PyThreadContextEntry::pushInsertionPoint(*this);
1887}
1888
1889void PyInsertionPoint::contextExit(const pybind11::object &excType,
1890 const pybind11::object &excVal,
1891 const pybind11::object &excTb) {
1892 PyThreadContextEntry::popInsertionPoint(insertionPoint&: *this);
1893}
1894
1895//------------------------------------------------------------------------------
1896// PyAttribute.
1897//------------------------------------------------------------------------------
1898
1899bool PyAttribute::operator==(const PyAttribute &other) const {
1900 return mlirAttributeEqual(attr, other.attr);
1901}
1902
1903py::object PyAttribute::getCapsule() {
1904 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1905}
1906
1907PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1908 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1909 if (mlirAttributeIsNull(rawAttr))
1910 throw py::error_already_set();
1911 return PyAttribute(
1912 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1913}
1914
1915//------------------------------------------------------------------------------
1916// PyNamedAttribute.
1917//------------------------------------------------------------------------------
1918
1919PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1920 : ownedName(new std::string(std::move(ownedName))) {
1921 namedAttr = mlirNamedAttributeGet(
1922 mlirIdentifierGet(mlirAttributeGetContext(attr),
1923 toMlirStringRef(*this->ownedName)),
1924 attr);
1925}
1926
1927//------------------------------------------------------------------------------
1928// PyType.
1929//------------------------------------------------------------------------------
1930
1931bool PyType::operator==(const PyType &other) const {
1932 return mlirTypeEqual(type, other.type);
1933}
1934
1935py::object PyType::getCapsule() {
1936 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1937}
1938
1939PyType PyType::createFromCapsule(py::object capsule) {
1940 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1941 if (mlirTypeIsNull(rawType))
1942 throw py::error_already_set();
1943 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1944 rawType);
1945}
1946
1947//------------------------------------------------------------------------------
1948// PyTypeID.
1949//------------------------------------------------------------------------------
1950
1951py::object PyTypeID::getCapsule() {
1952 return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
1953}
1954
1955PyTypeID PyTypeID::createFromCapsule(py::object capsule) {
1956 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1957 if (mlirTypeIDIsNull(mlirTypeID))
1958 throw py::error_already_set();
1959 return PyTypeID(mlirTypeID);
1960}
1961bool PyTypeID::operator==(const PyTypeID &other) const {
1962 return mlirTypeIDEqual(typeID, other.typeID);
1963}
1964
1965//------------------------------------------------------------------------------
1966// PyValue and subclasses.
1967//------------------------------------------------------------------------------
1968
1969pybind11::object PyValue::getCapsule() {
1970 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1971}
1972
1973pybind11::object PyValue::maybeDownCast() {
1974 MlirType type = mlirValueGetType(get());
1975 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
1976 assert(!mlirTypeIDIsNull(mlirTypeID) &&
1977 "mlirTypeID was expected to be non-null.");
1978 std::optional<pybind11::function> valueCaster =
1979 PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
1980 // py::return_value_policy::move means use std::move to move the return value
1981 // contents into a new instance that will be owned by Python.
1982 py::object thisObj = py::cast(this, py::return_value_policy::move);
1983 if (!valueCaster)
1984 return thisObj;
1985 return valueCaster.value()(thisObj);
1986}
1987
1988PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1989 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1990 if (mlirValueIsNull(value))
1991 throw py::error_already_set();
1992 MlirOperation owner;
1993 if (mlirValueIsAOpResult(value))
1994 owner = mlirOpResultGetOwner(value);
1995 if (mlirValueIsABlockArgument(value))
1996 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1997 if (mlirOperationIsNull(owner))
1998 throw py::error_already_set();
1999 MlirContext ctx = mlirOperationGetContext(owner);
2000 PyOperationRef ownerRef =
2001 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
2002 return PyValue(ownerRef, value);
2003}
2004
2005//------------------------------------------------------------------------------
2006// PySymbolTable.
2007//------------------------------------------------------------------------------
2008
2009PySymbolTable::PySymbolTable(PyOperationBase &operation)
2010 : operation(operation.getOperation().getRef()) {
2011 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2012 if (mlirSymbolTableIsNull(symbolTable)) {
2013 throw py::cast_error("Operation is not a Symbol Table.");
2014 }
2015}
2016
2017py::object PySymbolTable::dunderGetItem(const std::string &name) {
2018 operation->checkValid();
2019 MlirOperation symbol = mlirSymbolTableLookup(
2020 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2021 if (mlirOperationIsNull(symbol))
2022 throw py::key_error("Symbol '" + name + "' not in the symbol table.");
2023
2024 return PyOperation::forOperation(operation->getContext(), symbol,
2025 operation.getObject())
2026 ->createOpView();
2027}
2028
2029void PySymbolTable::erase(PyOperationBase &symbol) {
2030 operation->checkValid();
2031 symbol.getOperation().checkValid();
2032 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2033 // The operation is also erased, so we must invalidate it. There may be Python
2034 // references to this operation so we don't want to delete it from the list of
2035 // live operations here.
2036 symbol.getOperation().valid = false;
2037}
2038
2039void PySymbolTable::dunderDel(const std::string &name) {
2040 py::object operation = dunderGetItem(name);
2041 erase(py::cast<PyOperationBase &>(operation));
2042}
2043
2044MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2045 operation->checkValid();
2046 symbol.getOperation().checkValid();
2047 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2048 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
2049 if (mlirAttributeIsNull(symbolAttr))
2050 throw py::value_error("Expected operation to have a symbol name.");
2051 return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2052}
2053
2054MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
2055 // Op must already be a symbol.
2056 PyOperation &operation = symbol.getOperation();
2057 operation.checkValid();
2058 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2059 MlirAttribute existingNameAttr =
2060 mlirOperationGetAttributeByName(operation.get(), attrName);
2061 if (mlirAttributeIsNull(existingNameAttr))
2062 throw py::value_error("Expected operation to have a symbol name.");
2063 return existingNameAttr;
2064}
2065
2066void PySymbolTable::setSymbolName(PyOperationBase &symbol,
2067 const std::string &name) {
2068 // Op must already be a symbol.
2069 PyOperation &operation = symbol.getOperation();
2070 operation.checkValid();
2071 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2072 MlirAttribute existingNameAttr =
2073 mlirOperationGetAttributeByName(operation.get(), attrName);
2074 if (mlirAttributeIsNull(existingNameAttr))
2075 throw py::value_error("Expected operation to have a symbol name.");
2076 MlirAttribute newNameAttr =
2077 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2078 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2079}
2080
2081MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
2082 PyOperation &operation = symbol.getOperation();
2083 operation.checkValid();
2084 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2085 MlirAttribute existingVisAttr =
2086 mlirOperationGetAttributeByName(operation.get(), attrName);
2087 if (mlirAttributeIsNull(existingVisAttr))
2088 throw py::value_error("Expected operation to have a symbol visibility.");
2089 return existingVisAttr;
2090}
2091
2092void PySymbolTable::setVisibility(PyOperationBase &symbol,
2093 const std::string &visibility) {
2094 if (visibility != "public" && visibility != "private" &&
2095 visibility != "nested")
2096 throw py::value_error(
2097 "Expected visibility to be 'public', 'private' or 'nested'");
2098 PyOperation &operation = symbol.getOperation();
2099 operation.checkValid();
2100 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2101 MlirAttribute existingVisAttr =
2102 mlirOperationGetAttributeByName(operation.get(), attrName);
2103 if (mlirAttributeIsNull(existingVisAttr))
2104 throw py::value_error("Expected operation to have a symbol visibility.");
2105 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2106 toMlirStringRef(visibility));
2107 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2108}
2109
2110void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2111 const std::string &newSymbol,
2112 PyOperationBase &from) {
2113 PyOperation &fromOperation = from.getOperation();
2114 fromOperation.checkValid();
2115 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
2116 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2117 from.getOperation())))
2118
2119 throw py::value_error("Symbol rename failed");
2120}
2121
2122void PySymbolTable::walkSymbolTables(PyOperationBase &from,
2123 bool allSymUsesVisible,
2124 py::object callback) {
2125 PyOperation &fromOperation = from.getOperation();
2126 fromOperation.checkValid();
2127 struct UserData {
2128 PyMlirContextRef context;
2129 py::object callback;
2130 bool gotException;
2131 std::string exceptionWhat;
2132 py::object exceptionType;
2133 };
2134 UserData userData{
2135 fromOperation.getContext(), std::move(callback), false, {}, {}};
2136 mlirSymbolTableWalkSymbolTables(
2137 fromOperation.get(), allSymUsesVisible,
2138 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2139 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2140 auto pyFoundOp =
2141 PyOperation::forOperation(calleeUserData->context, foundOp);
2142 if (calleeUserData->gotException)
2143 return;
2144 try {
2145 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2146 } catch (py::error_already_set &e) {
2147 calleeUserData->gotException = true;
2148 calleeUserData->exceptionWhat = e.what();
2149 calleeUserData->exceptionType = e.type();
2150 }
2151 },
2152 static_cast<void *>(&userData));
2153 if (userData.gotException) {
2154 std::string message("Exception raised in callback: ");
2155 message.append(str: userData.exceptionWhat);
2156 throw std::runtime_error(message);
2157 }
2158}
2159
2160namespace {
2161/// CRTP base class for Python MLIR values that subclass Value and should be
2162/// castable from it. The value hierarchy is one level deep and is not supposed
2163/// to accommodate other levels unless core MLIR changes.
2164template <typename DerivedTy>
2165class PyConcreteValue : public PyValue {
2166public:
2167 // Derived classes must define statics for:
2168 // IsAFunctionTy isaFunction
2169 // const char *pyClassName
2170 // and redefine bindDerived.
2171 using ClassTy = py::class_<DerivedTy, PyValue>;
2172 using IsAFunctionTy = bool (*)(MlirValue);
2173
2174 PyConcreteValue() = default;
2175 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
2176 : PyValue(operationRef, value) {}
2177 PyConcreteValue(PyValue &orig)
2178 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
2179
2180 /// Attempts to cast the original value to the derived type and throws on
2181 /// type mismatches.
2182 static MlirValue castFrom(PyValue &orig) {
2183 if (!DerivedTy::isaFunction(orig.get())) {
2184 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
2185 throw py::value_error((Twine("Cannot cast value to ") +
2186 DerivedTy::pyClassName + " (from " + origRepr +
2187 ")")
2188 .str());
2189 }
2190 return orig.get();
2191 }
2192
2193 /// Binds the Python module objects to functions of this class.
2194 static void bind(py::module &m) {
2195 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
2196 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
2197 cls.def_static(
2198 "isinstance",
2199 [](PyValue &otherValue) -> bool {
2200 return DerivedTy::isaFunction(otherValue);
2201 },
2202 py::arg("other_value"));
2203 cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
2204 [](DerivedTy &self) { return self.maybeDownCast(); });
2205 DerivedTy::bindDerived(cls);
2206 }
2207
2208 /// Implemented by derived classes to add methods to the Python subclass.
2209 static void bindDerived(ClassTy &m) {}
2210};
2211
2212/// Python wrapper for MlirBlockArgument.
2213class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2214public:
2215 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2216 static constexpr const char *pyClassName = "BlockArgument";
2217 using PyConcreteValue::PyConcreteValue;
2218
2219 static void bindDerived(ClassTy &c) {
2220 c.def_property_readonly("owner", [](PyBlockArgument &self) {
2221 return PyBlock(self.getParentOperation(),
2222 mlirBlockArgumentGetOwner(self.get()));
2223 });
2224 c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
2225 return mlirBlockArgumentGetArgNumber(self.get());
2226 });
2227 c.def(
2228 "set_type",
2229 [](PyBlockArgument &self, PyType type) {
2230 return mlirBlockArgumentSetType(self.get(), type);
2231 },
2232 py::arg("type"));
2233 }
2234};
2235
2236/// Python wrapper for MlirOpResult.
2237class PyOpResult : public PyConcreteValue<PyOpResult> {
2238public:
2239 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
2240 static constexpr const char *pyClassName = "OpResult";
2241 using PyConcreteValue::PyConcreteValue;
2242
2243 static void bindDerived(ClassTy &c) {
2244 c.def_property_readonly("owner", [](PyOpResult &self) {
2245 assert(
2246 mlirOperationEqual(self.getParentOperation()->get(),
2247 mlirOpResultGetOwner(self.get())) &&
2248 "expected the owner of the value in Python to match that in the IR");
2249 return self.getParentOperation().getObject();
2250 });
2251 c.def_property_readonly("result_number", [](PyOpResult &self) {
2252 return mlirOpResultGetResultNumber(self.get());
2253 });
2254 }
2255};
2256
2257/// Returns the list of types of the values held by container.
2258template <typename Container>
2259static std::vector<MlirType> getValueTypes(Container &container,
2260 PyMlirContextRef &context) {
2261 std::vector<MlirType> result;
2262 result.reserve(container.size());
2263 for (int i = 0, e = container.size(); i < e; ++i) {
2264 result.push_back(mlirValueGetType(container.getElement(i).get()));
2265 }
2266 return result;
2267}
2268
2269/// A list of block arguments. Internally, these are stored as consecutive
2270/// elements, random access is cheap. The argument list is associated with the
2271/// operation that contains the block (detached blocks are not allowed in
2272/// Python bindings) and extends its lifetime.
2273class PyBlockArgumentList
2274 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2275public:
2276 static constexpr const char *pyClassName = "BlockArgumentList";
2277 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2278
2279 PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2280 intptr_t startIndex = 0, intptr_t length = -1,
2281 intptr_t step = 1)
2282 : Sliceable(startIndex,
2283 length == -1 ? mlirBlockGetNumArguments(block) : length,
2284 step),
2285 operation(std::move(operation)), block(block) {}
2286
2287 static void bindDerived(ClassTy &c) {
2288 c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2289 return getValueTypes(self, self.operation->getContext());
2290 });
2291 }
2292
2293private:
2294 /// Give the parent CRTP class access to hook implementations below.
2295 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2296
2297 /// Returns the number of arguments in the list.
2298 intptr_t getRawNumElements() {
2299 operation->checkValid();
2300 return mlirBlockGetNumArguments(block);
2301 }
2302
2303 /// Returns `pos`-the element in the list.
2304 PyBlockArgument getRawElement(intptr_t pos) {
2305 MlirValue argument = mlirBlockGetArgument(block, pos);
2306 return PyBlockArgument(operation, argument);
2307 }
2308
2309 /// Returns a sublist of this list.
2310 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2311 intptr_t step) {
2312 return PyBlockArgumentList(operation, block, startIndex, length, step);
2313 }
2314
2315 PyOperationRef operation;
2316 MlirBlock block;
2317};
2318
2319/// A list of operation operands. Internally, these are stored as consecutive
2320/// elements, random access is cheap. The (returned) operand list is associated
2321/// with the operation whose operands these are, and thus extends the lifetime
2322/// of this operation.
2323class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2324public:
2325 static constexpr const char *pyClassName = "OpOperandList";
2326 using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2327
2328 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2329 intptr_t length = -1, intptr_t step = 1)
2330 : Sliceable(startIndex,
2331 length == -1 ? mlirOperationGetNumOperands(operation->get())
2332 : length,
2333 step),
2334 operation(operation) {}
2335
2336 void dunderSetItem(intptr_t index, PyValue value) {
2337 index = wrapIndex(index);
2338 mlirOperationSetOperand(operation->get(), index, value.get());
2339 }
2340
2341 static void bindDerived(ClassTy &c) {
2342 c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2343 }
2344
2345private:
2346 /// Give the parent CRTP class access to hook implementations below.
2347 friend class Sliceable<PyOpOperandList, PyValue>;
2348
2349 intptr_t getRawNumElements() {
2350 operation->checkValid();
2351 return mlirOperationGetNumOperands(operation->get());
2352 }
2353
2354 PyValue getRawElement(intptr_t pos) {
2355 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2356 MlirOperation owner;
2357 if (mlirValueIsAOpResult(operand))
2358 owner = mlirOpResultGetOwner(operand);
2359 else if (mlirValueIsABlockArgument(operand))
2360 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2361 else
2362 assert(false && "Value must be an block arg or op result.");
2363 PyOperationRef pyOwner =
2364 PyOperation::forOperation(operation->getContext(), owner);
2365 return PyValue(pyOwner, operand);
2366 }
2367
2368 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2369 return PyOpOperandList(operation, startIndex, length, step);
2370 }
2371
2372 PyOperationRef operation;
2373};
2374
2375/// A list of operation results. Internally, these are stored as consecutive
2376/// elements, random access is cheap. The (returned) result list is associated
2377/// with the operation whose results these are, and thus extends the lifetime of
2378/// this operation.
2379class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2380public:
2381 static constexpr const char *pyClassName = "OpResultList";
2382 using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
2383
2384 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2385 intptr_t length = -1, intptr_t step = 1)
2386 : Sliceable(startIndex,
2387 length == -1 ? mlirOperationGetNumResults(operation->get())
2388 : length,
2389 step),
2390 operation(std::move(operation)) {}
2391
2392 static void bindDerived(ClassTy &c) {
2393 c.def_property_readonly("types", [](PyOpResultList &self) {
2394 return getValueTypes(self, self.operation->getContext());
2395 });
2396 c.def_property_readonly("owner", [](PyOpResultList &self) {
2397 return self.operation->createOpView();
2398 });
2399 }
2400
2401private:
2402 /// Give the parent CRTP class access to hook implementations below.
2403 friend class Sliceable<PyOpResultList, PyOpResult>;
2404
2405 intptr_t getRawNumElements() {
2406 operation->checkValid();
2407 return mlirOperationGetNumResults(operation->get());
2408 }
2409
2410 PyOpResult getRawElement(intptr_t index) {
2411 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2412 return PyOpResult(value);
2413 }
2414
2415 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2416 return PyOpResultList(operation, startIndex, length, step);
2417 }
2418
2419 PyOperationRef operation;
2420};
2421
2422/// A list of operation successors. Internally, these are stored as consecutive
2423/// elements, random access is cheap. The (returned) successor list is
2424/// associated with the operation whose successors these are, and thus extends
2425/// the lifetime of this operation.
2426class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2427public:
2428 static constexpr const char *pyClassName = "OpSuccessors";
2429
2430 PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2431 intptr_t length = -1, intptr_t step = 1)
2432 : Sliceable(startIndex,
2433 length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2434 : length,
2435 step),
2436 operation(operation) {}
2437
2438 void dunderSetItem(intptr_t index, PyBlock block) {
2439 index = wrapIndex(index);
2440 mlirOperationSetSuccessor(operation->get(), index, block.get());
2441 }
2442
2443 static void bindDerived(ClassTy &c) {
2444 c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2445 }
2446
2447private:
2448 /// Give the parent CRTP class access to hook implementations below.
2449 friend class Sliceable<PyOpSuccessors, PyBlock>;
2450
2451 intptr_t getRawNumElements() {
2452 operation->checkValid();
2453 return mlirOperationGetNumSuccessors(operation->get());
2454 }
2455
2456 PyBlock getRawElement(intptr_t pos) {
2457 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2458 return PyBlock(operation, block);
2459 }
2460
2461 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2462 return PyOpSuccessors(operation, startIndex, length, step);
2463 }
2464
2465 PyOperationRef operation;
2466};
2467
2468/// A list of operation attributes. Can be indexed by name, producing
2469/// attributes, or by index, producing named attributes.
2470class PyOpAttributeMap {
2471public:
2472 PyOpAttributeMap(PyOperationRef operation)
2473 : operation(std::move(operation)) {}
2474
2475 MlirAttribute dunderGetItemNamed(const std::string &name) {
2476 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2477 toMlirStringRef(name));
2478 if (mlirAttributeIsNull(attr)) {
2479 throw py::key_error("attempt to access a non-existent attribute");
2480 }
2481 return attr;
2482 }
2483
2484 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2485 if (index < 0 || index >= dunderLen()) {
2486 throw py::index_error("attempt to access out of bounds attribute");
2487 }
2488 MlirNamedAttribute namedAttr =
2489 mlirOperationGetAttribute(operation->get(), index);
2490 return PyNamedAttribute(
2491 namedAttr.attribute,
2492 std::string(mlirIdentifierStr(namedAttr.name).data,
2493 mlirIdentifierStr(namedAttr.name).length));
2494 }
2495
2496 void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2497 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2498 attr);
2499 }
2500
2501 void dunderDelItem(const std::string &name) {
2502 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2503 toMlirStringRef(name));
2504 if (!removed)
2505 throw py::key_error("attempt to delete a non-existent attribute");
2506 }
2507
2508 intptr_t dunderLen() {
2509 return mlirOperationGetNumAttributes(operation->get());
2510 }
2511
2512 bool dunderContains(const std::string &name) {
2513 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2514 operation->get(), toMlirStringRef(name)));
2515 }
2516
2517 static void bind(py::module &m) {
2518 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2519 .def("__contains__", &PyOpAttributeMap::dunderContains)
2520 .def("__len__", &PyOpAttributeMap::dunderLen)
2521 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2522 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2523 .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2524 .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2525 }
2526
2527private:
2528 PyOperationRef operation;
2529};
2530
2531} // namespace
2532
2533//------------------------------------------------------------------------------
2534// Populates the core exports of the 'ir' submodule.
2535//------------------------------------------------------------------------------
2536
2537void mlir::python::populateIRCore(py::module &m) {
2538 //----------------------------------------------------------------------------
2539 // Enums.
2540 //----------------------------------------------------------------------------
2541 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2542 .value("ERROR", MlirDiagnosticError)
2543 .value("WARNING", MlirDiagnosticWarning)
2544 .value("NOTE", MlirDiagnosticNote)
2545 .value("REMARK", MlirDiagnosticRemark);
2546
2547 py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
2548 .value("PRE_ORDER", MlirWalkPreOrder)
2549 .value("POST_ORDER", MlirWalkPostOrder);
2550
2551 py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
2552 .value("ADVANCE", MlirWalkResultAdvance)
2553 .value("INTERRUPT", MlirWalkResultInterrupt)
2554 .value("SKIP", MlirWalkResultSkip);
2555
2556 //----------------------------------------------------------------------------
2557 // Mapping of Diagnostics.
2558 //----------------------------------------------------------------------------
2559 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2560 .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2561 .def_property_readonly("location", &PyDiagnostic::getLocation)
2562 .def_property_readonly("message", &PyDiagnostic::getMessage)
2563 .def_property_readonly("notes", &PyDiagnostic::getNotes)
2564 .def("__str__", [](PyDiagnostic &self) -> py::str {
2565 if (!self.isValid())
2566 return "<Invalid Diagnostic>";
2567 return self.getMessage();
2568 });
2569
2570 py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
2571 py::module_local())
2572 .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
2573 .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
2574 .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
2575 .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
2576 .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
2577 .def("__str__",
2578 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2579
2580 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2581 .def("detach", &PyDiagnosticHandler::detach)
2582 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2583 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2584 .def("__enter__", &PyDiagnosticHandler::contextEnter)
2585 .def("__exit__", &PyDiagnosticHandler::contextExit);
2586
2587 //----------------------------------------------------------------------------
2588 // Mapping of MlirContext.
2589 // Note that this is exported as _BaseContext. The containing, Python level
2590 // __init__.py will subclass it with site-specific functionality and set a
2591 // "Context" attribute on this module.
2592 //----------------------------------------------------------------------------
2593 py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2594 .def(py::init<>(&PyMlirContext::createNewContextForInit))
2595 .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2596 .def("_get_context_again",
2597 [](PyMlirContext &self) {
2598 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2599 return ref.releaseObject();
2600 })
2601 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2602 .def("_get_live_operation_objects",
2603 &PyMlirContext::getLiveOperationObjects)
2604 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2605 .def("_clear_live_operations_inside",
2606 py::overload_cast<MlirOperation>(
2607 &PyMlirContext::clearOperationsInside))
2608 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2609 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2610 &PyMlirContext::getCapsule)
2611 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2612 .def("__enter__", &PyMlirContext::contextEnter)
2613 .def("__exit__", &PyMlirContext::contextExit)
2614 .def_property_readonly_static(
2615 "current",
2616 [](py::object & /*class*/) {
2617 auto *context = PyThreadContextEntry::getDefaultContext();
2618 if (!context)
2619 return py::none().cast<py::object>();
2620 return py::cast(context);
2621 },
2622 "Gets the Context bound to the current thread or raises ValueError")
2623 .def_property_readonly(
2624 "dialects",
2625 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2626 "Gets a container for accessing dialects by name")
2627 .def_property_readonly(
2628 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2629 "Alias for 'dialect'")
2630 .def(
2631 "get_dialect_descriptor",
2632 [=](PyMlirContext &self, std::string &name) {
2633 MlirDialect dialect = mlirContextGetOrLoadDialect(
2634 self.get(), {name.data(), name.size()});
2635 if (mlirDialectIsNull(dialect)) {
2636 throw py::value_error(
2637 (Twine("Dialect '") + name + "' not found").str());
2638 }
2639 return PyDialectDescriptor(self.getRef(), dialect);
2640 },
2641 py::arg("dialect_name"),
2642 "Gets or loads a dialect by name, returning its descriptor object")
2643 .def_property(
2644 "allow_unregistered_dialects",
2645 [](PyMlirContext &self) -> bool {
2646 return mlirContextGetAllowUnregisteredDialects(self.get());
2647 },
2648 [](PyMlirContext &self, bool value) {
2649 mlirContextSetAllowUnregisteredDialects(self.get(), value);
2650 })
2651 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2652 py::arg("callback"),
2653 "Attaches a diagnostic handler that will receive callbacks")
2654 .def(
2655 "enable_multithreading",
2656 [](PyMlirContext &self, bool enable) {
2657 mlirContextEnableMultithreading(self.get(), enable);
2658 },
2659 py::arg("enable"))
2660 .def(
2661 "is_registered_operation",
2662 [](PyMlirContext &self, std::string &name) {
2663 return mlirContextIsRegisteredOperation(
2664 self.get(), MlirStringRef{name.data(), name.size()});
2665 },
2666 py::arg("operation_name"))
2667 .def(
2668 "append_dialect_registry",
2669 [](PyMlirContext &self, PyDialectRegistry &registry) {
2670 mlirContextAppendDialectRegistry(self.get(), registry);
2671 },
2672 py::arg("registry"))
2673 .def_property("emit_error_diagnostics", nullptr,
2674 &PyMlirContext::setEmitErrorDiagnostics,
2675 "Emit error diagnostics to diagnostic handlers. By default "
2676 "error diagnostics are captured and reported through "
2677 "MLIRError exceptions.")
2678 .def("load_all_available_dialects", [](PyMlirContext &self) {
2679 mlirContextLoadAllAvailableDialects(self.get());
2680 });
2681
2682 //----------------------------------------------------------------------------
2683 // Mapping of PyDialectDescriptor
2684 //----------------------------------------------------------------------------
2685 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2686 .def_property_readonly("namespace",
2687 [](PyDialectDescriptor &self) {
2688 MlirStringRef ns =
2689 mlirDialectGetNamespace(self.get());
2690 return py::str(ns.data, ns.length);
2691 })
2692 .def("__repr__", [](PyDialectDescriptor &self) {
2693 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2694 std::string repr("<DialectDescriptor ");
2695 repr.append(ns.data, ns.length);
2696 repr.append(">");
2697 return repr;
2698 });
2699
2700 //----------------------------------------------------------------------------
2701 // Mapping of PyDialects
2702 //----------------------------------------------------------------------------
2703 py::class_<PyDialects>(m, "Dialects", py::module_local())
2704 .def("__getitem__",
2705 [=](PyDialects &self, std::string keyName) {
2706 MlirDialect dialect =
2707 self.getDialectForKey(keyName, /*attrError=*/false);
2708 py::object descriptor =
2709 py::cast(PyDialectDescriptor{self.getContext(), dialect});
2710 return createCustomDialectWrapper(keyName, std::move(descriptor));
2711 })
2712 .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2713 MlirDialect dialect =
2714 self.getDialectForKey(attrName, /*attrError=*/true);
2715 py::object descriptor =
2716 py::cast(PyDialectDescriptor{self.getContext(), dialect});
2717 return createCustomDialectWrapper(attrName, std::move(descriptor));
2718 });
2719
2720 //----------------------------------------------------------------------------
2721 // Mapping of PyDialect
2722 //----------------------------------------------------------------------------
2723 py::class_<PyDialect>(m, "Dialect", py::module_local())
2724 .def(py::init<py::object>(), py::arg("descriptor"))
2725 .def_property_readonly(
2726 "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2727 .def("__repr__", [](py::object self) {
2728 auto clazz = self.attr("__class__");
2729 return py::str("<Dialect ") +
2730 self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2731 clazz.attr("__module__") + py::str(".") +
2732 clazz.attr("__name__") + py::str(")>");
2733 });
2734
2735 //----------------------------------------------------------------------------
2736 // Mapping of PyDialectRegistry
2737 //----------------------------------------------------------------------------
2738 py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2739 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2740 &PyDialectRegistry::getCapsule)
2741 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2742 .def(py::init<>());
2743
2744 //----------------------------------------------------------------------------
2745 // Mapping of Location
2746 //----------------------------------------------------------------------------
2747 py::class_<PyLocation>(m, "Location", py::module_local())
2748 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2749 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2750 .def("__enter__", &PyLocation::contextEnter)
2751 .def("__exit__", &PyLocation::contextExit)
2752 .def("__eq__",
2753 [](PyLocation &self, PyLocation &other) -> bool {
2754 return mlirLocationEqual(self, other);
2755 })
2756 .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2757 .def_property_readonly_static(
2758 "current",
2759 [](py::object & /*class*/) {
2760 auto *loc = PyThreadContextEntry::getDefaultLocation();
2761 if (!loc)
2762 throw py::value_error("No current Location");
2763 return loc;
2764 },
2765 "Gets the Location bound to the current thread or raises ValueError")
2766 .def_static(
2767 "unknown",
2768 [](DefaultingPyMlirContext context) {
2769 return PyLocation(context->getRef(),
2770 mlirLocationUnknownGet(context->get()));
2771 },
2772 py::arg("context") = py::none(),
2773 "Gets a Location representing an unknown location")
2774 .def_static(
2775 "callsite",
2776 [](PyLocation callee, const std::vector<PyLocation> &frames,
2777 DefaultingPyMlirContext context) {
2778 if (frames.empty())
2779 throw py::value_error("No caller frames provided");
2780 MlirLocation caller = frames.back().get();
2781 for (const PyLocation &frame :
2782 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2783 caller = mlirLocationCallSiteGet(frame.get(), caller);
2784 return PyLocation(context->getRef(),
2785 mlirLocationCallSiteGet(callee.get(), caller));
2786 },
2787 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2788 kContextGetCallSiteLocationDocstring)
2789 .def_static(
2790 "file",
2791 [](std::string filename, int line, int col,
2792 DefaultingPyMlirContext context) {
2793 return PyLocation(
2794 context->getRef(),
2795 mlirLocationFileLineColGet(
2796 context->get(), toMlirStringRef(filename), line, col));
2797 },
2798 py::arg("filename"), py::arg("line"), py::arg("col"),
2799 py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2800 .def_static(
2801 "fused",
2802 [](const std::vector<PyLocation> &pyLocations,
2803 std::optional<PyAttribute> metadata,
2804 DefaultingPyMlirContext context) {
2805 llvm::SmallVector<MlirLocation, 4> locations;
2806 locations.reserve(pyLocations.size());
2807 for (auto &pyLocation : pyLocations)
2808 locations.push_back(pyLocation.get());
2809 MlirLocation location = mlirLocationFusedGet(
2810 context->get(), locations.size(), locations.data(),
2811 metadata ? metadata->get() : MlirAttribute{0});
2812 return PyLocation(context->getRef(), location);
2813 },
2814 py::arg("locations"), py::arg("metadata") = py::none(),
2815 py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2816 .def_static(
2817 "name",
2818 [](std::string name, std::optional<PyLocation> childLoc,
2819 DefaultingPyMlirContext context) {
2820 return PyLocation(
2821 context->getRef(),
2822 mlirLocationNameGet(
2823 context->get(), toMlirStringRef(name),
2824 childLoc ? childLoc->get()
2825 : mlirLocationUnknownGet(context->get())));
2826 },
2827 py::arg("name"), py::arg("childLoc") = py::none(),
2828 py::arg("context") = py::none(), kContextGetNameLocationDocString)
2829 .def_static(
2830 "from_attr",
2831 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2832 return PyLocation(context->getRef(),
2833 mlirLocationFromAttribute(attribute));
2834 },
2835 py::arg("attribute"), py::arg("context") = py::none(),
2836 "Gets a Location from a LocationAttr")
2837 .def_property_readonly(
2838 "context",
2839 [](PyLocation &self) { return self.getContext().getObject(); },
2840 "Context that owns the Location")
2841 .def_property_readonly(
2842 "attr",
2843 [](PyLocation &self) { return mlirLocationGetAttribute(self); },
2844 "Get the underlying LocationAttr")
2845 .def(
2846 "emit_error",
2847 [](PyLocation &self, std::string message) {
2848 mlirEmitError(self, message.c_str());
2849 },
2850 py::arg("message"), "Emits an error at this location")
2851 .def("__repr__", [](PyLocation &self) {
2852 PyPrintAccumulator printAccum;
2853 mlirLocationPrint(self, printAccum.getCallback(),
2854 printAccum.getUserData());
2855 return printAccum.join();
2856 });
2857
2858 //----------------------------------------------------------------------------
2859 // Mapping of Module
2860 //----------------------------------------------------------------------------
2861 py::class_<PyModule>(m, "Module", py::module_local())
2862 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2863 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2864 .def_static(
2865 "parse",
2866 [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
2867 PyMlirContext::ErrorCapture errors(context->getRef());
2868 MlirModule module = mlirModuleCreateParse(
2869 context->get(), toMlirStringRef(moduleAsm));
2870 if (mlirModuleIsNull(module))
2871 throw MLIRError("Unable to parse module assembly", errors.take());
2872 return PyModule::forModule(module).releaseObject();
2873 },
2874 py::arg("asm"), py::arg("context") = py::none(),
2875 kModuleParseDocstring)
2876 .def_static(
2877 "create",
2878 [](DefaultingPyLocation loc) {
2879 MlirModule module = mlirModuleCreateEmpty(loc);
2880 return PyModule::forModule(module).releaseObject();
2881 },
2882 py::arg("loc") = py::none(), "Creates an empty module")
2883 .def_property_readonly(
2884 "context",
2885 [](PyModule &self) { return self.getContext().getObject(); },
2886 "Context that created the Module")
2887 .def_property_readonly(
2888 "operation",
2889 [](PyModule &self) {
2890 return PyOperation::forOperation(self.getContext(),
2891 mlirModuleGetOperation(self.get()),
2892 self.getRef().releaseObject())
2893 .releaseObject();
2894 },
2895 "Accesses the module as an operation")
2896 .def_property_readonly(
2897 "body",
2898 [](PyModule &self) {
2899 PyOperationRef moduleOp = PyOperation::forOperation(
2900 self.getContext(), mlirModuleGetOperation(self.get()),
2901 self.getRef().releaseObject());
2902 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2903 return returnBlock;
2904 },
2905 "Return the block for this module")
2906 .def(
2907 "dump",
2908 [](PyModule &self) {
2909 mlirOperationDump(mlirModuleGetOperation(self.get()));
2910 },
2911 kDumpDocstring)
2912 .def(
2913 "__str__",
2914 [](py::object self) {
2915 // Defer to the operation's __str__.
2916 return self.attr("operation").attr("__str__")();
2917 },
2918 kOperationStrDunderDocstring);
2919
2920 //----------------------------------------------------------------------------
2921 // Mapping of Operation.
2922 //----------------------------------------------------------------------------
2923 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2924 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2925 [](PyOperationBase &self) {
2926 return self.getOperation().getCapsule();
2927 })
2928 .def("__eq__",
2929 [](PyOperationBase &self, PyOperationBase &other) {
2930 return &self.getOperation() == &other.getOperation();
2931 })
2932 .def("__eq__",
2933 [](PyOperationBase &self, py::object other) { return false; })
2934 .def("__hash__",
2935 [](PyOperationBase &self) {
2936 return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2937 })
2938 .def_property_readonly("attributes",
2939 [](PyOperationBase &self) {
2940 return PyOpAttributeMap(
2941 self.getOperation().getRef());
2942 })
2943 .def_property_readonly(
2944 "context",
2945 [](PyOperationBase &self) {
2946 PyOperation &concreteOperation = self.getOperation();
2947 concreteOperation.checkValid();
2948 return concreteOperation.getContext().getObject();
2949 },
2950 "Context that owns the Operation")
2951 .def_property_readonly("name",
2952 [](PyOperationBase &self) {
2953 auto &concreteOperation = self.getOperation();
2954 concreteOperation.checkValid();
2955 MlirOperation operation =
2956 concreteOperation.get();
2957 MlirStringRef name = mlirIdentifierStr(
2958 mlirOperationGetName(operation));
2959 return py::str(name.data, name.length);
2960 })
2961 .def_property_readonly("operands",
2962 [](PyOperationBase &self) {
2963 return PyOpOperandList(
2964 self.getOperation().getRef());
2965 })
2966 .def_property_readonly("regions",
2967 [](PyOperationBase &self) {
2968 return PyRegionList(
2969 self.getOperation().getRef());
2970 })
2971 .def_property_readonly(
2972 "results",
2973 [](PyOperationBase &self) {
2974 return PyOpResultList(self.getOperation().getRef());
2975 },
2976 "Returns the list of Operation results.")
2977 .def_property_readonly(
2978 "result",
2979 [](PyOperationBase &self) {
2980 auto &operation = self.getOperation();
2981 auto numResults = mlirOperationGetNumResults(operation);
2982 if (numResults != 1) {
2983 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2984 throw py::value_error(
2985 (Twine("Cannot call .result on operation ") +
2986 StringRef(name.data, name.length) + " which has " +
2987 Twine(numResults) +
2988 " results (it is only valid for operations with a "
2989 "single result)")
2990 .str());
2991 }
2992 return PyOpResult(operation.getRef(),
2993 mlirOperationGetResult(operation, 0))
2994 .maybeDownCast();
2995 },
2996 "Shortcut to get an op result if it has only one (throws an error "
2997 "otherwise).")
2998 .def_property_readonly(
2999 "location",
3000 [](PyOperationBase &self) {
3001 PyOperation &operation = self.getOperation();
3002 return PyLocation(operation.getContext(),
3003 mlirOperationGetLocation(operation.get()));
3004 },
3005 "Returns the source location the operation was defined or derived "
3006 "from.")
3007 .def_property_readonly("parent",
3008 [](PyOperationBase &self) -> py::object {
3009 auto parent =
3010 self.getOperation().getParentOperation();
3011 if (parent)
3012 return parent->getObject();
3013 return py::none();
3014 })
3015 .def(
3016 "__str__",
3017 [](PyOperationBase &self) {
3018 return self.getAsm(/*binary=*/false,
3019 /*largeElementsLimit=*/std::nullopt,
3020 /*enableDebugInfo=*/false,
3021 /*prettyDebugInfo=*/false,
3022 /*printGenericOpForm=*/false,
3023 /*useLocalScope=*/false,
3024 /*assumeVerified=*/false);
3025 },
3026 "Returns the assembly form of the operation.")
3027 .def("print",
3028 py::overload_cast<PyAsmState &, pybind11::object, bool>(
3029 &PyOperationBase::print),
3030 py::arg("state"), py::arg("file") = py::none(),
3031 py::arg("binary") = false, kOperationPrintStateDocstring)
3032 .def("print",
3033 py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3034 bool, py::object, bool>(&PyOperationBase::print),
3035 // Careful: Lots of arguments must match up with print method.
3036 py::arg("large_elements_limit") = py::none(),
3037 py::arg("enable_debug_info") = false,
3038 py::arg("pretty_debug_info") = false,
3039 py::arg("print_generic_op_form") = false,
3040 py::arg("use_local_scope") = false,
3041 py::arg("assume_verified") = false, py::arg("file") = py::none(),
3042 py::arg("binary") = false, kOperationPrintDocstring)
3043 .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
3044 py::arg("desired_version") = py::none(),
3045 kOperationPrintBytecodeDocstring)
3046 .def("get_asm", &PyOperationBase::getAsm,
3047 // Careful: Lots of arguments must match up with get_asm method.
3048 py::arg("binary") = false,
3049 py::arg("large_elements_limit") = py::none(),
3050 py::arg("enable_debug_info") = false,
3051 py::arg("pretty_debug_info") = false,
3052 py::arg("print_generic_op_form") = false,
3053 py::arg("use_local_scope") = false,
3054 py::arg("assume_verified") = false, kOperationGetAsmDocstring)
3055 .def("verify", &PyOperationBase::verify,
3056 "Verify the operation. Raises MLIRError if verification fails, and "
3057 "returns true otherwise.")
3058 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
3059 "Puts self immediately after the other operation in its parent "
3060 "block.")
3061 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
3062 "Puts self immediately before the other operation in its parent "
3063 "block.")
3064 .def(
3065 "clone",
3066 [](PyOperationBase &self, py::object ip) {
3067 return self.getOperation().clone(ip);
3068 },
3069 py::arg("ip") = py::none())
3070 .def(
3071 "detach_from_parent",
3072 [](PyOperationBase &self) {
3073 PyOperation &operation = self.getOperation();
3074 operation.checkValid();
3075 if (!operation.isAttached())
3076 throw py::value_error("Detached operation has no parent.");
3077
3078 operation.detachFromParent();
3079 return operation.createOpView();
3080 },
3081 "Detaches the operation from its parent block.")
3082 .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3083 .def("walk", &PyOperationBase::walk, py::arg("callback"),
3084 py::arg("walk_order") = MlirWalkPostOrder);
3085
3086 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
3087 .def_static("create", &PyOperation::create, py::arg("name"),
3088 py::arg("results") = py::none(),
3089 py::arg("operands") = py::none(),
3090 py::arg("attributes") = py::none(),
3091 py::arg("successors") = py::none(), py::arg("regions") = 0,
3092 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3093 py::arg("infer_type") = false, kOperationCreateDocstring)
3094 .def_static(
3095 "parse",
3096 [](const std::string &sourceStr, const std::string &sourceName,
3097 DefaultingPyMlirContext context) {
3098 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3099 ->createOpView();
3100 },
3101 py::arg("source"), py::kw_only(), py::arg("source_name") = "",
3102 py::arg("context") = py::none(),
3103 "Parses an operation. Supports both text assembly format and binary "
3104 "bytecode format.")
3105 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3106 &PyOperation::getCapsule)
3107 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3108 .def_property_readonly("operation", [](py::object self) { return self; })
3109 .def_property_readonly("opview", &PyOperation::createOpView)
3110 .def_property_readonly(
3111 "successors",
3112 [](PyOperationBase &self) {
3113 return PyOpSuccessors(self.getOperation().getRef());
3114 },
3115 "Returns the list of Operation successors.");
3116
3117 auto opViewClass =
3118 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
3119 .def(py::init<py::object>(), py::arg("operation"))
3120 .def_property_readonly("operation", &PyOpView::getOperationObject)
3121 .def_property_readonly("opview", [](py::object self) { return self; })
3122 .def(
3123 "__str__",
3124 [](PyOpView &self) { return py::str(self.getOperationObject()); })
3125 .def_property_readonly(
3126 "successors",
3127 [](PyOperationBase &self) {
3128 return PyOpSuccessors(self.getOperation().getRef());
3129 },
3130 "Returns the list of Operation successors.");
3131 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
3132 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
3133 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
3134 opViewClass.attr("build_generic") = classmethod(
3135 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
3136 py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
3137 py::arg("successors") = py::none(), py::arg("regions") = py::none(),
3138 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3139 "Builds a specific, generated OpView based on class level attributes.");
3140 opViewClass.attr("parse") = classmethod(
3141 [](const py::object &cls, const std::string &sourceStr,
3142 const std::string &sourceName, DefaultingPyMlirContext context) {
3143 PyOperationRef parsed =
3144 PyOperation::parse(context->getRef(), sourceStr, sourceName);
3145
3146 // Check if the expected operation was parsed, and cast to to the
3147 // appropriate `OpView` subclass if successful.
3148 // NOTE: This accesses attributes that have been automatically added to
3149 // `OpView` subclasses, and is not intended to be used on `OpView`
3150 // directly.
3151 std::string clsOpName =
3152 py::cast<std::string>(cls.attr("OPERATION_NAME"));
3153 MlirStringRef identifier =
3154 mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
3155 std::string_view parsedOpName(identifier.data, identifier.length);
3156 if (clsOpName != parsedOpName)
3157 throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3158 parsedOpName + "'");
3159 return PyOpView::constructDerived(cls, *parsed.get());
3160 },
3161 py::arg("cls"), py::arg("source"), py::kw_only(),
3162 py::arg("source_name") = "", py::arg("context") = py::none(),
3163 "Parses a specific, generated OpView based on class level attributes");
3164
3165 //----------------------------------------------------------------------------
3166 // Mapping of PyRegion.
3167 //----------------------------------------------------------------------------
3168 py::class_<PyRegion>(m, "Region", py::module_local())
3169 .def_property_readonly(
3170 "blocks",
3171 [](PyRegion &self) {
3172 return PyBlockList(self.getParentOperation(), self.get());
3173 },
3174 "Returns a forward-optimized sequence of blocks.")
3175 .def_property_readonly(
3176 "owner",
3177 [](PyRegion &self) {
3178 return self.getParentOperation()->createOpView();
3179 },
3180 "Returns the operation owning this region.")
3181 .def(
3182 "__iter__",
3183 [](PyRegion &self) {
3184 self.checkValid();
3185 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3186 return PyBlockIterator(self.getParentOperation(), firstBlock);
3187 },
3188 "Iterates over blocks in the region.")
3189 .def("__eq__",
3190 [](PyRegion &self, PyRegion &other) {
3191 return self.get().ptr == other.get().ptr;
3192 })
3193 .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
3194
3195 //----------------------------------------------------------------------------
3196 // Mapping of PyBlock.
3197 //----------------------------------------------------------------------------
3198 py::class_<PyBlock>(m, "Block", py::module_local())
3199 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3200 .def_property_readonly(
3201 "owner",
3202 [](PyBlock &self) {
3203 return self.getParentOperation()->createOpView();
3204 },
3205 "Returns the owning operation of this block.")
3206 .def_property_readonly(
3207 "region",
3208 [](PyBlock &self) {
3209 MlirRegion region = mlirBlockGetParentRegion(self.get());
3210 return PyRegion(self.getParentOperation(), region);
3211 },
3212 "Returns the owning region of this block.")
3213 .def_property_readonly(
3214 "arguments",
3215 [](PyBlock &self) {
3216 return PyBlockArgumentList(self.getParentOperation(), self.get());
3217 },
3218 "Returns a list of block arguments.")
3219 .def_property_readonly(
3220 "operations",
3221 [](PyBlock &self) {
3222 return PyOperationList(self.getParentOperation(), self.get());
3223 },
3224 "Returns a forward-optimized sequence of operations.")
3225 .def_static(
3226 "create_at_start",
3227 [](PyRegion &parent, const py::list &pyArgTypes,
3228 const std::optional<py::sequence> &pyArgLocs) {
3229 parent.checkValid();
3230 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3231 mlirRegionInsertOwnedBlock(parent, 0, block);
3232 return PyBlock(parent.getParentOperation(), block);
3233 },
3234 py::arg("parent"), py::arg("arg_types") = py::list(),
3235 py::arg("arg_locs") = std::nullopt,
3236 "Creates and returns a new Block at the beginning of the given "
3237 "region (with given argument types and locations).")
3238 .def(
3239 "append_to",
3240 [](PyBlock &self, PyRegion &region) {
3241 MlirBlock b = self.get();
3242 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
3243 mlirBlockDetach(b);
3244 mlirRegionAppendOwnedBlock(region.get(), b);
3245 },
3246 "Append this block to a region, transferring ownership if necessary")
3247 .def(
3248 "create_before",
3249 [](PyBlock &self, const py::args &pyArgTypes,
3250 const std::optional<py::sequence> &pyArgLocs) {
3251 self.checkValid();
3252 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3253 MlirRegion region = mlirBlockGetParentRegion(self.get());
3254 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3255 return PyBlock(self.getParentOperation(), block);
3256 },
3257 py::arg("arg_locs") = std::nullopt,
3258 "Creates and returns a new Block before this block "
3259 "(with given argument types and locations).")
3260 .def(
3261 "create_after",
3262 [](PyBlock &self, const py::args &pyArgTypes,
3263 const std::optional<py::sequence> &pyArgLocs) {
3264 self.checkValid();
3265 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3266 MlirRegion region = mlirBlockGetParentRegion(self.get());
3267 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3268 return PyBlock(self.getParentOperation(), block);
3269 },
3270 py::arg("arg_locs") = std::nullopt,
3271 "Creates and returns a new Block after this block "
3272 "(with given argument types and locations).")
3273 .def(
3274 "__iter__",
3275 [](PyBlock &self) {
3276 self.checkValid();
3277 MlirOperation firstOperation =
3278 mlirBlockGetFirstOperation(self.get());
3279 return PyOperationIterator(self.getParentOperation(),
3280 firstOperation);
3281 },
3282 "Iterates over operations in the block.")
3283 .def("__eq__",
3284 [](PyBlock &self, PyBlock &other) {
3285 return self.get().ptr == other.get().ptr;
3286 })
3287 .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
3288 .def("__hash__",
3289 [](PyBlock &self) {
3290 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3291 })
3292 .def(
3293 "__str__",
3294 [](PyBlock &self) {
3295 self.checkValid();
3296 PyPrintAccumulator printAccum;
3297 mlirBlockPrint(self.get(), printAccum.getCallback(),
3298 printAccum.getUserData());
3299 return printAccum.join();
3300 },
3301 "Returns the assembly form of the block.")
3302 .def(
3303 "append",
3304 [](PyBlock &self, PyOperationBase &operation) {
3305 if (operation.getOperation().isAttached())
3306 operation.getOperation().detachFromParent();
3307
3308 MlirOperation mlirOperation = operation.getOperation().get();
3309 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3310 operation.getOperation().setAttached(
3311 self.getParentOperation().getObject());
3312 },
3313 py::arg("operation"),
3314 "Appends an operation to this block. If the operation is currently "
3315 "in another block, it will be moved.");
3316
3317 //----------------------------------------------------------------------------
3318 // Mapping of PyInsertionPoint.
3319 //----------------------------------------------------------------------------
3320
3321 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
3322 .def(py::init<PyBlock &>(), py::arg("block"),
3323 "Inserts after the last operation but still inside the block.")
3324 .def("__enter__", &PyInsertionPoint::contextEnter)
3325 .def("__exit__", &PyInsertionPoint::contextExit)
3326 .def_property_readonly_static(
3327 "current",
3328 [](py::object & /*class*/) {
3329 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3330 if (!ip)
3331 throw py::value_error("No current InsertionPoint");
3332 return ip;
3333 },
3334 "Gets the InsertionPoint bound to the current thread or raises "
3335 "ValueError if none has been set")
3336 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
3337 "Inserts before a referenced operation.")
3338 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3339 py::arg("block"), "Inserts at the beginning of the block.")
3340 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3341 py::arg("block"), "Inserts before the block terminator.")
3342 .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
3343 "Inserts an operation.")
3344 .def_property_readonly(
3345 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3346 "Returns the block that this InsertionPoint points to.")
3347 .def_property_readonly(
3348 "ref_operation",
3349 [](PyInsertionPoint &self) -> py::object {
3350 auto refOperation = self.getRefOperation();
3351 if (refOperation)
3352 return refOperation->getObject();
3353 return py::none();
3354 },
3355 "The reference operation before which new operations are "
3356 "inserted, or None if the insertion point is at the end of "
3357 "the block");
3358
3359 //----------------------------------------------------------------------------
3360 // Mapping of PyAttribute.
3361 //----------------------------------------------------------------------------
3362 py::class_<PyAttribute>(m, "Attribute", py::module_local())
3363 // Delegate to the PyAttribute copy constructor, which will also lifetime
3364 // extend the backing context which owns the MlirAttribute.
3365 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
3366 "Casts the passed attribute to the generic Attribute")
3367 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3368 &PyAttribute::getCapsule)
3369 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3370 .def_static(
3371 "parse",
3372 [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3373 PyMlirContext::ErrorCapture errors(context->getRef());
3374 MlirAttribute attr = mlirAttributeParseGet(
3375 context->get(), toMlirStringRef(attrSpec));
3376 if (mlirAttributeIsNull(attr))
3377 throw MLIRError("Unable to parse attribute", errors.take());
3378 return attr;
3379 },
3380 py::arg("asm"), py::arg("context") = py::none(),
3381 "Parses an attribute from an assembly form. Raises an MLIRError on "
3382 "failure.")
3383 .def_property_readonly(
3384 "context",
3385 [](PyAttribute &self) { return self.getContext().getObject(); },
3386 "Context that owns the Attribute")
3387 .def_property_readonly(
3388 "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
3389 .def(
3390 "get_named",
3391 [](PyAttribute &self, std::string name) {
3392 return PyNamedAttribute(self, std::move(name));
3393 },
3394 py::keep_alive<0, 1>(), "Binds a name to the attribute")
3395 .def("__eq__",
3396 [](PyAttribute &self, PyAttribute &other) { return self == other; })
3397 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3398 .def("__hash__",
3399 [](PyAttribute &self) {
3400 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3401 })
3402 .def(
3403 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3404 kDumpDocstring)
3405 .def(
3406 "__str__",
3407 [](PyAttribute &self) {
3408 PyPrintAccumulator printAccum;
3409 mlirAttributePrint(self, printAccum.getCallback(),
3410 printAccum.getUserData());
3411 return printAccum.join();
3412 },
3413 "Returns the assembly form of the Attribute.")
3414 .def("__repr__",
3415 [](PyAttribute &self) {
3416 // Generally, assembly formats are not printed for __repr__ because
3417 // this can cause exceptionally long debug output and exceptions.
3418 // However, attribute values are generally considered useful and
3419 // are printed. This may need to be re-evaluated if debug dumps end
3420 // up being excessive.
3421 PyPrintAccumulator printAccum;
3422 printAccum.parts.append("Attribute(");
3423 mlirAttributePrint(self, printAccum.getCallback(),
3424 printAccum.getUserData());
3425 printAccum.parts.append(")");
3426 return printAccum.join();
3427 })
3428 .def_property_readonly(
3429 "typeid",
3430 [](PyAttribute &self) -> MlirTypeID {
3431 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3432 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3433 "mlirTypeID was expected to be non-null.");
3434 return mlirTypeID;
3435 })
3436 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
3437 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3438 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3439 "mlirTypeID was expected to be non-null.");
3440 std::optional<pybind11::function> typeCaster =
3441 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3442 mlirAttributeGetDialect(self));
3443 if (!typeCaster)
3444 return py::cast(self);
3445 return typeCaster.value()(self);
3446 });
3447
3448 //----------------------------------------------------------------------------
3449 // Mapping of PyNamedAttribute
3450 //----------------------------------------------------------------------------
3451 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3452 .def("__repr__",
3453 [](PyNamedAttribute &self) {
3454 PyPrintAccumulator printAccum;
3455 printAccum.parts.append("NamedAttribute(");
3456 printAccum.parts.append(
3457 py::str(mlirIdentifierStr(self.namedAttr.name).data,
3458 mlirIdentifierStr(self.namedAttr.name).length));
3459 printAccum.parts.append("=");
3460 mlirAttributePrint(self.namedAttr.attribute,
3461 printAccum.getCallback(),
3462 printAccum.getUserData());
3463 printAccum.parts.append(")");
3464 return printAccum.join();
3465 })
3466 .def_property_readonly(
3467 "name",
3468 [](PyNamedAttribute &self) {
3469 return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3470 mlirIdentifierStr(self.namedAttr.name).length);
3471 },
3472 "The name of the NamedAttribute binding")
3473 .def_property_readonly(
3474 "attr",
3475 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3476 py::keep_alive<0, 1>(),
3477 "The underlying generic attribute of the NamedAttribute binding");
3478
3479 //----------------------------------------------------------------------------
3480 // Mapping of PyType.
3481 //----------------------------------------------------------------------------
3482 py::class_<PyType>(m, "Type", py::module_local())
3483 // Delegate to the PyType copy constructor, which will also lifetime
3484 // extend the backing context which owns the MlirType.
3485 .def(py::init<PyType &>(), py::arg("cast_from_type"),
3486 "Casts the passed type to the generic Type")
3487 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3488 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3489 .def_static(
3490 "parse",
3491 [](std::string typeSpec, DefaultingPyMlirContext context) {
3492 PyMlirContext::ErrorCapture errors(context->getRef());
3493 MlirType type =
3494 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3495 if (mlirTypeIsNull(type))
3496 throw MLIRError("Unable to parse type", errors.take());
3497 return type;
3498 },
3499 py::arg("asm"), py::arg("context") = py::none(),
3500 kContextParseTypeDocstring)
3501 .def_property_readonly(
3502 "context", [](PyType &self) { return self.getContext().getObject(); },
3503 "Context that owns the Type")
3504 .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3505 .def("__eq__", [](PyType &self, py::object &other) { return false; })
3506 .def("__hash__",
3507 [](PyType &self) {
3508 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3509 })
3510 .def(
3511 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3512 .def(
3513 "__str__",
3514 [](PyType &self) {
3515 PyPrintAccumulator printAccum;
3516 mlirTypePrint(self, printAccum.getCallback(),
3517 printAccum.getUserData());
3518 return printAccum.join();
3519 },
3520 "Returns the assembly form of the type.")
3521 .def("__repr__",
3522 [](PyType &self) {
3523 // Generally, assembly formats are not printed for __repr__ because
3524 // this can cause exceptionally long debug output and exceptions.
3525 // However, types are an exception as they typically have compact
3526 // assembly forms and printing them is useful.
3527 PyPrintAccumulator printAccum;
3528 printAccum.parts.append("Type(");
3529 mlirTypePrint(self, printAccum.getCallback(),
3530 printAccum.getUserData());
3531 printAccum.parts.append(")");
3532 return printAccum.join();
3533 })
3534 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3535 [](PyType &self) {
3536 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3537 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3538 "mlirTypeID was expected to be non-null.");
3539 std::optional<pybind11::function> typeCaster =
3540 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3541 mlirTypeGetDialect(self));
3542 if (!typeCaster)
3543 return py::cast(self);
3544 return typeCaster.value()(self);
3545 })
3546 .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
3547 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3548 if (!mlirTypeIDIsNull(mlirTypeID))
3549 return mlirTypeID;
3550 auto origRepr =
3551 pybind11::repr(pybind11::cast(self)).cast<std::string>();
3552 throw py::value_error(
3553 (origRepr + llvm::Twine(" has no typeid.")).str());
3554 });
3555
3556 //----------------------------------------------------------------------------
3557 // Mapping of PyTypeID.
3558 //----------------------------------------------------------------------------
3559 py::class_<PyTypeID>(m, "TypeID", py::module_local())
3560 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3561 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3562 // Note, this tests whether the underlying TypeIDs are the same,
3563 // not whether the wrapper MlirTypeIDs are the same, nor whether
3564 // the Python objects are the same (i.e., PyTypeID is a value type).
3565 .def("__eq__",
3566 [](PyTypeID &self, PyTypeID &other) { return self == other; })
3567 .def("__eq__",
3568 [](PyTypeID &self, const py::object &other) { return false; })
3569 // Note, this gives the hash value of the underlying TypeID, not the
3570 // hash value of the Python object, nor the hash value of the
3571 // MlirTypeID wrapper.
3572 .def("__hash__", [](PyTypeID &self) {
3573 return static_cast<size_t>(mlirTypeIDHashValue(self));
3574 });
3575
3576 //----------------------------------------------------------------------------
3577 // Mapping of Value.
3578 //----------------------------------------------------------------------------
3579 py::class_<PyValue>(m, "Value", py::module_local())
3580 .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
3581 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3582 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3583 .def_property_readonly(
3584 "context",
3585 [](PyValue &self) { return self.getParentOperation()->getContext(); },
3586 "Context in which the value lives.")
3587 .def(
3588 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3589 kDumpDocstring)
3590 .def_property_readonly(
3591 "owner",
3592 [](PyValue &self) -> py::object {
3593 MlirValue v = self.get();
3594 if (mlirValueIsAOpResult(v)) {
3595 assert(
3596 mlirOperationEqual(self.getParentOperation()->get(),
3597 mlirOpResultGetOwner(self.get())) &&
3598 "expected the owner of the value in Python to match that in "
3599 "the IR");
3600 return self.getParentOperation().getObject();
3601 }
3602
3603 if (mlirValueIsABlockArgument(v)) {
3604 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3605 return py::cast(PyBlock(self.getParentOperation(), block));
3606 }
3607
3608 assert(false && "Value must be a block argument or an op result");
3609 return py::none();
3610 })
3611 .def_property_readonly("uses",
3612 [](PyValue &self) {
3613 return PyOpOperandIterator(
3614 mlirValueGetFirstUse(self.get()));
3615 })
3616 .def("__eq__",
3617 [](PyValue &self, PyValue &other) {
3618 return self.get().ptr == other.get().ptr;
3619 })
3620 .def("__eq__", [](PyValue &self, py::object other) { return false; })
3621 .def("__hash__",
3622 [](PyValue &self) {
3623 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3624 })
3625 .def(
3626 "__str__",
3627 [](PyValue &self) {
3628 PyPrintAccumulator printAccum;
3629 printAccum.parts.append("Value(");
3630 mlirValuePrint(self.get(), printAccum.getCallback(),
3631 printAccum.getUserData());
3632 printAccum.parts.append(")");
3633 return printAccum.join();
3634 },
3635 kValueDunderStrDocstring)
3636 .def(
3637 "get_name",
3638 [](PyValue &self, bool useLocalScope) {
3639 PyPrintAccumulator printAccum;
3640 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3641 if (useLocalScope)
3642 mlirOpPrintingFlagsUseLocalScope(flags);
3643 MlirAsmState valueState =
3644 mlirAsmStateCreateForValue(self.get(), flags);
3645 mlirValuePrintAsOperand(self.get(), valueState,
3646 printAccum.getCallback(),
3647 printAccum.getUserData());
3648 mlirOpPrintingFlagsDestroy(flags);
3649 mlirAsmStateDestroy(valueState);
3650 return printAccum.join();
3651 },
3652 py::arg("use_local_scope") = false)
3653 .def(
3654 "get_name",
3655 [](PyValue &self, std::reference_wrapper<PyAsmState> state) {
3656 PyPrintAccumulator printAccum;
3657 MlirAsmState valueState = state.get().get();
3658 mlirValuePrintAsOperand(self.get(), valueState,
3659 printAccum.getCallback(),
3660 printAccum.getUserData());
3661 return printAccum.join();
3662 },
3663 py::arg("state"), kGetNameAsOperand)
3664 .def_property_readonly(
3665 "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
3666 .def(
3667 "set_type",
3668 [](PyValue &self, const PyType &type) {
3669 return mlirValueSetType(self.get(), type);
3670 },
3671 py::arg("type"))
3672 .def(
3673 "replace_all_uses_with",
3674 [](PyValue &self, PyValue &with) {
3675 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3676 },
3677 kValueReplaceAllUsesWithDocstring)
3678 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3679 [](PyValue &self) { return self.maybeDownCast(); });
3680 PyBlockArgument::bind(m);
3681 PyOpResult::bind(m);
3682 PyOpOperand::bind(m);
3683
3684 py::class_<PyAsmState>(m, "AsmState", py::module_local())
3685 .def(py::init<PyValue &, bool>(), py::arg("value"),
3686 py::arg("use_local_scope") = false)
3687 .def(py::init<PyOperationBase &, bool>(), py::arg("op"),
3688 py::arg("use_local_scope") = false);
3689
3690 //----------------------------------------------------------------------------
3691 // Mapping of SymbolTable.
3692 //----------------------------------------------------------------------------
3693 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3694 .def(py::init<PyOperationBase &>())
3695 .def("__getitem__", &PySymbolTable::dunderGetItem)
3696 .def("insert", &PySymbolTable::insert, py::arg("operation"))
3697 .def("erase", &PySymbolTable::erase, py::arg("operation"))
3698 .def("__delitem__", &PySymbolTable::dunderDel)
3699 .def("__contains__",
3700 [](PySymbolTable &table, const std::string &name) {
3701 return !mlirOperationIsNull(mlirSymbolTableLookup(
3702 table, mlirStringRefCreate(name.data(), name.length())));
3703 })
3704 // Static helpers.
3705 .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3706 py::arg("symbol"), py::arg("name"))
3707 .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3708 py::arg("symbol"))
3709 .def_static("get_visibility", &PySymbolTable::getVisibility,
3710 py::arg("symbol"))
3711 .def_static("set_visibility", &PySymbolTable::setVisibility,
3712 py::arg("symbol"), py::arg("visibility"))
3713 .def_static("replace_all_symbol_uses",
3714 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3715 py::arg("new_symbol"), py::arg("from_op"))
3716 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3717 py::arg("from_op"), py::arg("all_sym_uses_visible"),
3718 py::arg("callback"));
3719
3720 // Container bindings.
3721 PyBlockArgumentList::bind(m);
3722 PyBlockIterator::bind(m);
3723 PyBlockList::bind(m);
3724 PyOperationIterator::bind(m);
3725 PyOperationList::bind(m);
3726 PyOpAttributeMap::bind(m);
3727 PyOpOperandIterator::bind(m);
3728 PyOpOperandList::bind(m);
3729 PyOpResultList::bind(m);
3730 PyOpSuccessors::bind(m);
3731 PyRegionIterator::bind(m);
3732 PyRegionList::bind(m);
3733
3734 // Debug bindings.
3735 PyGlobalDebugFlag::bind(m);
3736
3737 // Attribute builder getter.
3738 PyAttrBuilderMap::bind(m);
3739
3740 py::register_local_exception_translator([](std::exception_ptr p) {
3741 // We can't define exceptions with custom fields through pybind, so instead
3742 // the exception class is defined in python and imported here.
3743 try {
3744 if (p)
3745 std::rethrow_exception(p);
3746 } catch (const MLIRError &e) {
3747 py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
3748 .attr("MLIRError")(e.message, e.errorDiagnostics);
3749 PyErr_SetObject(PyExc_Exception, obj.ptr());
3750 }
3751 });
3752}
3753

source code of mlir/lib/Bindings/Python/IRCore.cpp