1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Globals.h"
10#include "IRModule.h"
11#include "NanobindUtils.h"
12#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
13#include "mlir-c/BuiltinAttributes.h"
14#include "mlir-c/Debug.h"
15#include "mlir-c/Diagnostics.h"
16#include "mlir-c/IR.h"
17#include "mlir-c/Support.h"
18#include "mlir/Bindings/Python/Nanobind.h"
19#include "mlir/Bindings/Python/NanobindAdaptors.h"
20#include "nanobind/nanobind.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/Support/raw_ostream.h"
24
25#include <optional>
26#include <system_error>
27#include <utility>
28
29namespace nb = nanobind;
30using namespace nb::literals;
31using namespace mlir;
32using namespace mlir::python;
33
34using llvm::SmallVector;
35using llvm::StringRef;
36using llvm::Twine;
37
38//------------------------------------------------------------------------------
39// Docstrings (trivial, non-duplicated docstrings are included inline).
40//------------------------------------------------------------------------------
41
42static const char kContextParseTypeDocstring[] =
43 R"(Parses the assembly form of a type.
44
45Returns a Type object or raises an MLIRError if the type cannot be parsed.
46
47See also: https://mlir.llvm.org/docs/LangRef/#type-system
48)";
49
50static const char kContextGetCallSiteLocationDocstring[] =
51 R"(Gets a Location representing a caller and callsite)";
52
53static const char kContextGetFileLocationDocstring[] =
54 R"(Gets a Location representing a file, line and column)";
55
56static const char kContextGetFileRangeDocstring[] =
57 R"(Gets a Location representing a file, line and column range)";
58
59static const char kContextGetFusedLocationDocstring[] =
60 R"(Gets a Location representing a fused location with optional metadata)";
61
62static const char kContextGetNameLocationDocString[] =
63 R"(Gets a Location representing a named location with optional child location)";
64
65static const char kModuleParseDocstring[] =
66 R"(Parses a module's assembly format from a string.
67
68Returns a new MlirModule or raises an MLIRError if the parsing fails.
69
70See also: https://mlir.llvm.org/docs/LangRef/
71)";
72
73static const char kOperationCreateDocstring[] =
74 R"(Creates a new operation.
75
76Args:
77 name: Operation name (e.g. "dialect.operation").
78 results: Sequence of Type representing op result types.
79 attributes: Dict of str:Attribute.
80 successors: List of Block for the operation's successors.
81 regions: Number of regions to create.
82 location: A Location object (defaults to resolve from context manager).
83 ip: An InsertionPoint (defaults to resolve from context manager or set to
84 False to disable insertion, even with an insertion point set in the
85 context manager).
86 infer_type: Whether to infer result types.
87Returns:
88 A new "detached" Operation object. Detached operations can be added
89 to blocks, which causes them to become "attached."
90)";
91
92static const char kOperationPrintDocstring[] =
93 R"(Prints the assembly form of the operation to a file like object.
94
95Args:
96 file: The file like object to write to. Defaults to sys.stdout.
97 binary: Whether to write bytes (True) or str (False). Defaults to False.
98 large_elements_limit: Whether to elide elements attributes above this
99 number of elements. Defaults to None (no limit).
100 enable_debug_info: Whether to print debug/location information. Defaults
101 to False.
102 pretty_debug_info: Whether to format debug information for easier reading
103 by a human (warning: the result is unparseable).
104 print_generic_op_form: Whether to print the generic assembly forms of all
105 ops. Defaults to False.
106 use_local_Scope: Whether to print in a way that is more optimized for
107 multi-threaded access but may not be consistent with how the overall
108 module prints.
109 assume_verified: By default, if not printing generic form, the verifier
110 will be run and if it fails, generic form will be printed with a comment
111 about failed verification. While a reasonable default for interactive use,
112 for systematic use, it is often better for the caller to verify explicitly
113 and report failures in a more robust fashion. Set this to True if doing this
114 in order to avoid running a redundant verification. If the IR is actually
115 invalid, behavior is undefined.
116 skip_regions: Whether to skip printing regions. Defaults to False.
117)";
118
119static const char kOperationPrintStateDocstring[] =
120 R"(Prints the assembly form of the operation to a file like object.
121
122Args:
123 file: The file like object to write to. Defaults to sys.stdout.
124 binary: Whether to write bytes (True) or str (False). Defaults to False.
125 state: AsmState capturing the operation numbering and flags.
126)";
127
128static const char kOperationGetAsmDocstring[] =
129 R"(Gets the assembly form of the operation with all options available.
130
131Args:
132 binary: Whether to return a bytes (True) or str (False) object. Defaults to
133 False.
134 ... others ...: See the print() method for common keyword arguments for
135 configuring the printout.
136Returns:
137 Either a bytes or str object, depending on the setting of the 'binary'
138 argument.
139)";
140
141static const char kOperationPrintBytecodeDocstring[] =
142 R"(Write the bytecode form of the operation to a file like object.
143
144Args:
145 file: The file like object to write to.
146 desired_version: The version of bytecode to emit.
147Returns:
148 The bytecode writer status.
149)";
150
151static const char kOperationStrDunderDocstring[] =
152 R"(Gets the assembly form of the operation with default options.
153
154If more advanced control over the assembly formatting or I/O options is needed,
155use the dedicated print or get_asm method, which supports keyword arguments to
156customize behavior.
157)";
158
159static const char kDumpDocstring[] =
160 R"(Dumps a debug representation of the object to stderr.)";
161
162static const char kAppendBlockDocstring[] =
163 R"(Appends a new block, with argument types as positional args.
164
165Returns:
166 The created block.
167)";
168
169static const char kValueDunderStrDocstring[] =
170 R"(Returns the string form of the value.
171
172If the value is a block argument, this is the assembly form of its type and the
173position in the argument list. If the value is an operation result, this is
174equivalent to printing the operation that produced it.
175)";
176
177static const char kGetNameAsOperand[] =
178 R"(Returns the string form of value as an operand (i.e., the ValueID).
179)";
180
181static const char kValueReplaceAllUsesWithDocstring[] =
182 R"(Replace all uses of value with the new value, updating anything in
183the IR that uses 'self' to use the other value instead.
184)";
185
186static const char kValueReplaceAllUsesExceptDocstring[] =
187 R"("Replace all uses of this value with the 'with' value, except for those
188in 'exceptions'. 'exceptions' can be either a single operation or a list of
189operations.
190)";
191
192//------------------------------------------------------------------------------
193// Utilities.
194//------------------------------------------------------------------------------
195
196/// Helper for creating an @classmethod.
197template <class Func, typename... Args>
198nb::object classmethod(Func f, Args... args) {
199 nb::object cf = nb::cpp_function(f, args...);
200 return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
201}
202
203static nb::object
204createCustomDialectWrapper(const std::string &dialectNamespace,
205 nb::object dialectDescriptor) {
206 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
207 if (!dialectClass) {
208 // Use the base class.
209 return nb::cast(PyDialect(std::move(dialectDescriptor)));
210 }
211
212 // Create the custom implementation.
213 return (*dialectClass)(std::move(dialectDescriptor));
214}
215
216static MlirStringRef toMlirStringRef(const std::string &s) {
217 return mlirStringRefCreate(s.data(), s.size());
218}
219
220static MlirStringRef toMlirStringRef(std::string_view s) {
221 return mlirStringRefCreate(s.data(), s.size());
222}
223
224static MlirStringRef toMlirStringRef(const nb::bytes &s) {
225 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
226}
227
228/// Create a block, using the current location context if no locations are
229/// specified.
230static MlirBlock createBlock(const nb::sequence &pyArgTypes,
231 const std::optional<nb::sequence> &pyArgLocs) {
232 SmallVector<MlirType> argTypes;
233 argTypes.reserve(nb::len(pyArgTypes));
234 for (const auto &pyType : pyArgTypes)
235 argTypes.push_back(nb::cast<PyType &>(pyType));
236
237 SmallVector<MlirLocation> argLocs;
238 if (pyArgLocs) {
239 argLocs.reserve(nb::len(*pyArgLocs));
240 for (const auto &pyLoc : *pyArgLocs)
241 argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
242 } else if (!argTypes.empty()) {
243 argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
244 }
245
246 if (argTypes.size() != argLocs.size())
247 throw nb::value_error(("Expected " + Twine(argTypes.size()) +
248 " locations, got: " + Twine(argLocs.size()))
249 .str()
250 .c_str());
251 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
252}
253
254/// Wrapper for the global LLVM debugging flag.
255struct PyGlobalDebugFlag {
256 static void set(nb::object &o, bool enable) {
257 nb::ft_lock_guard lock(mutex);
258 mlirEnableGlobalDebug(enable);
259 }
260
261 static bool get(const nb::object &) {
262 nb::ft_lock_guard lock(mutex);
263 return mlirIsGlobalDebugEnabled();
264 }
265
266 static void bind(nb::module_ &m) {
267 // Debug flags.
268 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
269 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
270 &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
271 .def_static(
272 "set_types",
273 [](const std::string &type) {
274 nb::ft_lock_guard lock(mutex);
275 mlirSetGlobalDebugType(type.c_str());
276 },
277 "types"_a, "Sets specific debug types to be produced by LLVM")
278 .def_static("set_types", [](const std::vector<std::string> &types) {
279 std::vector<const char *> pointers;
280 pointers.reserve(types.size());
281 for (const std::string &str : types)
282 pointers.push_back(str.c_str());
283 nb::ft_lock_guard lock(mutex);
284 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
285 });
286 }
287
288private:
289 static nb::ft_mutex mutex;
290};
291
292nb::ft_mutex PyGlobalDebugFlag::mutex;
293
294struct PyAttrBuilderMap {
295 static bool dunderContains(const std::string &attributeKind) {
296 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
297 }
298 static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
299 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
300 if (!builder)
301 throw nb::key_error(attributeKind.c_str());
302 return *builder;
303 }
304 static void dunderSetItemNamed(const std::string &attributeKind,
305 nb::callable func, bool replace) {
306 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
307 replace);
308 }
309
310 static void bind(nb::module_ &m) {
311 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
312 .def_static("contains", &PyAttrBuilderMap::dunderContains)
313 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
314 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
315 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
316 "Register an attribute builder for building MLIR "
317 "attributes from python values.");
318 }
319};
320
321//------------------------------------------------------------------------------
322// PyBlock
323//------------------------------------------------------------------------------
324
325nb::object PyBlock::getCapsule() {
326 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
327}
328
329//------------------------------------------------------------------------------
330// Collections.
331//------------------------------------------------------------------------------
332
333namespace {
334
335class PyRegionIterator {
336public:
337 PyRegionIterator(PyOperationRef operation)
338 : operation(std::move(operation)) {}
339
340 PyRegionIterator &dunderIter() { return *this; }
341
342 PyRegion dunderNext() {
343 operation->checkValid();
344 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
345 throw nb::stop_iteration();
346 }
347 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
348 return PyRegion(operation, region);
349 }
350
351 static void bind(nb::module_ &m) {
352 nb::class_<PyRegionIterator>(m, "RegionIterator")
353 .def("__iter__", &PyRegionIterator::dunderIter)
354 .def("__next__", &PyRegionIterator::dunderNext);
355 }
356
357private:
358 PyOperationRef operation;
359 int nextIndex = 0;
360};
361
362/// Regions of an op are fixed length and indexed numerically so are represented
363/// with a sequence-like container.
364class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
365public:
366 static constexpr const char *pyClassName = "RegionSequence";
367
368 PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
369 intptr_t length = -1, intptr_t step = 1)
370 : Sliceable(startIndex,
371 length == -1 ? mlirOperationGetNumRegions(operation->get())
372 : length,
373 step),
374 operation(std::move(operation)) {}
375
376 PyRegionIterator dunderIter() {
377 operation->checkValid();
378 return PyRegionIterator(operation);
379 }
380
381 static void bindDerived(ClassTy &c) {
382 c.def("__iter__", &PyRegionList::dunderIter);
383 }
384
385private:
386 /// Give the parent CRTP class access to hook implementations below.
387 friend class Sliceable<PyRegionList, PyRegion>;
388
389 intptr_t getRawNumElements() {
390 operation->checkValid();
391 return mlirOperationGetNumRegions(operation->get());
392 }
393
394 PyRegion getRawElement(intptr_t pos) {
395 operation->checkValid();
396 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
397 }
398
399 PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
400 return PyRegionList(operation, startIndex, length, step);
401 }
402
403 PyOperationRef operation;
404};
405
406class PyBlockIterator {
407public:
408 PyBlockIterator(PyOperationRef operation, MlirBlock next)
409 : operation(std::move(operation)), next(next) {}
410
411 PyBlockIterator &dunderIter() { return *this; }
412
413 PyBlock dunderNext() {
414 operation->checkValid();
415 if (mlirBlockIsNull(next)) {
416 throw nb::stop_iteration();
417 }
418
419 PyBlock returnBlock(operation, next);
420 next = mlirBlockGetNextInRegion(next);
421 return returnBlock;
422 }
423
424 static void bind(nb::module_ &m) {
425 nb::class_<PyBlockIterator>(m, "BlockIterator")
426 .def("__iter__", &PyBlockIterator::dunderIter)
427 .def("__next__", &PyBlockIterator::dunderNext);
428 }
429
430private:
431 PyOperationRef operation;
432 MlirBlock next;
433};
434
435/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
436/// we present them as a more full-featured list-like container but optimize
437/// it for forward iteration. Blocks are always owned by a region.
438class PyBlockList {
439public:
440 PyBlockList(PyOperationRef operation, MlirRegion region)
441 : operation(std::move(operation)), region(region) {}
442
443 PyBlockIterator dunderIter() {
444 operation->checkValid();
445 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
446 }
447
448 intptr_t dunderLen() {
449 operation->checkValid();
450 intptr_t count = 0;
451 MlirBlock block = mlirRegionGetFirstBlock(region);
452 while (!mlirBlockIsNull(block)) {
453 count += 1;
454 block = mlirBlockGetNextInRegion(block);
455 }
456 return count;
457 }
458
459 PyBlock dunderGetItem(intptr_t index) {
460 operation->checkValid();
461 if (index < 0) {
462 index += dunderLen();
463 }
464 if (index < 0) {
465 throw nb::index_error("attempt to access out of bounds block");
466 }
467 MlirBlock block = mlirRegionGetFirstBlock(region);
468 while (!mlirBlockIsNull(block)) {
469 if (index == 0) {
470 return PyBlock(operation, block);
471 }
472 block = mlirBlockGetNextInRegion(block);
473 index -= 1;
474 }
475 throw nb::index_error("attempt to access out of bounds block");
476 }
477
478 PyBlock appendBlock(const nb::args &pyArgTypes,
479 const std::optional<nb::sequence> &pyArgLocs) {
480 operation->checkValid();
481 MlirBlock block =
482 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
483 mlirRegionAppendOwnedBlock(region, block);
484 return PyBlock(operation, block);
485 }
486
487 static void bind(nb::module_ &m) {
488 nb::class_<PyBlockList>(m, "BlockList")
489 .def("__getitem__", &PyBlockList::dunderGetItem)
490 .def("__iter__", &PyBlockList::dunderIter)
491 .def("__len__", &PyBlockList::dunderLen)
492 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
493 nb::arg("args"), nb::kw_only(),
494 nb::arg("arg_locs") = std::nullopt);
495 }
496
497private:
498 PyOperationRef operation;
499 MlirRegion region;
500};
501
502class PyOperationIterator {
503public:
504 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
505 : parentOperation(std::move(parentOperation)), next(next) {}
506
507 PyOperationIterator &dunderIter() { return *this; }
508
509 nb::object dunderNext() {
510 parentOperation->checkValid();
511 if (mlirOperationIsNull(next)) {
512 throw nb::stop_iteration();
513 }
514
515 PyOperationRef returnOperation =
516 PyOperation::forOperation(parentOperation->getContext(), next);
517 next = mlirOperationGetNextInBlock(next);
518 return returnOperation->createOpView();
519 }
520
521 static void bind(nb::module_ &m) {
522 nb::class_<PyOperationIterator>(m, "OperationIterator")
523 .def("__iter__", &PyOperationIterator::dunderIter)
524 .def("__next__", &PyOperationIterator::dunderNext);
525 }
526
527private:
528 PyOperationRef parentOperation;
529 MlirOperation next;
530};
531
532/// Operations are exposed by the C-API as a forward-only linked list. In
533/// Python, we present them as a more full-featured list-like container but
534/// optimize it for forward iteration. Iterable operations are always owned
535/// by a block.
536class PyOperationList {
537public:
538 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
539 : parentOperation(std::move(parentOperation)), block(block) {}
540
541 PyOperationIterator dunderIter() {
542 parentOperation->checkValid();
543 return PyOperationIterator(parentOperation,
544 mlirBlockGetFirstOperation(block));
545 }
546
547 intptr_t dunderLen() {
548 parentOperation->checkValid();
549 intptr_t count = 0;
550 MlirOperation childOp = mlirBlockGetFirstOperation(block);
551 while (!mlirOperationIsNull(childOp)) {
552 count += 1;
553 childOp = mlirOperationGetNextInBlock(childOp);
554 }
555 return count;
556 }
557
558 nb::object dunderGetItem(intptr_t index) {
559 parentOperation->checkValid();
560 if (index < 0) {
561 index += dunderLen();
562 }
563 if (index < 0) {
564 throw nb::index_error("attempt to access out of bounds operation");
565 }
566 MlirOperation childOp = mlirBlockGetFirstOperation(block);
567 while (!mlirOperationIsNull(childOp)) {
568 if (index == 0) {
569 return PyOperation::forOperation(parentOperation->getContext(), childOp)
570 ->createOpView();
571 }
572 childOp = mlirOperationGetNextInBlock(childOp);
573 index -= 1;
574 }
575 throw nb::index_error("attempt to access out of bounds operation");
576 }
577
578 static void bind(nb::module_ &m) {
579 nb::class_<PyOperationList>(m, "OperationList")
580 .def("__getitem__", &PyOperationList::dunderGetItem)
581 .def("__iter__", &PyOperationList::dunderIter)
582 .def("__len__", &PyOperationList::dunderLen);
583 }
584
585private:
586 PyOperationRef parentOperation;
587 MlirBlock block;
588};
589
590class PyOpOperand {
591public:
592 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
593
594 nb::object getOwner() {
595 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
596 PyMlirContextRef context =
597 PyMlirContext::forContext(mlirOperationGetContext(owner));
598 return PyOperation::forOperation(context, owner)->createOpView();
599 }
600
601 size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
602
603 static void bind(nb::module_ &m) {
604 nb::class_<PyOpOperand>(m, "OpOperand")
605 .def_prop_ro("owner", &PyOpOperand::getOwner)
606 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
607 }
608
609private:
610 MlirOpOperand opOperand;
611};
612
613class PyOpOperandIterator {
614public:
615 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
616
617 PyOpOperandIterator &dunderIter() { return *this; }
618
619 PyOpOperand dunderNext() {
620 if (mlirOpOperandIsNull(opOperand))
621 throw nb::stop_iteration();
622
623 PyOpOperand returnOpOperand(opOperand);
624 opOperand = mlirOpOperandGetNextUse(opOperand);
625 return returnOpOperand;
626 }
627
628 static void bind(nb::module_ &m) {
629 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
630 .def("__iter__", &PyOpOperandIterator::dunderIter)
631 .def("__next__", &PyOpOperandIterator::dunderNext);
632 }
633
634private:
635 MlirOpOperand opOperand;
636};
637
638} // namespace
639
640//------------------------------------------------------------------------------
641// PyMlirContext
642//------------------------------------------------------------------------------
643
644PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
645 nb::gil_scoped_acquire acquire;
646 nb::ft_lock_guard lock(live_contexts_mutex);
647 auto &liveContexts = getLiveContexts();
648 liveContexts[context.ptr] = this;
649}
650
651PyMlirContext::~PyMlirContext() {
652 // Note that the only public way to construct an instance is via the
653 // forContext method, which always puts the associated handle into
654 // liveContexts.
655 nb::gil_scoped_acquire acquire;
656 {
657 nb::ft_lock_guard lock(live_contexts_mutex);
658 getLiveContexts().erase(context.ptr);
659 }
660 mlirContextDestroy(context);
661}
662
663nb::object PyMlirContext::getCapsule() {
664 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
665}
666
667nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
668 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
669 if (mlirContextIsNull(rawContext))
670 throw nb::python_error();
671 return forContext(rawContext).releaseObject();
672}
673
674PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
675 nb::gil_scoped_acquire acquire;
676 nb::ft_lock_guard lock(live_contexts_mutex);
677 auto &liveContexts = getLiveContexts();
678 auto it = liveContexts.find(context.ptr);
679 if (it == liveContexts.end()) {
680 // Create.
681 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
682 nb::object pyRef = nb::cast(unownedContextWrapper);
683 assert(pyRef && "cast to nb::object failed");
684 liveContexts[context.ptr] = unownedContextWrapper;
685 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
686 }
687 // Use existing.
688 nb::object pyRef = nb::cast(it->second);
689 return PyMlirContextRef(it->second, std::move(pyRef));
690}
691
692nb::ft_mutex PyMlirContext::live_contexts_mutex;
693
694PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
695 static LiveContextMap liveContexts;
696 return liveContexts;
697}
698
699size_t PyMlirContext::getLiveCount() {
700 nb::ft_lock_guard lock(live_contexts_mutex);
701 return getLiveContexts().size();
702}
703
704size_t PyMlirContext::getLiveOperationCount() {
705 nb::ft_lock_guard lock(liveOperationsMutex);
706 return liveOperations.size();
707}
708
709std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
710 std::vector<PyOperation *> liveObjects;
711 nb::ft_lock_guard lock(liveOperationsMutex);
712 for (auto &entry : liveOperations)
713 liveObjects.push_back(entry.second.second);
714 return liveObjects;
715}
716
717size_t PyMlirContext::clearLiveOperations() {
718
719 LiveOperationMap operations;
720 {
721 nb::ft_lock_guard lock(liveOperationsMutex);
722 std::swap(operations, liveOperations);
723 }
724 for (auto &op : operations)
725 op.second.second->setInvalid();
726 size_t numInvalidated = operations.size();
727 return numInvalidated;
728}
729
730void PyMlirContext::clearOperation(MlirOperation op) {
731 PyOperation *py_op;
732 {
733 nb::ft_lock_guard lock(liveOperationsMutex);
734 auto it = liveOperations.find(op.ptr);
735 if (it == liveOperations.end()) {
736 return;
737 }
738 py_op = it->second.second;
739 liveOperations.erase(it);
740 }
741 py_op->setInvalid();
742}
743
744void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
745 typedef struct {
746 PyOperation &rootOp;
747 bool rootSeen;
748 } callBackData;
749 callBackData data{.rootOp: op.getOperation(), .rootSeen: false};
750 // Mark all ops below the op that the passmanager will be rooted
751 // at (but not op itself - note the preorder) as invalid.
752 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
753 void *userData) {
754 callBackData *data = static_cast<callBackData *>(userData);
755 if (LLVM_LIKELY(data->rootSeen))
756 data->rootOp.getOperation().getContext()->clearOperation(op);
757 else
758 data->rootSeen = true;
759 return MlirWalkResult::MlirWalkResultAdvance;
760 };
761 mlirOperationWalk(op.getOperation(), invalidatingCallback,
762 static_cast<void *>(&data), MlirWalkPreOrder);
763}
764void PyMlirContext::clearOperationsInside(MlirOperation op) {
765 PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
766 clearOperationsInside(opRef->getOperation());
767}
768
769void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
770 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
771 void *userData) {
772 PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
773 contextRef->clearOperation(op);
774 return MlirWalkResult::MlirWalkResultAdvance;
775 };
776 mlirOperationWalk(op.getOperation(), invalidatingCallback,
777 &op.getOperation().getContext(), MlirWalkPreOrder);
778}
779
780size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
781
782nb::object PyMlirContext::contextEnter(nb::object context) {
783 return PyThreadContextEntry::pushContext(context);
784}
785
786void PyMlirContext::contextExit(const nb::object &excType,
787 const nb::object &excVal,
788 const nb::object &excTb) {
789 PyThreadContextEntry::popContext(context&: *this);
790}
791
792nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
793 // Note that ownership is transferred to the delete callback below by way of
794 // an explicit inc_ref (borrow).
795 PyDiagnosticHandler *pyHandler =
796 new PyDiagnosticHandler(get(), std::move(callback));
797 nb::object pyHandlerObject =
798 nb::cast(pyHandler, nb::rv_policy::take_ownership);
799 pyHandlerObject.inc_ref();
800
801 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
802 // guaranteed to be known to pybind.
803 auto handlerCallback =
804 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
805 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
806 nb::object pyDiagnosticObject =
807 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
808
809 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
810 bool result = false;
811 {
812 // Since this can be called from arbitrary C++ contexts, always get the
813 // gil.
814 nb::gil_scoped_acquire gil;
815 try {
816 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
817 } catch (std::exception &e) {
818 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
819 e.what());
820 pyHandler->hadError = true;
821 }
822 }
823
824 pyDiagnostic->invalidate();
825 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
826 };
827 auto deleteCallback = +[](void *userData) {
828 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
829 assert(pyHandler->registeredID && "handler is not registered");
830 pyHandler->registeredID.reset();
831
832 // Decrement reference, balancing the inc_ref() above.
833 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
834 pyHandlerObject.dec_ref();
835 };
836
837 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
838 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
839 return pyHandlerObject;
840}
841
842MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
843 void *userData) {
844 auto *self = static_cast<ErrorCapture *>(userData);
845 // Check if the context requested we emit errors instead of capturing them.
846 if (self->ctx->emitErrorDiagnostics)
847 return mlirLogicalResultFailure();
848
849 if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
850 return mlirLogicalResultFailure();
851
852 self->errors.emplace_back(args: PyDiagnostic(diag).getInfo());
853 return mlirLogicalResultSuccess();
854}
855
856PyMlirContext &DefaultingPyMlirContext::resolve() {
857 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
858 if (!context) {
859 throw std::runtime_error(
860 "An MLIR function requires a Context but none was provided in the call "
861 "or from the surrounding environment. Either pass to the function with "
862 "a 'context=' argument or establish a default using 'with Context():'");
863 }
864 return *context;
865}
866
867//------------------------------------------------------------------------------
868// PyThreadContextEntry management
869//------------------------------------------------------------------------------
870
871std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
872 static thread_local std::vector<PyThreadContextEntry> stack;
873 return stack;
874}
875
876PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
877 auto &stack = getStack();
878 if (stack.empty())
879 return nullptr;
880 return &stack.back();
881}
882
883void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
884 nb::object insertionPoint,
885 nb::object location) {
886 auto &stack = getStack();
887 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
888 std::move(location));
889 // If the new stack has more than one entry and the context of the new top
890 // entry matches the previous, copy the insertionPoint and location from the
891 // previous entry if missing from the new top entry.
892 if (stack.size() > 1) {
893 auto &prev = *(stack.rbegin() + 1);
894 auto &current = stack.back();
895 if (current.context.is(prev.context)) {
896 // Default non-context objects from the previous entry.
897 if (!current.insertionPoint)
898 current.insertionPoint = prev.insertionPoint;
899 if (!current.location)
900 current.location = prev.location;
901 }
902 }
903}
904
905PyMlirContext *PyThreadContextEntry::getContext() {
906 if (!context)
907 return nullptr;
908 return nb::cast<PyMlirContext *>(context);
909}
910
911PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
912 if (!insertionPoint)
913 return nullptr;
914 return nb::cast<PyInsertionPoint *>(insertionPoint);
915}
916
917PyLocation *PyThreadContextEntry::getLocation() {
918 if (!location)
919 return nullptr;
920 return nb::cast<PyLocation *>(location);
921}
922
923PyMlirContext *PyThreadContextEntry::getDefaultContext() {
924 auto *tos = getTopOfStack();
925 return tos ? tos->getContext() : nullptr;
926}
927
928PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
929 auto *tos = getTopOfStack();
930 return tos ? tos->getInsertionPoint() : nullptr;
931}
932
933PyLocation *PyThreadContextEntry::getDefaultLocation() {
934 auto *tos = getTopOfStack();
935 return tos ? tos->getLocation() : nullptr;
936}
937
938nb::object PyThreadContextEntry::pushContext(nb::object context) {
939 push(FrameKind::Context, /*context=*/context,
940 /*insertionPoint=*/nb::object(),
941 /*location=*/nb::object());
942 return context;
943}
944
945void PyThreadContextEntry::popContext(PyMlirContext &context) {
946 auto &stack = getStack();
947 if (stack.empty())
948 throw std::runtime_error("Unbalanced Context enter/exit");
949 auto &tos = stack.back();
950 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
951 throw std::runtime_error("Unbalanced Context enter/exit");
952 stack.pop_back();
953}
954
955nb::object
956PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
957 PyInsertionPoint &insertionPoint =
958 nb::cast<PyInsertionPoint &>(insertionPointObj);
959 nb::object contextObj =
960 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
961 push(FrameKind::InsertionPoint,
962 /*context=*/contextObj,
963 /*insertionPoint=*/insertionPointObj,
964 /*location=*/nb::object());
965 return insertionPointObj;
966}
967
968void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
969 auto &stack = getStack();
970 if (stack.empty())
971 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
972 auto &tos = stack.back();
973 if (tos.frameKind != FrameKind::InsertionPoint &&
974 tos.getInsertionPoint() != &insertionPoint)
975 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
976 stack.pop_back();
977}
978
979nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
980 PyLocation &location = nb::cast<PyLocation &>(locationObj);
981 nb::object contextObj = location.getContext().getObject();
982 push(FrameKind::Location, /*context=*/contextObj,
983 /*insertionPoint=*/nb::object(),
984 /*location=*/locationObj);
985 return locationObj;
986}
987
988void PyThreadContextEntry::popLocation(PyLocation &location) {
989 auto &stack = getStack();
990 if (stack.empty())
991 throw std::runtime_error("Unbalanced Location enter/exit");
992 auto &tos = stack.back();
993 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
994 throw std::runtime_error("Unbalanced Location enter/exit");
995 stack.pop_back();
996}
997
998//------------------------------------------------------------------------------
999// PyDiagnostic*
1000//------------------------------------------------------------------------------
1001
1002void PyDiagnostic::invalidate() {
1003 valid = false;
1004 if (materializedNotes) {
1005 for (nb::handle noteObject : *materializedNotes) {
1006 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
1007 note->invalidate();
1008 }
1009 }
1010}
1011
1012PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
1013 nb::object callback)
1014 : context(context), callback(std::move(callback)) {}
1015
1016PyDiagnosticHandler::~PyDiagnosticHandler() = default;
1017
1018void PyDiagnosticHandler::detach() {
1019 if (!registeredID)
1020 return;
1021 MlirDiagnosticHandlerID localID = *registeredID;
1022 mlirContextDetachDiagnosticHandler(context, localID);
1023 assert(!registeredID && "should have unregistered");
1024 // Not strictly necessary but keeps stale pointers from being around to cause
1025 // issues.
1026 context = {nullptr};
1027}
1028
1029void PyDiagnostic::checkValid() {
1030 if (!valid) {
1031 throw std::invalid_argument(
1032 "Diagnostic is invalid (used outside of callback)");
1033 }
1034}
1035
1036MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
1037 checkValid();
1038 return mlirDiagnosticGetSeverity(diagnostic);
1039}
1040
1041PyLocation PyDiagnostic::getLocation() {
1042 checkValid();
1043 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1044 MlirContext context = mlirLocationGetContext(loc);
1045 return PyLocation(PyMlirContext::forContext(context), loc);
1046}
1047
1048nb::str PyDiagnostic::getMessage() {
1049 checkValid();
1050 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
1051 PyFileAccumulator accum(fileObject, /*binary=*/false);
1052 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
1053 return nb::cast<nb::str>(fileObject.attr("getvalue")());
1054}
1055
1056nb::tuple PyDiagnostic::getNotes() {
1057 checkValid();
1058 if (materializedNotes)
1059 return *materializedNotes;
1060 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
1061 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
1062 for (intptr_t i = 0; i < numNotes; ++i) {
1063 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1064 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
1065 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
1066 }
1067 materializedNotes = std::move(notes);
1068
1069 return *materializedNotes;
1070}
1071
1072PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
1073 std::vector<DiagnosticInfo> notes;
1074 for (nb::handle n : getNotes())
1075 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
1076 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
1077 std::move(notes)};
1078}
1079
1080//------------------------------------------------------------------------------
1081// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1082//------------------------------------------------------------------------------
1083
1084MlirDialect PyDialects::getDialectForKey(const std::string &key,
1085 bool attrError) {
1086 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1087 {key.data(), key.size()});
1088 if (mlirDialectIsNull(dialect)) {
1089 std::string msg = (Twine("Dialect '") + key + "' not found").str();
1090 if (attrError)
1091 throw nb::attribute_error(msg.c_str());
1092 throw nb::index_error(msg.c_str());
1093 }
1094 return dialect;
1095}
1096
1097nb::object PyDialectRegistry::getCapsule() {
1098 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
1099}
1100
1101PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) {
1102 MlirDialectRegistry rawRegistry =
1103 mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1104 if (mlirDialectRegistryIsNull(rawRegistry))
1105 throw nb::python_error();
1106 return PyDialectRegistry(rawRegistry);
1107}
1108
1109//------------------------------------------------------------------------------
1110// PyLocation
1111//------------------------------------------------------------------------------
1112
1113nb::object PyLocation::getCapsule() {
1114 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
1115}
1116
1117PyLocation PyLocation::createFromCapsule(nb::object capsule) {
1118 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1119 if (mlirLocationIsNull(rawLoc))
1120 throw nb::python_error();
1121 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
1122 rawLoc);
1123}
1124
1125nb::object PyLocation::contextEnter(nb::object locationObj) {
1126 return PyThreadContextEntry::pushLocation(locationObj);
1127}
1128
1129void PyLocation::contextExit(const nb::object &excType,
1130 const nb::object &excVal,
1131 const nb::object &excTb) {
1132 PyThreadContextEntry::popLocation(location&: *this);
1133}
1134
1135PyLocation &DefaultingPyLocation::resolve() {
1136 auto *location = PyThreadContextEntry::getDefaultLocation();
1137 if (!location) {
1138 throw std::runtime_error(
1139 "An MLIR function requires a Location but none was provided in the "
1140 "call or from the surrounding environment. Either pass to the function "
1141 "with a 'loc=' argument or establish a default using 'with loc:'");
1142 }
1143 return *location;
1144}
1145
1146//------------------------------------------------------------------------------
1147// PyModule
1148//------------------------------------------------------------------------------
1149
1150PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1151 : BaseContextObject(std::move(contextRef)), module(module) {}
1152
1153PyModule::~PyModule() {
1154 nb::gil_scoped_acquire acquire;
1155 auto &liveModules = getContext()->liveModules;
1156 assert(liveModules.count(module.ptr) == 1 &&
1157 "destroying module not in live map");
1158 liveModules.erase(module.ptr);
1159 mlirModuleDestroy(module);
1160}
1161
1162PyModuleRef PyModule::forModule(MlirModule module) {
1163 MlirContext context = mlirModuleGetContext(module);
1164 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1165
1166 nb::gil_scoped_acquire acquire;
1167 auto &liveModules = contextRef->liveModules;
1168 auto it = liveModules.find(module.ptr);
1169 if (it == liveModules.end()) {
1170 // Create.
1171 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1172 // Note that the default return value policy on cast is automatic_reference,
1173 // which does not take ownership (delete will not be called).
1174 // Just be explicit.
1175 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1176 unownedModule->handle = pyRef;
1177 liveModules[module.ptr] =
1178 std::make_pair(unownedModule->handle, unownedModule);
1179 return PyModuleRef(unownedModule, std::move(pyRef));
1180 }
1181 // Use existing.
1182 PyModule *existing = it->second.second;
1183 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1184 return PyModuleRef(existing, std::move(pyRef));
1185}
1186
1187nb::object PyModule::createFromCapsule(nb::object capsule) {
1188 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1189 if (mlirModuleIsNull(rawModule))
1190 throw nb::python_error();
1191 return forModule(rawModule).releaseObject();
1192}
1193
1194nb::object PyModule::getCapsule() {
1195 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1196}
1197
1198//------------------------------------------------------------------------------
1199// PyOperation
1200//------------------------------------------------------------------------------
1201
1202PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1203 : BaseContextObject(std::move(contextRef)), operation(operation) {}
1204
1205PyOperation::~PyOperation() {
1206 // If the operation has already been invalidated there is nothing to do.
1207 if (!valid)
1208 return;
1209
1210 // Otherwise, invalidate the operation and remove it from live map when it is
1211 // attached.
1212 if (isAttached()) {
1213 getContext()->clearOperation(*this);
1214 } else {
1215 // And destroy it when it is detached, i.e. owned by Python, in which case
1216 // all nested operations must be invalidated at removed from the live map as
1217 // well.
1218 erase();
1219 }
1220}
1221
1222namespace {
1223
1224// Constructs a new object of type T in-place on the Python heap, returning a
1225// PyObjectRef to it, loosely analogous to std::make_shared<T>().
1226template <typename T, class... Args>
1227PyObjectRef<T> makeObjectRef(Args &&...args) {
1228 nb::handle type = nb::type<T>();
1229 nb::object instance = nb::inst_alloc(type);
1230 T *ptr = nb::inst_ptr<T>(instance);
1231 new (ptr) T(std::forward<Args>(args)...);
1232 nb::inst_mark_ready(instance);
1233 return PyObjectRef<T>(ptr, std::move(instance));
1234}
1235
1236} // namespace
1237
1238PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1239 MlirOperation operation,
1240 nb::object parentKeepAlive) {
1241 // Create.
1242 PyOperationRef unownedOperation =
1243 makeObjectRef<PyOperation>(std::move(contextRef), operation);
1244 unownedOperation->handle = unownedOperation.getObject();
1245 if (parentKeepAlive) {
1246 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1247 }
1248 return unownedOperation;
1249}
1250
1251PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
1252 MlirOperation operation,
1253 nb::object parentKeepAlive) {
1254 nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1255 auto &liveOperations = contextRef->liveOperations;
1256 auto it = liveOperations.find(operation.ptr);
1257 if (it == liveOperations.end()) {
1258 // Create.
1259 PyOperationRef result = createInstance(std::move(contextRef), operation,
1260 std::move(parentKeepAlive));
1261 liveOperations[operation.ptr] =
1262 std::make_pair(result.getObject(), result.get());
1263 return result;
1264 }
1265 // Use existing.
1266 PyOperation *existing = it->second.second;
1267 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1268 return PyOperationRef(existing, std::move(pyRef));
1269}
1270
1271PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
1272 MlirOperation operation,
1273 nb::object parentKeepAlive) {
1274 nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1275 auto &liveOperations = contextRef->liveOperations;
1276 assert(liveOperations.count(operation.ptr) == 0 &&
1277 "cannot create detached operation that already exists");
1278 (void)liveOperations;
1279 PyOperationRef created = createInstance(std::move(contextRef), operation,
1280 std::move(parentKeepAlive));
1281 liveOperations[operation.ptr] =
1282 std::make_pair(created.getObject(), created.get());
1283 created->attached = false;
1284 return created;
1285}
1286
1287PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
1288 const std::string &sourceStr,
1289 const std::string &sourceName) {
1290 PyMlirContext::ErrorCapture errors(contextRef);
1291 MlirOperation op =
1292 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1293 toMlirStringRef(sourceName));
1294 if (mlirOperationIsNull(op))
1295 throw MLIRError("Unable to parse operation assembly", errors.take());
1296 return PyOperation::createDetached(std::move(contextRef), op);
1297}
1298
1299void PyOperation::checkValid() const {
1300 if (!valid) {
1301 throw std::runtime_error("the operation has been invalidated");
1302 }
1303}
1304
1305void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1306 bool enableDebugInfo, bool prettyDebugInfo,
1307 bool printGenericOpForm, bool useLocalScope,
1308 bool useNameLocAsPrefix, bool assumeVerified,
1309 nb::object fileObject, bool binary,
1310 bool skipRegions) {
1311 PyOperation &operation = getOperation();
1312 operation.checkValid();
1313 if (fileObject.is_none())
1314 fileObject = nb::module_::import_("sys").attr("stdout");
1315
1316 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1317 if (largeElementsLimit) {
1318 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1319 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit);
1320 }
1321 if (enableDebugInfo)
1322 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1323 /*prettyForm=*/prettyDebugInfo);
1324 if (printGenericOpForm)
1325 mlirOpPrintingFlagsPrintGenericOpForm(flags);
1326 if (useLocalScope)
1327 mlirOpPrintingFlagsUseLocalScope(flags);
1328 if (assumeVerified)
1329 mlirOpPrintingFlagsAssumeVerified(flags);
1330 if (skipRegions)
1331 mlirOpPrintingFlagsSkipRegions(flags);
1332 if (useNameLocAsPrefix)
1333 mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
1334
1335 PyFileAccumulator accum(fileObject, binary);
1336 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1337 accum.getUserData());
1338 mlirOpPrintingFlagsDestroy(flags);
1339}
1340
1341void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1342 bool binary) {
1343 PyOperation &operation = getOperation();
1344 operation.checkValid();
1345 if (fileObject.is_none())
1346 fileObject = nb::module_::import_("sys").attr("stdout");
1347 PyFileAccumulator accum(fileObject, binary);
1348 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1349 accum.getUserData());
1350}
1351
1352void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1353 std::optional<int64_t> bytecodeVersion) {
1354 PyOperation &operation = getOperation();
1355 operation.checkValid();
1356 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1357
1358 if (!bytecodeVersion.has_value())
1359 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1360 accum.getUserData());
1361
1362 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1363 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1364 MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
1365 operation, config, accum.getCallback(), accum.getUserData());
1366 mlirBytecodeWriterConfigDestroy(config);
1367 if (mlirLogicalResultIsFailure(res))
1368 throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1369 Twine(*bytecodeVersion))
1370 .str()
1371 .c_str());
1372}
1373
1374void PyOperationBase::walk(
1375 std::function<MlirWalkResult(MlirOperation)> callback,
1376 MlirWalkOrder walkOrder) {
1377 PyOperation &operation = getOperation();
1378 operation.checkValid();
1379 struct UserData {
1380 std::function<MlirWalkResult(MlirOperation)> callback;
1381 bool gotException;
1382 std::string exceptionWhat;
1383 nb::object exceptionType;
1384 };
1385 UserData userData{callback, false, {}, {}};
1386 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1387 void *userData) {
1388 UserData *calleeUserData = static_cast<UserData *>(userData);
1389 try {
1390 return (calleeUserData->callback)(op);
1391 } catch (nb::python_error &e) {
1392 calleeUserData->gotException = true;
1393 calleeUserData->exceptionWhat = std::string(e.what());
1394 calleeUserData->exceptionType = nb::borrow(e.type());
1395 return MlirWalkResult::MlirWalkResultInterrupt;
1396 }
1397 };
1398 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1399 if (userData.gotException) {
1400 std::string message("Exception raised in callback: ");
1401 message.append(str: userData.exceptionWhat);
1402 throw std::runtime_error(message);
1403 }
1404}
1405
1406nb::object PyOperationBase::getAsm(bool binary,
1407 std::optional<int64_t> largeElementsLimit,
1408 bool enableDebugInfo, bool prettyDebugInfo,
1409 bool printGenericOpForm, bool useLocalScope,
1410 bool useNameLocAsPrefix, bool assumeVerified,
1411 bool skipRegions) {
1412 nb::object fileObject;
1413 if (binary) {
1414 fileObject = nb::module_::import_("io").attr("BytesIO")();
1415 } else {
1416 fileObject = nb::module_::import_("io").attr("StringIO")();
1417 }
1418 print(/*largeElementsLimit=*/largeElementsLimit,
1419 /*enableDebugInfo=*/enableDebugInfo,
1420 /*prettyDebugInfo=*/prettyDebugInfo,
1421 /*printGenericOpForm=*/printGenericOpForm,
1422 /*useLocalScope=*/useLocalScope,
1423 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1424 /*assumeVerified=*/assumeVerified,
1425 /*fileObject=*/fileObject,
1426 /*binary=*/binary,
1427 /*skipRegions=*/skipRegions);
1428
1429 return fileObject.attr("getvalue")();
1430}
1431
1432void PyOperationBase::moveAfter(PyOperationBase &other) {
1433 PyOperation &operation = getOperation();
1434 PyOperation &otherOp = other.getOperation();
1435 operation.checkValid();
1436 otherOp.checkValid();
1437 mlirOperationMoveAfter(operation, otherOp);
1438 operation.parentKeepAlive = otherOp.parentKeepAlive;
1439}
1440
1441void PyOperationBase::moveBefore(PyOperationBase &other) {
1442 PyOperation &operation = getOperation();
1443 PyOperation &otherOp = other.getOperation();
1444 operation.checkValid();
1445 otherOp.checkValid();
1446 mlirOperationMoveBefore(operation, otherOp);
1447 operation.parentKeepAlive = otherOp.parentKeepAlive;
1448}
1449
1450bool PyOperationBase::verify() {
1451 PyOperation &op = getOperation();
1452 PyMlirContext::ErrorCapture errors(op.getContext());
1453 if (!mlirOperationVerify(op.get()))
1454 throw MLIRError("Verification failed", errors.take());
1455 return true;
1456}
1457
1458std::optional<PyOperationRef> PyOperation::getParentOperation() {
1459 checkValid();
1460 if (!isAttached())
1461 throw nb::value_error("Detached operations have no parent");
1462 MlirOperation operation = mlirOperationGetParentOperation(get());
1463 if (mlirOperationIsNull(operation))
1464 return {};
1465 return PyOperation::forOperation(getContext(), operation);
1466}
1467
1468PyBlock PyOperation::getBlock() {
1469 checkValid();
1470 std::optional<PyOperationRef> parentOperation = getParentOperation();
1471 MlirBlock block = mlirOperationGetBlock(get());
1472 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1473 assert(parentOperation && "Operation has no parent");
1474 return PyBlock{std::move(*parentOperation), block};
1475}
1476
1477nb::object PyOperation::getCapsule() {
1478 checkValid();
1479 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1480}
1481
1482nb::object PyOperation::createFromCapsule(nb::object capsule) {
1483 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1484 if (mlirOperationIsNull(rawOperation))
1485 throw nb::python_error();
1486 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1487 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1488 .releaseObject();
1489}
1490
1491static void maybeInsertOperation(PyOperationRef &op,
1492 const nb::object &maybeIp) {
1493 // InsertPoint active?
1494 if (!maybeIp.is(nb::cast(false))) {
1495 PyInsertionPoint *ip;
1496 if (maybeIp.is_none()) {
1497 ip = PyThreadContextEntry::getDefaultInsertionPoint();
1498 } else {
1499 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1500 }
1501 if (ip)
1502 ip->insert(*op.get());
1503 }
1504}
1505
1506nb::object PyOperation::create(std::string_view name,
1507 std::optional<std::vector<PyType *>> results,
1508 llvm::ArrayRef<MlirValue> operands,
1509 std::optional<nb::dict> attributes,
1510 std::optional<std::vector<PyBlock *>> successors,
1511 int regions, DefaultingPyLocation location,
1512 const nb::object &maybeIp, bool inferType) {
1513 llvm::SmallVector<MlirType, 4> mlirResults;
1514 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1515 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1516
1517 // General parameter validation.
1518 if (regions < 0)
1519 throw nb::value_error("number of regions must be >= 0");
1520
1521 // Unpack/validate results.
1522 if (results) {
1523 mlirResults.reserve(results->size());
1524 for (PyType *result : *results) {
1525 // TODO: Verify result type originate from the same context.
1526 if (!result)
1527 throw nb::value_error("result type cannot be None");
1528 mlirResults.push_back(*result);
1529 }
1530 }
1531 // Unpack/validate attributes.
1532 if (attributes) {
1533 mlirAttributes.reserve(attributes->size());
1534 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1535 std::string key;
1536 try {
1537 key = nb::cast<std::string>(it.first);
1538 } catch (nb::cast_error &err) {
1539 std::string msg = "Invalid attribute key (not a string) when "
1540 "attempting to create the operation \"" +
1541 std::string(name) + "\" (" + err.what() + ")";
1542 throw nb::type_error(msg.c_str());
1543 }
1544 try {
1545 auto &attribute = nb::cast<PyAttribute &>(it.second);
1546 // TODO: Verify attribute originates from the same context.
1547 mlirAttributes.emplace_back(std::move(key), attribute);
1548 } catch (nb::cast_error &err) {
1549 std::string msg = "Invalid attribute value for the key \"" + key +
1550 "\" when attempting to create the operation \"" +
1551 std::string(name) + "\" (" + err.what() + ")";
1552 throw nb::type_error(msg.c_str());
1553 } catch (std::runtime_error &) {
1554 // This exception seems thrown when the value is "None".
1555 std::string msg =
1556 "Found an invalid (`None`?) attribute value for the key \"" + key +
1557 "\" when attempting to create the operation \"" +
1558 std::string(name) + "\"";
1559 throw std::runtime_error(msg);
1560 }
1561 }
1562 }
1563 // Unpack/validate successors.
1564 if (successors) {
1565 mlirSuccessors.reserve(successors->size());
1566 for (auto *successor : *successors) {
1567 // TODO: Verify successor originate from the same context.
1568 if (!successor)
1569 throw nb::value_error("successor block cannot be None");
1570 mlirSuccessors.push_back(successor->get());
1571 }
1572 }
1573
1574 // Apply unpacked/validated to the operation state. Beyond this
1575 // point, exceptions cannot be thrown or else the state will leak.
1576 MlirOperationState state =
1577 mlirOperationStateGet(toMlirStringRef(name), location);
1578 if (!operands.empty())
1579 mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1580 state.enableResultTypeInference = inferType;
1581 if (!mlirResults.empty())
1582 mlirOperationStateAddResults(&state, mlirResults.size(),
1583 mlirResults.data());
1584 if (!mlirAttributes.empty()) {
1585 // Note that the attribute names directly reference bytes in
1586 // mlirAttributes, so that vector must not be changed from here
1587 // on.
1588 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1589 mlirNamedAttributes.reserve(mlirAttributes.size());
1590 for (auto &it : mlirAttributes)
1591 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1592 mlirIdentifierGet(mlirAttributeGetContext(it.second),
1593 toMlirStringRef(it.first)),
1594 it.second));
1595 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1596 mlirNamedAttributes.data());
1597 }
1598 if (!mlirSuccessors.empty())
1599 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1600 mlirSuccessors.data());
1601 if (regions) {
1602 llvm::SmallVector<MlirRegion, 4> mlirRegions;
1603 mlirRegions.resize(regions);
1604 for (int i = 0; i < regions; ++i)
1605 mlirRegions[i] = mlirRegionCreate();
1606 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1607 mlirRegions.data());
1608 }
1609
1610 // Construct the operation.
1611 MlirOperation operation = mlirOperationCreate(&state);
1612 if (!operation.ptr)
1613 throw nb::value_error("Operation creation failed");
1614 PyOperationRef created =
1615 PyOperation::createDetached(location->getContext(), operation);
1616 maybeInsertOperation(created, maybeIp);
1617
1618 return created.getObject();
1619}
1620
1621nb::object PyOperation::clone(const nb::object &maybeIp) {
1622 MlirOperation clonedOperation = mlirOperationClone(operation);
1623 PyOperationRef cloned =
1624 PyOperation::createDetached(getContext(), clonedOperation);
1625 maybeInsertOperation(cloned, maybeIp);
1626
1627 return cloned->createOpView();
1628}
1629
1630nb::object PyOperation::createOpView() {
1631 checkValid();
1632 MlirIdentifier ident = mlirOperationGetName(get());
1633 MlirStringRef identStr = mlirIdentifierStr(ident);
1634 auto operationCls = PyGlobals::get().lookupOperationClass(
1635 StringRef(identStr.data, identStr.length));
1636 if (operationCls)
1637 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1638 return nb::cast(PyOpView(getRef().getObject()));
1639}
1640
1641void PyOperation::erase() {
1642 checkValid();
1643 getContext()->clearOperationAndInside(*this);
1644 mlirOperationDestroy(operation);
1645}
1646
1647namespace {
1648/// CRTP base class for Python MLIR values that subclass Value and should be
1649/// castable from it. The value hierarchy is one level deep and is not supposed
1650/// to accommodate other levels unless core MLIR changes.
1651template <typename DerivedTy>
1652class PyConcreteValue : public PyValue {
1653public:
1654 // Derived classes must define statics for:
1655 // IsAFunctionTy isaFunction
1656 // const char *pyClassName
1657 // and redefine bindDerived.
1658 using ClassTy = nb::class_<DerivedTy, PyValue>;
1659 using IsAFunctionTy = bool (*)(MlirValue);
1660
1661 PyConcreteValue() = default;
1662 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1663 : PyValue(operationRef, value) {}
1664 PyConcreteValue(PyValue &orig)
1665 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1666
1667 /// Attempts to cast the original value to the derived type and throws on
1668 /// type mismatches.
1669 static MlirValue castFrom(PyValue &orig) {
1670 if (!DerivedTy::isaFunction(orig.get())) {
1671 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1672 throw nb::value_error((Twine("Cannot cast value to ") +
1673 DerivedTy::pyClassName + " (from " + origRepr +
1674 ")")
1675 .str()
1676 .c_str());
1677 }
1678 return orig.get();
1679 }
1680
1681 /// Binds the Python module objects to functions of this class.
1682 static void bind(nb::module_ &m) {
1683 auto cls = ClassTy(m, DerivedTy::pyClassName);
1684 cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1685 cls.def_static(
1686 "isinstance",
1687 [](PyValue &otherValue) -> bool {
1688 return DerivedTy::isaFunction(otherValue);
1689 },
1690 nb::arg("other_value"));
1691 cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
1692 [](DerivedTy &self) { return self.maybeDownCast(); });
1693 DerivedTy::bindDerived(cls);
1694 }
1695
1696 /// Implemented by derived classes to add methods to the Python subclass.
1697 static void bindDerived(ClassTy &m) {}
1698};
1699
1700} // namespace
1701
1702/// Python wrapper for MlirOpResult.
1703class PyOpResult : public PyConcreteValue<PyOpResult> {
1704public:
1705 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1706 static constexpr const char *pyClassName = "OpResult";
1707 using PyConcreteValue::PyConcreteValue;
1708
1709 static void bindDerived(ClassTy &c) {
1710 c.def_prop_ro("owner", [](PyOpResult &self) {
1711 assert(
1712 mlirOperationEqual(self.getParentOperation()->get(),
1713 mlirOpResultGetOwner(self.get())) &&
1714 "expected the owner of the value in Python to match that in the IR");
1715 return self.getParentOperation().getObject();
1716 });
1717 c.def_prop_ro("result_number", [](PyOpResult &self) {
1718 return mlirOpResultGetResultNumber(self.get());
1719 });
1720 }
1721};
1722
1723/// Returns the list of types of the values held by container.
1724template <typename Container>
1725static std::vector<MlirType> getValueTypes(Container &container,
1726 PyMlirContextRef &context) {
1727 std::vector<MlirType> result;
1728 result.reserve(container.size());
1729 for (int i = 0, e = container.size(); i < e; ++i) {
1730 result.push_back(mlirValueGetType(container.getElement(i).get()));
1731 }
1732 return result;
1733}
1734
1735/// A list of operation results. Internally, these are stored as consecutive
1736/// elements, random access is cheap. The (returned) result list is associated
1737/// with the operation whose results these are, and thus extends the lifetime of
1738/// this operation.
1739class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1740public:
1741 static constexpr const char *pyClassName = "OpResultList";
1742 using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
1743
1744 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1745 intptr_t length = -1, intptr_t step = 1)
1746 : Sliceable(startIndex,
1747 length == -1 ? mlirOperationGetNumResults(operation->get())
1748 : length,
1749 step),
1750 operation(std::move(operation)) {}
1751
1752 static void bindDerived(ClassTy &c) {
1753 c.def_prop_ro("types", [](PyOpResultList &self) {
1754 return getValueTypes(self, self.operation->getContext());
1755 });
1756 c.def_prop_ro("owner", [](PyOpResultList &self) {
1757 return self.operation->createOpView();
1758 });
1759 }
1760
1761 PyOperationRef &getOperation() { return operation; }
1762
1763private:
1764 /// Give the parent CRTP class access to hook implementations below.
1765 friend class Sliceable<PyOpResultList, PyOpResult>;
1766
1767 intptr_t getRawNumElements() {
1768 operation->checkValid();
1769 return mlirOperationGetNumResults(operation->get());
1770 }
1771
1772 PyOpResult getRawElement(intptr_t index) {
1773 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1774 return PyOpResult(value);
1775 }
1776
1777 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1778 return PyOpResultList(operation, startIndex, length, step);
1779 }
1780
1781 PyOperationRef operation;
1782};
1783
1784//------------------------------------------------------------------------------
1785// PyOpView
1786//------------------------------------------------------------------------------
1787
1788static void populateResultTypes(StringRef name, nb::list resultTypeList,
1789 const nb::object &resultSegmentSpecObj,
1790 std::vector<int32_t> &resultSegmentLengths,
1791 std::vector<PyType *> &resultTypes) {
1792 resultTypes.reserve(n: resultTypeList.size());
1793 if (resultSegmentSpecObj.is_none()) {
1794 // Non-variadic result unpacking.
1795 for (const auto &it : llvm::enumerate(resultTypeList)) {
1796 try {
1797 resultTypes.push_back(nb::cast<PyType *>(it.value()));
1798 if (!resultTypes.back())
1799 throw nb::cast_error();
1800 } catch (nb::cast_error &err) {
1801 throw nb::value_error((llvm::Twine("Result ") +
1802 llvm::Twine(it.index()) + " of operation \"" +
1803 name + "\" must be a Type (" + err.what() + ")")
1804 .str()
1805 .c_str());
1806 }
1807 }
1808 } else {
1809 // Sized result unpacking.
1810 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1811 if (resultSegmentSpec.size() != resultTypeList.size()) {
1812 throw nb::value_error((llvm::Twine("Operation \"") + name +
1813 "\" requires " +
1814 llvm::Twine(resultSegmentSpec.size()) +
1815 " result segments but was provided " +
1816 llvm::Twine(resultTypeList.size()))
1817 .str()
1818 .c_str());
1819 }
1820 resultSegmentLengths.reserve(n: resultTypeList.size());
1821 for (const auto &it :
1822 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1823 int segmentSpec = std::get<1>(it.value());
1824 if (segmentSpec == 1 || segmentSpec == 0) {
1825 // Unpack unary element.
1826 try {
1827 auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1828 if (resultType) {
1829 resultTypes.push_back(resultType);
1830 resultSegmentLengths.push_back(1);
1831 } else if (segmentSpec == 0) {
1832 // Allowed to be optional.
1833 resultSegmentLengths.push_back(0);
1834 } else {
1835 throw nb::value_error(
1836 (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1837 " of operation \"" + name +
1838 "\" must be a Type (was None and result is not optional)")
1839 .str()
1840 .c_str());
1841 }
1842 } catch (nb::cast_error &err) {
1843 throw nb::value_error((llvm::Twine("Result ") +
1844 llvm::Twine(it.index()) + " of operation \"" +
1845 name + "\" must be a Type (" + err.what() +
1846 ")")
1847 .str()
1848 .c_str());
1849 }
1850 } else if (segmentSpec == -1) {
1851 // Unpack sequence by appending.
1852 try {
1853 if (std::get<0>(it.value()).is_none()) {
1854 // Treat it as an empty list.
1855 resultSegmentLengths.push_back(0);
1856 } else {
1857 // Unpack the list.
1858 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1859 for (nb::handle segmentItem : segment) {
1860 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1861 if (!resultTypes.back()) {
1862 throw nb::type_error("contained a None item");
1863 }
1864 }
1865 resultSegmentLengths.push_back(nb::len(segment));
1866 }
1867 } catch (std::exception &err) {
1868 // NOTE: Sloppy to be using a catch-all here, but there are at least
1869 // three different unrelated exceptions that can be thrown in the
1870 // above "casts". Just keep the scope above small and catch them all.
1871 throw nb::value_error((llvm::Twine("Result ") +
1872 llvm::Twine(it.index()) + " of operation \"" +
1873 name + "\" must be a Sequence of Types (" +
1874 err.what() + ")")
1875 .str()
1876 .c_str());
1877 }
1878 } else {
1879 throw nb::value_error("Unexpected segment spec");
1880 }
1881 }
1882 }
1883}
1884
1885static MlirValue getUniqueResult(MlirOperation operation) {
1886 auto numResults = mlirOperationGetNumResults(operation);
1887 if (numResults != 1) {
1888 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1889 throw nb::value_error((Twine("Cannot call .result on operation ") +
1890 StringRef(name.data, name.length) + " which has " +
1891 Twine(numResults) +
1892 " results (it is only valid for operations with a "
1893 "single result)")
1894 .str()
1895 .c_str());
1896 }
1897 return mlirOperationGetResult(operation, 0);
1898}
1899
1900static MlirValue getOpResultOrValue(nb::handle operand) {
1901 if (operand.is_none()) {
1902 throw nb::value_error("contained a None item");
1903 }
1904 PyOperationBase *op;
1905 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1906 return getUniqueResult(op->getOperation());
1907 }
1908 PyOpResultList *opResultList;
1909 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1910 return getUniqueResult(opResultList->getOperation()->get());
1911 }
1912 PyValue *value;
1913 if (nb::try_cast<PyValue *>(operand, value)) {
1914 return value->get();
1915 }
1916 throw nb::value_error("is not a Value");
1917}
1918
1919nb::object PyOpView::buildGeneric(
1920 std::string_view name, std::tuple<int, bool> opRegionSpec,
1921 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1922 std::optional<nb::list> resultTypeList, nb::list operandList,
1923 std::optional<nb::dict> attributes,
1924 std::optional<std::vector<PyBlock *>> successors,
1925 std::optional<int> regions, DefaultingPyLocation location,
1926 const nb::object &maybeIp) {
1927 PyMlirContextRef context = location->getContext();
1928
1929 // Class level operation construction metadata.
1930 // Operand and result segment specs are either none, which does no
1931 // variadic unpacking, or a list of ints with segment sizes, where each
1932 // element is either a positive number (typically 1 for a scalar) or -1 to
1933 // indicate that it is derived from the length of the same-indexed operand
1934 // or result (implying that it is a list at that position).
1935 std::vector<int32_t> operandSegmentLengths;
1936 std::vector<int32_t> resultSegmentLengths;
1937
1938 // Validate/determine region count.
1939 int opMinRegionCount = std::get<0>(t&: opRegionSpec);
1940 bool opHasNoVariadicRegions = std::get<1>(t&: opRegionSpec);
1941 if (!regions) {
1942 regions = opMinRegionCount;
1943 }
1944 if (*regions < opMinRegionCount) {
1945 throw nb::value_error(
1946 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1947 llvm::Twine(opMinRegionCount) +
1948 " regions but was built with regions=" + llvm::Twine(*regions))
1949 .str()
1950 .c_str());
1951 }
1952 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1953 throw nb::value_error(
1954 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1955 llvm::Twine(opMinRegionCount) +
1956 " regions but was built with regions=" + llvm::Twine(*regions))
1957 .str()
1958 .c_str());
1959 }
1960
1961 // Unpack results.
1962 std::vector<PyType *> resultTypes;
1963 if (resultTypeList.has_value()) {
1964 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1965 resultSegmentLengths, resultTypes);
1966 }
1967
1968 // Unpack operands.
1969 llvm::SmallVector<MlirValue, 4> operands;
1970 operands.reserve(operands.size());
1971 if (operandSegmentSpecObj.is_none()) {
1972 // Non-sized operand unpacking.
1973 for (const auto &it : llvm::enumerate(operandList)) {
1974 try {
1975 operands.push_back(getOpResultOrValue(it.value()));
1976 } catch (nb::builtin_exception &err) {
1977 throw nb::value_error((llvm::Twine("Operand ") +
1978 llvm::Twine(it.index()) + " of operation \"" +
1979 name + "\" must be a Value (" + err.what() + ")")
1980 .str()
1981 .c_str());
1982 }
1983 }
1984 } else {
1985 // Sized operand unpacking.
1986 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1987 if (operandSegmentSpec.size() != operandList.size()) {
1988 throw nb::value_error((llvm::Twine("Operation \"") + name +
1989 "\" requires " +
1990 llvm::Twine(operandSegmentSpec.size()) +
1991 "operand segments but was provided " +
1992 llvm::Twine(operandList.size()))
1993 .str()
1994 .c_str());
1995 }
1996 operandSegmentLengths.reserve(n: operandList.size());
1997 for (const auto &it :
1998 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1999 int segmentSpec = std::get<1>(it.value());
2000 if (segmentSpec == 1 || segmentSpec == 0) {
2001 // Unpack unary element.
2002 auto &operand = std::get<0>(it.value());
2003 if (!operand.is_none()) {
2004 try {
2005
2006 operands.push_back(getOpResultOrValue(operand));
2007 } catch (nb::builtin_exception &err) {
2008 throw nb::value_error((llvm::Twine("Operand ") +
2009 llvm::Twine(it.index()) +
2010 " of operation \"" + name +
2011 "\" must be a Value (" + err.what() + ")")
2012 .str()
2013 .c_str());
2014 }
2015
2016 operandSegmentLengths.push_back(1);
2017 } else if (segmentSpec == 0) {
2018 // Allowed to be optional.
2019 operandSegmentLengths.push_back(0);
2020 } else {
2021 throw nb::value_error(
2022 (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
2023 " of operation \"" + name +
2024 "\" must be a Value (was None and operand is not optional)")
2025 .str()
2026 .c_str());
2027 }
2028 } else if (segmentSpec == -1) {
2029 // Unpack sequence by appending.
2030 try {
2031 if (std::get<0>(it.value()).is_none()) {
2032 // Treat it as an empty list.
2033 operandSegmentLengths.push_back(0);
2034 } else {
2035 // Unpack the list.
2036 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
2037 for (nb::handle segmentItem : segment) {
2038 operands.push_back(getOpResultOrValue(segmentItem));
2039 }
2040 operandSegmentLengths.push_back(nb::len(segment));
2041 }
2042 } catch (std::exception &err) {
2043 // NOTE: Sloppy to be using a catch-all here, but there are at least
2044 // three different unrelated exceptions that can be thrown in the
2045 // above "casts". Just keep the scope above small and catch them all.
2046 throw nb::value_error((llvm::Twine("Operand ") +
2047 llvm::Twine(it.index()) + " of operation \"" +
2048 name + "\" must be a Sequence of Values (" +
2049 err.what() + ")")
2050 .str()
2051 .c_str());
2052 }
2053 } else {
2054 throw nb::value_error("Unexpected segment spec");
2055 }
2056 }
2057 }
2058
2059 // Merge operand/result segment lengths into attributes if needed.
2060 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
2061 // Dup.
2062 if (attributes) {
2063 attributes = nb::dict(*attributes);
2064 } else {
2065 attributes = nb::dict();
2066 }
2067 if (attributes->contains("resultSegmentSizes") ||
2068 attributes->contains("operandSegmentSizes")) {
2069 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
2070 "'operandSegmentSizes' attribute is unsupported. "
2071 "Use Operation.create for such low-level access.");
2072 }
2073
2074 // Add resultSegmentSizes attribute.
2075 if (!resultSegmentLengths.empty()) {
2076 MlirAttribute segmentLengthAttr =
2077 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
2078 resultSegmentLengths.data());
2079 (*attributes)["resultSegmentSizes"] =
2080 PyAttribute(context, segmentLengthAttr);
2081 }
2082
2083 // Add operandSegmentSizes attribute.
2084 if (!operandSegmentLengths.empty()) {
2085 MlirAttribute segmentLengthAttr =
2086 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
2087 operandSegmentLengths.data());
2088 (*attributes)["operandSegmentSizes"] =
2089 PyAttribute(context, segmentLengthAttr);
2090 }
2091 }
2092
2093 // Delegate to create.
2094 return PyOperation::create(name,
2095 /*results=*/std::move(resultTypes),
2096 /*operands=*/std::move(operands),
2097 /*attributes=*/std::move(attributes),
2098 /*successors=*/std::move(successors),
2099 /*regions=*/*regions, location, maybeIp,
2100 !resultTypeList);
2101}
2102
2103nb::object PyOpView::constructDerived(const nb::object &cls,
2104 const nb::object &operation) {
2105 nb::handle opViewType = nb::type<PyOpView>();
2106 nb::object instance = cls.attr("__new__")(cls);
2107 opViewType.attr("__init__")(instance, operation);
2108 return instance;
2109}
2110
2111PyOpView::PyOpView(const nb::object &operationObject)
2112 // Casting through the PyOperationBase base-class and then back to the
2113 // Operation lets us accept any PyOperationBase subclass.
2114 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
2115 operationObject(operation.getRef().getObject()) {}
2116
2117//------------------------------------------------------------------------------
2118// PyInsertionPoint.
2119//------------------------------------------------------------------------------
2120
2121PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
2122
2123PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
2124 : refOperation(beforeOperationBase.getOperation().getRef()),
2125 block((*refOperation)->getBlock()) {}
2126
2127void PyInsertionPoint::insert(PyOperationBase &operationBase) {
2128 PyOperation &operation = operationBase.getOperation();
2129 if (operation.isAttached())
2130 throw nb::value_error(
2131 "Attempt to insert operation that is already attached");
2132 block.getParentOperation()->checkValid();
2133 MlirOperation beforeOp = {nullptr};
2134 if (refOperation) {
2135 // Insert before operation.
2136 (*refOperation)->checkValid();
2137 beforeOp = (*refOperation)->get();
2138 } else {
2139 // Insert at end (before null) is only valid if the block does not
2140 // already end in a known terminator (violating this will cause assertion
2141 // failures later).
2142 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
2143 throw nb::index_error("Cannot insert operation at the end of a block "
2144 "that already has a terminator. Did you mean to "
2145 "use 'InsertionPoint.at_block_terminator(block)' "
2146 "versus 'InsertionPoint(block)'?");
2147 }
2148 }
2149 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
2150 operation.setAttached();
2151}
2152
2153PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
2154 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
2155 if (mlirOperationIsNull(firstOp)) {
2156 // Just insert at end.
2157 return PyInsertionPoint(block);
2158 }
2159
2160 // Insert before first op.
2161 PyOperationRef firstOpRef = PyOperation::forOperation(
2162 block.getParentOperation()->getContext(), firstOp);
2163 return PyInsertionPoint{block, std::move(firstOpRef)};
2164}
2165
2166PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
2167 MlirOperation terminator = mlirBlockGetTerminator(block.get());
2168 if (mlirOperationIsNull(terminator))
2169 throw nb::value_error("Block has no terminator");
2170 PyOperationRef terminatorOpRef = PyOperation::forOperation(
2171 block.getParentOperation()->getContext(), terminator);
2172 return PyInsertionPoint{block, std::move(terminatorOpRef)};
2173}
2174
2175nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2176 return PyThreadContextEntry::pushInsertionPoint(insertPoint);
2177}
2178
2179void PyInsertionPoint::contextExit(const nb::object &excType,
2180 const nb::object &excVal,
2181 const nb::object &excTb) {
2182 PyThreadContextEntry::popInsertionPoint(insertionPoint&: *this);
2183}
2184
2185//------------------------------------------------------------------------------
2186// PyAttribute.
2187//------------------------------------------------------------------------------
2188
2189bool PyAttribute::operator==(const PyAttribute &other) const {
2190 return mlirAttributeEqual(attr, other.attr);
2191}
2192
2193nb::object PyAttribute::getCapsule() {
2194 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2195}
2196
2197PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
2198 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2199 if (mlirAttributeIsNull(rawAttr))
2200 throw nb::python_error();
2201 return PyAttribute(
2202 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
2203}
2204
2205//------------------------------------------------------------------------------
2206// PyNamedAttribute.
2207//------------------------------------------------------------------------------
2208
2209PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2210 : ownedName(new std::string(std::move(ownedName))) {
2211 namedAttr = mlirNamedAttributeGet(
2212 mlirIdentifierGet(mlirAttributeGetContext(attr),
2213 toMlirStringRef(*this->ownedName)),
2214 attr);
2215}
2216
2217//------------------------------------------------------------------------------
2218// PyType.
2219//------------------------------------------------------------------------------
2220
2221bool PyType::operator==(const PyType &other) const {
2222 return mlirTypeEqual(type, other.type);
2223}
2224
2225nb::object PyType::getCapsule() {
2226 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2227}
2228
2229PyType PyType::createFromCapsule(nb::object capsule) {
2230 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2231 if (mlirTypeIsNull(rawType))
2232 throw nb::python_error();
2233 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
2234 rawType);
2235}
2236
2237//------------------------------------------------------------------------------
2238// PyTypeID.
2239//------------------------------------------------------------------------------
2240
2241nb::object PyTypeID::getCapsule() {
2242 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2243}
2244
2245PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
2246 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2247 if (mlirTypeIDIsNull(mlirTypeID))
2248 throw nb::python_error();
2249 return PyTypeID(mlirTypeID);
2250}
2251bool PyTypeID::operator==(const PyTypeID &other) const {
2252 return mlirTypeIDEqual(typeID, other.typeID);
2253}
2254
2255//------------------------------------------------------------------------------
2256// PyValue and subclasses.
2257//------------------------------------------------------------------------------
2258
2259nb::object PyValue::getCapsule() {
2260 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2261}
2262
2263nb::object PyValue::maybeDownCast() {
2264 MlirType type = mlirValueGetType(get());
2265 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2266 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2267 "mlirTypeID was expected to be non-null.");
2268 std::optional<nb::callable> valueCaster =
2269 PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2270 // nb::rv_policy::move means use std::move to move the return value
2271 // contents into a new instance that will be owned by Python.
2272 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2273 if (!valueCaster)
2274 return thisObj;
2275 return valueCaster.value()(thisObj);
2276}
2277
2278PyValue PyValue::createFromCapsule(nb::object capsule) {
2279 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2280 if (mlirValueIsNull(value))
2281 throw nb::python_error();
2282 MlirOperation owner;
2283 if (mlirValueIsAOpResult(value))
2284 owner = mlirOpResultGetOwner(value);
2285 if (mlirValueIsABlockArgument(value))
2286 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
2287 if (mlirOperationIsNull(owner))
2288 throw nb::python_error();
2289 MlirContext ctx = mlirOperationGetContext(owner);
2290 PyOperationRef ownerRef =
2291 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
2292 return PyValue(ownerRef, value);
2293}
2294
2295//------------------------------------------------------------------------------
2296// PySymbolTable.
2297//------------------------------------------------------------------------------
2298
2299PySymbolTable::PySymbolTable(PyOperationBase &operation)
2300 : operation(operation.getOperation().getRef()) {
2301 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2302 if (mlirSymbolTableIsNull(symbolTable)) {
2303 throw nb::type_error("Operation is not a Symbol Table.");
2304 }
2305}
2306
2307nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2308 operation->checkValid();
2309 MlirOperation symbol = mlirSymbolTableLookup(
2310 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2311 if (mlirOperationIsNull(symbol))
2312 throw nb::key_error(
2313 ("Symbol '" + name + "' not in the symbol table.").c_str());
2314
2315 return PyOperation::forOperation(operation->getContext(), symbol,
2316 operation.getObject())
2317 ->createOpView();
2318}
2319
2320void PySymbolTable::erase(PyOperationBase &symbol) {
2321 operation->checkValid();
2322 symbol.getOperation().checkValid();
2323 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2324 // The operation is also erased, so we must invalidate it. There may be Python
2325 // references to this operation so we don't want to delete it from the list of
2326 // live operations here.
2327 symbol.getOperation().valid = false;
2328}
2329
2330void PySymbolTable::dunderDel(const std::string &name) {
2331 nb::object operation = dunderGetItem(name);
2332 erase(nb::cast<PyOperationBase &>(operation));
2333}
2334
2335MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2336 operation->checkValid();
2337 symbol.getOperation().checkValid();
2338 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2339 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
2340 if (mlirAttributeIsNull(symbolAttr))
2341 throw nb::value_error("Expected operation to have a symbol name.");
2342 return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2343}
2344
2345MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
2346 // Op must already be a symbol.
2347 PyOperation &operation = symbol.getOperation();
2348 operation.checkValid();
2349 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2350 MlirAttribute existingNameAttr =
2351 mlirOperationGetAttributeByName(operation.get(), attrName);
2352 if (mlirAttributeIsNull(existingNameAttr))
2353 throw nb::value_error("Expected operation to have a symbol name.");
2354 return existingNameAttr;
2355}
2356
2357void PySymbolTable::setSymbolName(PyOperationBase &symbol,
2358 const std::string &name) {
2359 // Op must already be a symbol.
2360 PyOperation &operation = symbol.getOperation();
2361 operation.checkValid();
2362 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2363 MlirAttribute existingNameAttr =
2364 mlirOperationGetAttributeByName(operation.get(), attrName);
2365 if (mlirAttributeIsNull(existingNameAttr))
2366 throw nb::value_error("Expected operation to have a symbol name.");
2367 MlirAttribute newNameAttr =
2368 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2369 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2370}
2371
2372MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
2373 PyOperation &operation = symbol.getOperation();
2374 operation.checkValid();
2375 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2376 MlirAttribute existingVisAttr =
2377 mlirOperationGetAttributeByName(operation.get(), attrName);
2378 if (mlirAttributeIsNull(existingVisAttr))
2379 throw nb::value_error("Expected operation to have a symbol visibility.");
2380 return existingVisAttr;
2381}
2382
2383void PySymbolTable::setVisibility(PyOperationBase &symbol,
2384 const std::string &visibility) {
2385 if (visibility != "public" && visibility != "private" &&
2386 visibility != "nested")
2387 throw nb::value_error(
2388 "Expected visibility to be 'public', 'private' or 'nested'");
2389 PyOperation &operation = symbol.getOperation();
2390 operation.checkValid();
2391 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2392 MlirAttribute existingVisAttr =
2393 mlirOperationGetAttributeByName(operation.get(), attrName);
2394 if (mlirAttributeIsNull(existingVisAttr))
2395 throw nb::value_error("Expected operation to have a symbol visibility.");
2396 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2397 toMlirStringRef(visibility));
2398 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2399}
2400
2401void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2402 const std::string &newSymbol,
2403 PyOperationBase &from) {
2404 PyOperation &fromOperation = from.getOperation();
2405 fromOperation.checkValid();
2406 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
2407 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2408 from.getOperation())))
2409
2410 throw nb::value_error("Symbol rename failed");
2411}
2412
2413void PySymbolTable::walkSymbolTables(PyOperationBase &from,
2414 bool allSymUsesVisible,
2415 nb::object callback) {
2416 PyOperation &fromOperation = from.getOperation();
2417 fromOperation.checkValid();
2418 struct UserData {
2419 PyMlirContextRef context;
2420 nb::object callback;
2421 bool gotException;
2422 std::string exceptionWhat;
2423 nb::object exceptionType;
2424 };
2425 UserData userData{
2426 fromOperation.getContext(), std::move(callback), false, {}, {}};
2427 mlirSymbolTableWalkSymbolTables(
2428 fromOperation.get(), allSymUsesVisible,
2429 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2430 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2431 auto pyFoundOp =
2432 PyOperation::forOperation(calleeUserData->context, foundOp);
2433 if (calleeUserData->gotException)
2434 return;
2435 try {
2436 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2437 } catch (nb::python_error &e) {
2438 calleeUserData->gotException = true;
2439 calleeUserData->exceptionWhat = e.what();
2440 calleeUserData->exceptionType = nb::borrow(e.type());
2441 }
2442 },
2443 static_cast<void *>(&userData));
2444 if (userData.gotException) {
2445 std::string message("Exception raised in callback: ");
2446 message.append(str: userData.exceptionWhat);
2447 throw std::runtime_error(message);
2448 }
2449}
2450
2451namespace {
2452
2453/// Python wrapper for MlirBlockArgument.
2454class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2455public:
2456 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2457 static constexpr const char *pyClassName = "BlockArgument";
2458 using PyConcreteValue::PyConcreteValue;
2459
2460 static void bindDerived(ClassTy &c) {
2461 c.def_prop_ro("owner", [](PyBlockArgument &self) {
2462 return PyBlock(self.getParentOperation(),
2463 mlirBlockArgumentGetOwner(self.get()));
2464 });
2465 c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
2466 return mlirBlockArgumentGetArgNumber(self.get());
2467 });
2468 c.def(
2469 "set_type",
2470 [](PyBlockArgument &self, PyType type) {
2471 return mlirBlockArgumentSetType(self.get(), type);
2472 },
2473 nb::arg("type"));
2474 }
2475};
2476
2477/// A list of block arguments. Internally, these are stored as consecutive
2478/// elements, random access is cheap. The argument list is associated with the
2479/// operation that contains the block (detached blocks are not allowed in
2480/// Python bindings) and extends its lifetime.
2481class PyBlockArgumentList
2482 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2483public:
2484 static constexpr const char *pyClassName = "BlockArgumentList";
2485 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2486
2487 PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2488 intptr_t startIndex = 0, intptr_t length = -1,
2489 intptr_t step = 1)
2490 : Sliceable(startIndex,
2491 length == -1 ? mlirBlockGetNumArguments(block) : length,
2492 step),
2493 operation(std::move(operation)), block(block) {}
2494
2495 static void bindDerived(ClassTy &c) {
2496 c.def_prop_ro("types", [](PyBlockArgumentList &self) {
2497 return getValueTypes(self, self.operation->getContext());
2498 });
2499 }
2500
2501private:
2502 /// Give the parent CRTP class access to hook implementations below.
2503 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2504
2505 /// Returns the number of arguments in the list.
2506 intptr_t getRawNumElements() {
2507 operation->checkValid();
2508 return mlirBlockGetNumArguments(block);
2509 }
2510
2511 /// Returns `pos`-the element in the list.
2512 PyBlockArgument getRawElement(intptr_t pos) {
2513 MlirValue argument = mlirBlockGetArgument(block, pos);
2514 return PyBlockArgument(operation, argument);
2515 }
2516
2517 /// Returns a sublist of this list.
2518 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2519 intptr_t step) {
2520 return PyBlockArgumentList(operation, block, startIndex, length, step);
2521 }
2522
2523 PyOperationRef operation;
2524 MlirBlock block;
2525};
2526
2527/// A list of operation operands. Internally, these are stored as consecutive
2528/// elements, random access is cheap. The (returned) operand list is associated
2529/// with the operation whose operands these are, and thus extends the lifetime
2530/// of this operation.
2531class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2532public:
2533 static constexpr const char *pyClassName = "OpOperandList";
2534 using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2535
2536 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2537 intptr_t length = -1, intptr_t step = 1)
2538 : Sliceable(startIndex,
2539 length == -1 ? mlirOperationGetNumOperands(operation->get())
2540 : length,
2541 step),
2542 operation(operation) {}
2543
2544 void dunderSetItem(intptr_t index, PyValue value) {
2545 index = wrapIndex(index);
2546 mlirOperationSetOperand(operation->get(), index, value.get());
2547 }
2548
2549 static void bindDerived(ClassTy &c) {
2550 c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2551 }
2552
2553private:
2554 /// Give the parent CRTP class access to hook implementations below.
2555 friend class Sliceable<PyOpOperandList, PyValue>;
2556
2557 intptr_t getRawNumElements() {
2558 operation->checkValid();
2559 return mlirOperationGetNumOperands(operation->get());
2560 }
2561
2562 PyValue getRawElement(intptr_t pos) {
2563 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2564 MlirOperation owner;
2565 if (mlirValueIsAOpResult(operand))
2566 owner = mlirOpResultGetOwner(operand);
2567 else if (mlirValueIsABlockArgument(operand))
2568 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2569 else
2570 assert(false && "Value must be an block arg or op result.");
2571 PyOperationRef pyOwner =
2572 PyOperation::forOperation(operation->getContext(), owner);
2573 return PyValue(pyOwner, operand);
2574 }
2575
2576 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2577 return PyOpOperandList(operation, startIndex, length, step);
2578 }
2579
2580 PyOperationRef operation;
2581};
2582
2583/// A list of operation successors. Internally, these are stored as consecutive
2584/// elements, random access is cheap. The (returned) successor list is
2585/// associated with the operation whose successors these are, and thus extends
2586/// the lifetime of this operation.
2587class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2588public:
2589 static constexpr const char *pyClassName = "OpSuccessors";
2590
2591 PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2592 intptr_t length = -1, intptr_t step = 1)
2593 : Sliceable(startIndex,
2594 length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2595 : length,
2596 step),
2597 operation(operation) {}
2598
2599 void dunderSetItem(intptr_t index, PyBlock block) {
2600 index = wrapIndex(index);
2601 mlirOperationSetSuccessor(operation->get(), index, block.get());
2602 }
2603
2604 static void bindDerived(ClassTy &c) {
2605 c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2606 }
2607
2608private:
2609 /// Give the parent CRTP class access to hook implementations below.
2610 friend class Sliceable<PyOpSuccessors, PyBlock>;
2611
2612 intptr_t getRawNumElements() {
2613 operation->checkValid();
2614 return mlirOperationGetNumSuccessors(operation->get());
2615 }
2616
2617 PyBlock getRawElement(intptr_t pos) {
2618 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2619 return PyBlock(operation, block);
2620 }
2621
2622 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2623 return PyOpSuccessors(operation, startIndex, length, step);
2624 }
2625
2626 PyOperationRef operation;
2627};
2628
2629/// A list of operation attributes. Can be indexed by name, producing
2630/// attributes, or by index, producing named attributes.
2631class PyOpAttributeMap {
2632public:
2633 PyOpAttributeMap(PyOperationRef operation)
2634 : operation(std::move(operation)) {}
2635
2636 MlirAttribute dunderGetItemNamed(const std::string &name) {
2637 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2638 toMlirStringRef(name));
2639 if (mlirAttributeIsNull(attr)) {
2640 throw nb::key_error("attempt to access a non-existent attribute");
2641 }
2642 return attr;
2643 }
2644
2645 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2646 if (index < 0) {
2647 index += dunderLen();
2648 }
2649 if (index < 0 || index >= dunderLen()) {
2650 throw nb::index_error("attempt to access out of bounds attribute");
2651 }
2652 MlirNamedAttribute namedAttr =
2653 mlirOperationGetAttribute(operation->get(), index);
2654 return PyNamedAttribute(
2655 namedAttr.attribute,
2656 std::string(mlirIdentifierStr(namedAttr.name).data,
2657 mlirIdentifierStr(namedAttr.name).length));
2658 }
2659
2660 void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2661 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2662 attr);
2663 }
2664
2665 void dunderDelItem(const std::string &name) {
2666 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2667 toMlirStringRef(name));
2668 if (!removed)
2669 throw nb::key_error("attempt to delete a non-existent attribute");
2670 }
2671
2672 intptr_t dunderLen() {
2673 return mlirOperationGetNumAttributes(operation->get());
2674 }
2675
2676 bool dunderContains(const std::string &name) {
2677 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2678 operation->get(), toMlirStringRef(name)));
2679 }
2680
2681 static void bind(nb::module_ &m) {
2682 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2683 .def("__contains__", &PyOpAttributeMap::dunderContains)
2684 .def("__len__", &PyOpAttributeMap::dunderLen)
2685 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2686 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2687 .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2688 .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2689 }
2690
2691private:
2692 PyOperationRef operation;
2693};
2694
2695} // namespace
2696
2697//------------------------------------------------------------------------------
2698// Populates the core exports of the 'ir' submodule.
2699//------------------------------------------------------------------------------
2700
2701void mlir::python::populateIRCore(nb::module_ &m) {
2702 // disable leak warnings which tend to be false positives.
2703 nb::set_leak_warnings(false);
2704 //----------------------------------------------------------------------------
2705 // Enums.
2706 //----------------------------------------------------------------------------
2707 nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2708 .value("ERROR", MlirDiagnosticError)
2709 .value("WARNING", MlirDiagnosticWarning)
2710 .value("NOTE", MlirDiagnosticNote)
2711 .value("REMARK", MlirDiagnosticRemark);
2712
2713 nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2714 .value("PRE_ORDER", MlirWalkPreOrder)
2715 .value("POST_ORDER", MlirWalkPostOrder);
2716
2717 nb::enum_<MlirWalkResult>(m, "WalkResult")
2718 .value("ADVANCE", MlirWalkResultAdvance)
2719 .value("INTERRUPT", MlirWalkResultInterrupt)
2720 .value("SKIP", MlirWalkResultSkip);
2721
2722 //----------------------------------------------------------------------------
2723 // Mapping of Diagnostics.
2724 //----------------------------------------------------------------------------
2725 nb::class_<PyDiagnostic>(m, "Diagnostic")
2726 .def_prop_ro("severity", &PyDiagnostic::getSeverity)
2727 .def_prop_ro("location", &PyDiagnostic::getLocation)
2728 .def_prop_ro("message", &PyDiagnostic::getMessage)
2729 .def_prop_ro("notes", &PyDiagnostic::getNotes)
2730 .def("__str__", [](PyDiagnostic &self) -> nb::str {
2731 if (!self.isValid())
2732 return nb::str("<Invalid Diagnostic>");
2733 return self.getMessage();
2734 });
2735
2736 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2737 .def("__init__",
2738 [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
2739 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2740 })
2741 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
2742 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
2743 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
2744 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
2745 .def("__str__",
2746 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2747
2748 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2749 .def("detach", &PyDiagnosticHandler::detach)
2750 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
2751 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
2752 .def("__enter__", &PyDiagnosticHandler::contextEnter)
2753 .def("__exit__", &PyDiagnosticHandler::contextExit,
2754 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2755 nb::arg("traceback").none());
2756
2757 //----------------------------------------------------------------------------
2758 // Mapping of MlirContext.
2759 // Note that this is exported as _BaseContext. The containing, Python level
2760 // __init__.py will subclass it with site-specific functionality and set a
2761 // "Context" attribute on this module.
2762 //----------------------------------------------------------------------------
2763
2764 // Expose DefaultThreadPool to python
2765 nb::class_<PyThreadPool>(m, "ThreadPool")
2766 .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
2767 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
2768 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
2769
2770 nb::class_<PyMlirContext>(m, "_BaseContext")
2771 .def("__init__",
2772 [](PyMlirContext &self) {
2773 MlirContext context = mlirContextCreateWithThreading(false);
2774 new (&self) PyMlirContext(context);
2775 })
2776 .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2777 .def("_get_context_again",
2778 [](PyMlirContext &self) {
2779 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2780 return ref.releaseObject();
2781 })
2782 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2783 .def("_get_live_operation_objects",
2784 &PyMlirContext::getLiveOperationObjects)
2785 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2786 .def("_clear_live_operations_inside",
2787 nb::overload_cast<MlirOperation>(
2788 &PyMlirContext::clearOperationsInside))
2789 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2790 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2791 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2792 .def("__enter__", &PyMlirContext::contextEnter)
2793 .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
2794 nb::arg("exc_value").none(), nb::arg("traceback").none())
2795 .def_prop_ro_static(
2796 "current",
2797 [](nb::object & /*class*/) {
2798 auto *context = PyThreadContextEntry::getDefaultContext();
2799 if (!context)
2800 return nb::none();
2801 return nb::cast(context);
2802 },
2803 "Gets the Context bound to the current thread or raises ValueError")
2804 .def_prop_ro(
2805 "dialects",
2806 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2807 "Gets a container for accessing dialects by name")
2808 .def_prop_ro(
2809 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2810 "Alias for 'dialect'")
2811 .def(
2812 "get_dialect_descriptor",
2813 [=](PyMlirContext &self, std::string &name) {
2814 MlirDialect dialect = mlirContextGetOrLoadDialect(
2815 self.get(), {name.data(), name.size()});
2816 if (mlirDialectIsNull(dialect)) {
2817 throw nb::value_error(
2818 (Twine("Dialect '") + name + "' not found").str().c_str());
2819 }
2820 return PyDialectDescriptor(self.getRef(), dialect);
2821 },
2822 nb::arg("dialect_name"),
2823 "Gets or loads a dialect by name, returning its descriptor object")
2824 .def_prop_rw(
2825 "allow_unregistered_dialects",
2826 [](PyMlirContext &self) -> bool {
2827 return mlirContextGetAllowUnregisteredDialects(self.get());
2828 },
2829 [](PyMlirContext &self, bool value) {
2830 mlirContextSetAllowUnregisteredDialects(self.get(), value);
2831 })
2832 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2833 nb::arg("callback"),
2834 "Attaches a diagnostic handler that will receive callbacks")
2835 .def(
2836 "enable_multithreading",
2837 [](PyMlirContext &self, bool enable) {
2838 mlirContextEnableMultithreading(self.get(), enable);
2839 },
2840 nb::arg("enable"))
2841 .def("set_thread_pool",
2842 [](PyMlirContext &self, PyThreadPool &pool) {
2843 // we should disable multi-threading first before setting
2844 // new thread pool otherwise the assert in
2845 // MLIRContext::setThreadPool will be raised.
2846 mlirContextEnableMultithreading(self.get(), false);
2847 mlirContextSetThreadPool(self.get(), pool.get());
2848 })
2849 .def("get_num_threads",
2850 [](PyMlirContext &self) {
2851 return mlirContextGetNumThreads(self.get());
2852 })
2853 .def("_mlir_thread_pool_ptr",
2854 [](PyMlirContext &self) {
2855 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
2856 std::stringstream ss;
2857 ss << pool.ptr;
2858 return ss.str();
2859 })
2860 .def(
2861 "is_registered_operation",
2862 [](PyMlirContext &self, std::string &name) {
2863 return mlirContextIsRegisteredOperation(
2864 self.get(), MlirStringRef{name.data(), name.size()});
2865 },
2866 nb::arg("operation_name"))
2867 .def(
2868 "append_dialect_registry",
2869 [](PyMlirContext &self, PyDialectRegistry &registry) {
2870 mlirContextAppendDialectRegistry(self.get(), registry);
2871 },
2872 nb::arg("registry"))
2873 .def_prop_rw("emit_error_diagnostics", nullptr,
2874 &PyMlirContext::setEmitErrorDiagnostics,
2875 "Emit error diagnostics to diagnostic handlers. By default "
2876 "error diagnostics are captured and reported through "
2877 "MLIRError exceptions.")
2878 .def("load_all_available_dialects", [](PyMlirContext &self) {
2879 mlirContextLoadAllAvailableDialects(self.get());
2880 });
2881
2882 //----------------------------------------------------------------------------
2883 // Mapping of PyDialectDescriptor
2884 //----------------------------------------------------------------------------
2885 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2886 .def_prop_ro("namespace",
2887 [](PyDialectDescriptor &self) {
2888 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2889 return nb::str(ns.data, ns.length);
2890 })
2891 .def("__repr__", [](PyDialectDescriptor &self) {
2892 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2893 std::string repr("<DialectDescriptor ");
2894 repr.append(ns.data, ns.length);
2895 repr.append(">");
2896 return repr;
2897 });
2898
2899 //----------------------------------------------------------------------------
2900 // Mapping of PyDialects
2901 //----------------------------------------------------------------------------
2902 nb::class_<PyDialects>(m, "Dialects")
2903 .def("__getitem__",
2904 [=](PyDialects &self, std::string keyName) {
2905 MlirDialect dialect =
2906 self.getDialectForKey(keyName, /*attrError=*/false);
2907 nb::object descriptor =
2908 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2909 return createCustomDialectWrapper(keyName, std::move(descriptor));
2910 })
2911 .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2912 MlirDialect dialect =
2913 self.getDialectForKey(attrName, /*attrError=*/true);
2914 nb::object descriptor =
2915 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2916 return createCustomDialectWrapper(attrName, std::move(descriptor));
2917 });
2918
2919 //----------------------------------------------------------------------------
2920 // Mapping of PyDialect
2921 //----------------------------------------------------------------------------
2922 nb::class_<PyDialect>(m, "Dialect")
2923 .def(nb::init<nb::object>(), nb::arg("descriptor"))
2924 .def_prop_ro("descriptor",
2925 [](PyDialect &self) { return self.getDescriptor(); })
2926 .def("__repr__", [](nb::object self) {
2927 auto clazz = self.attr("__class__");
2928 return nb::str("<Dialect ") +
2929 self.attr("descriptor").attr("namespace") + nb::str(" (class ") +
2930 clazz.attr("__module__") + nb::str(".") +
2931 clazz.attr("__name__") + nb::str(")>");
2932 });
2933
2934 //----------------------------------------------------------------------------
2935 // Mapping of PyDialectRegistry
2936 //----------------------------------------------------------------------------
2937 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
2938 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
2939 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2940 .def(nb::init<>());
2941
2942 //----------------------------------------------------------------------------
2943 // Mapping of Location
2944 //----------------------------------------------------------------------------
2945 nb::class_<PyLocation>(m, "Location")
2946 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2947 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2948 .def("__enter__", &PyLocation::contextEnter)
2949 .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
2950 nb::arg("exc_value").none(), nb::arg("traceback").none())
2951 .def("__eq__",
2952 [](PyLocation &self, PyLocation &other) -> bool {
2953 return mlirLocationEqual(self, other);
2954 })
2955 .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
2956 .def_prop_ro_static(
2957 "current",
2958 [](nb::object & /*class*/) {
2959 auto *loc = PyThreadContextEntry::getDefaultLocation();
2960 if (!loc)
2961 throw nb::value_error("No current Location");
2962 return loc;
2963 },
2964 "Gets the Location bound to the current thread or raises ValueError")
2965 .def_static(
2966 "unknown",
2967 [](DefaultingPyMlirContext context) {
2968 return PyLocation(context->getRef(),
2969 mlirLocationUnknownGet(context->get()));
2970 },
2971 nb::arg("context").none() = nb::none(),
2972 "Gets a Location representing an unknown location")
2973 .def_static(
2974 "callsite",
2975 [](PyLocation callee, const std::vector<PyLocation> &frames,
2976 DefaultingPyMlirContext context) {
2977 if (frames.empty())
2978 throw nb::value_error("No caller frames provided");
2979 MlirLocation caller = frames.back().get();
2980 for (const PyLocation &frame :
2981 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2982 caller = mlirLocationCallSiteGet(frame.get(), caller);
2983 return PyLocation(context->getRef(),
2984 mlirLocationCallSiteGet(callee.get(), caller));
2985 },
2986 nb::arg("callee"), nb::arg("frames"),
2987 nb::arg("context").none() = nb::none(),
2988 kContextGetCallSiteLocationDocstring)
2989 .def("is_a_callsite", mlirLocationIsACallSite)
2990 .def_prop_ro("callee", mlirLocationCallSiteGetCallee)
2991 .def_prop_ro("caller", mlirLocationCallSiteGetCaller)
2992 .def_static(
2993 "file",
2994 [](std::string filename, int line, int col,
2995 DefaultingPyMlirContext context) {
2996 return PyLocation(
2997 context->getRef(),
2998 mlirLocationFileLineColGet(
2999 context->get(), toMlirStringRef(filename), line, col));
3000 },
3001 nb::arg("filename"), nb::arg("line"), nb::arg("col"),
3002 nb::arg("context").none() = nb::none(),
3003 kContextGetFileLocationDocstring)
3004 .def_static(
3005 "file",
3006 [](std::string filename, int startLine, int startCol, int endLine,
3007 int endCol, DefaultingPyMlirContext context) {
3008 return PyLocation(context->getRef(),
3009 mlirLocationFileLineColRangeGet(
3010 context->get(), toMlirStringRef(filename),
3011 startLine, startCol, endLine, endCol));
3012 },
3013 nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
3014 nb::arg("end_line"), nb::arg("end_col"),
3015 nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
3016 .def("is_a_file", mlirLocationIsAFileLineColRange)
3017 .def_prop_ro("filename",
3018 [](MlirLocation loc) {
3019 return mlirIdentifierStr(
3020 mlirLocationFileLineColRangeGetFilename(loc));
3021 })
3022 .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
3023 .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
3024 .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
3025 .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
3026 .def_static(
3027 "fused",
3028 [](const std::vector<PyLocation> &pyLocations,
3029 std::optional<PyAttribute> metadata,
3030 DefaultingPyMlirContext context) {
3031 llvm::SmallVector<MlirLocation, 4> locations;
3032 locations.reserve(pyLocations.size());
3033 for (auto &pyLocation : pyLocations)
3034 locations.push_back(pyLocation.get());
3035 MlirLocation location = mlirLocationFusedGet(
3036 context->get(), locations.size(), locations.data(),
3037 metadata ? metadata->get() : MlirAttribute{0});
3038 return PyLocation(context->getRef(), location);
3039 },
3040 nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
3041 nb::arg("context").none() = nb::none(),
3042 kContextGetFusedLocationDocstring)
3043 .def("is_a_fused", mlirLocationIsAFused)
3044 .def_prop_ro("locations",
3045 [](MlirLocation loc) {
3046 unsigned numLocations =
3047 mlirLocationFusedGetNumLocations(loc);
3048 std::vector<MlirLocation> locations(numLocations);
3049 if (numLocations)
3050 mlirLocationFusedGetLocations(loc, locations.data());
3051 return locations;
3052 })
3053 .def_static(
3054 "name",
3055 [](std::string name, std::optional<PyLocation> childLoc,
3056 DefaultingPyMlirContext context) {
3057 return PyLocation(
3058 context->getRef(),
3059 mlirLocationNameGet(
3060 context->get(), toMlirStringRef(name),
3061 childLoc ? childLoc->get()
3062 : mlirLocationUnknownGet(context->get())));
3063 },
3064 nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
3065 nb::arg("context").none() = nb::none(),
3066 kContextGetNameLocationDocString)
3067 .def("is_a_name", mlirLocationIsAName)
3068 .def_prop_ro("name_str",
3069 [](MlirLocation loc) {
3070 return mlirIdentifierStr(mlirLocationNameGetName(loc));
3071 })
3072 .def_prop_ro("child_loc", mlirLocationNameGetChildLoc)
3073 .def_static(
3074 "from_attr",
3075 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3076 return PyLocation(context->getRef(),
3077 mlirLocationFromAttribute(attribute));
3078 },
3079 nb::arg("attribute"), nb::arg("context").none() = nb::none(),
3080 "Gets a Location from a LocationAttr")
3081 .def_prop_ro(
3082 "context",
3083 [](PyLocation &self) { return self.getContext().getObject(); },
3084 "Context that owns the Location")
3085 .def_prop_ro(
3086 "attr",
3087 [](PyLocation &self) { return mlirLocationGetAttribute(self); },
3088 "Get the underlying LocationAttr")
3089 .def(
3090 "emit_error",
3091 [](PyLocation &self, std::string message) {
3092 mlirEmitError(self, message.c_str());
3093 },
3094 nb::arg("message"), "Emits an error at this location")
3095 .def("__repr__", [](PyLocation &self) {
3096 PyPrintAccumulator printAccum;
3097 mlirLocationPrint(self, printAccum.getCallback(),
3098 printAccum.getUserData());
3099 return printAccum.join();
3100 });
3101
3102 //----------------------------------------------------------------------------
3103 // Mapping of Module
3104 //----------------------------------------------------------------------------
3105 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3106 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3107 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3108 .def_static(
3109 "parse",
3110 [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
3111 PyMlirContext::ErrorCapture errors(context->getRef());
3112 MlirModule module = mlirModuleCreateParse(
3113 context->get(), toMlirStringRef(moduleAsm));
3114 if (mlirModuleIsNull(module))
3115 throw MLIRError("Unable to parse module assembly", errors.take());
3116 return PyModule::forModule(module).releaseObject();
3117 },
3118 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3119 kModuleParseDocstring)
3120 .def_static(
3121 "parse",
3122 [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
3123 PyMlirContext::ErrorCapture errors(context->getRef());
3124 MlirModule module = mlirModuleCreateParse(
3125 context->get(), toMlirStringRef(moduleAsm));
3126 if (mlirModuleIsNull(module))
3127 throw MLIRError("Unable to parse module assembly", errors.take());
3128 return PyModule::forModule(module).releaseObject();
3129 },
3130 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3131 kModuleParseDocstring)
3132 .def_static(
3133 "parseFile",
3134 [](const std::string &path, DefaultingPyMlirContext context) {
3135 PyMlirContext::ErrorCapture errors(context->getRef());
3136 MlirModule module = mlirModuleCreateParseFromFile(
3137 context->get(), toMlirStringRef(path));
3138 if (mlirModuleIsNull(module))
3139 throw MLIRError("Unable to parse module assembly", errors.take());
3140 return PyModule::forModule(module).releaseObject();
3141 },
3142 nb::arg("path"), nb::arg("context").none() = nb::none(),
3143 kModuleParseDocstring)
3144 .def_static(
3145 "create",
3146 [](DefaultingPyLocation loc) {
3147 MlirModule module = mlirModuleCreateEmpty(loc);
3148 return PyModule::forModule(module).releaseObject();
3149 },
3150 nb::arg("loc").none() = nb::none(), "Creates an empty module")
3151 .def_prop_ro(
3152 "context",
3153 [](PyModule &self) { return self.getContext().getObject(); },
3154 "Context that created the Module")
3155 .def_prop_ro(
3156 "operation",
3157 [](PyModule &self) {
3158 return PyOperation::forOperation(self.getContext(),
3159 mlirModuleGetOperation(self.get()),
3160 self.getRef().releaseObject())
3161 .releaseObject();
3162 },
3163 "Accesses the module as an operation")
3164 .def_prop_ro(
3165 "body",
3166 [](PyModule &self) {
3167 PyOperationRef moduleOp = PyOperation::forOperation(
3168 self.getContext(), mlirModuleGetOperation(self.get()),
3169 self.getRef().releaseObject());
3170 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3171 return returnBlock;
3172 },
3173 "Return the block for this module")
3174 .def(
3175 "dump",
3176 [](PyModule &self) {
3177 mlirOperationDump(mlirModuleGetOperation(self.get()));
3178 },
3179 kDumpDocstring)
3180 .def(
3181 "__str__",
3182 [](nb::object self) {
3183 // Defer to the operation's __str__.
3184 return self.attr("operation").attr("__str__")();
3185 },
3186 kOperationStrDunderDocstring);
3187
3188 //----------------------------------------------------------------------------
3189 // Mapping of Operation.
3190 //----------------------------------------------------------------------------
3191 nb::class_<PyOperationBase>(m, "_OperationBase")
3192 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
3193 [](PyOperationBase &self) {
3194 return self.getOperation().getCapsule();
3195 })
3196 .def("__eq__",
3197 [](PyOperationBase &self, PyOperationBase &other) {
3198 return &self.getOperation() == &other.getOperation();
3199 })
3200 .def("__eq__",
3201 [](PyOperationBase &self, nb::object other) { return false; })
3202 .def("__hash__",
3203 [](PyOperationBase &self) {
3204 return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
3205 })
3206 .def_prop_ro("attributes",
3207 [](PyOperationBase &self) {
3208 return PyOpAttributeMap(self.getOperation().getRef());
3209 })
3210 .def_prop_ro(
3211 "context",
3212 [](PyOperationBase &self) {
3213 PyOperation &concreteOperation = self.getOperation();
3214 concreteOperation.checkValid();
3215 return concreteOperation.getContext().getObject();
3216 },
3217 "Context that owns the Operation")
3218 .def_prop_ro("name",
3219 [](PyOperationBase &self) {
3220 auto &concreteOperation = self.getOperation();
3221 concreteOperation.checkValid();
3222 MlirOperation operation = concreteOperation.get();
3223 return mlirIdentifierStr(mlirOperationGetName(operation));
3224 })
3225 .def_prop_ro("operands",
3226 [](PyOperationBase &self) {
3227 return PyOpOperandList(self.getOperation().getRef());
3228 })
3229 .def_prop_ro("regions",
3230 [](PyOperationBase &self) {
3231 return PyRegionList(self.getOperation().getRef());
3232 })
3233 .def_prop_ro(
3234 "results",
3235 [](PyOperationBase &self) {
3236 return PyOpResultList(self.getOperation().getRef());
3237 },
3238 "Returns the list of Operation results.")
3239 .def_prop_ro(
3240 "result",
3241 [](PyOperationBase &self) {
3242 auto &operation = self.getOperation();
3243 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3244 .maybeDownCast();
3245 },
3246 "Shortcut to get an op result if it has only one (throws an error "
3247 "otherwise).")
3248 .def_prop_ro(
3249 "location",
3250 [](PyOperationBase &self) {
3251 PyOperation &operation = self.getOperation();
3252 return PyLocation(operation.getContext(),
3253 mlirOperationGetLocation(operation.get()));
3254 },
3255 "Returns the source location the operation was defined or derived "
3256 "from.")
3257 .def_prop_ro("parent",
3258 [](PyOperationBase &self) -> nb::object {
3259 auto parent = self.getOperation().getParentOperation();
3260 if (parent)
3261 return parent->getObject();
3262 return nb::none();
3263 })
3264 .def(
3265 "__str__",
3266 [](PyOperationBase &self) {
3267 return self.getAsm(/*binary=*/false,
3268 /*largeElementsLimit=*/std::nullopt,
3269 /*enableDebugInfo=*/false,
3270 /*prettyDebugInfo=*/false,
3271 /*printGenericOpForm=*/false,
3272 /*useLocalScope=*/false,
3273 /*useNameLocAsPrefix=*/false,
3274 /*assumeVerified=*/false,
3275 /*skipRegions=*/false);
3276 },
3277 "Returns the assembly form of the operation.")
3278 .def("print",
3279 nb::overload_cast<PyAsmState &, nb::object, bool>(
3280 &PyOperationBase::print),
3281 nb::arg("state"), nb::arg("file").none() = nb::none(),
3282 nb::arg("binary") = false, kOperationPrintStateDocstring)
3283 .def("print",
3284 nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3285 bool, bool, nb::object, bool, bool>(
3286 &PyOperationBase::print),
3287 // Careful: Lots of arguments must match up with print method.
3288 nb::arg("large_elements_limit").none() = nb::none(),
3289 nb::arg("enable_debug_info") = false,
3290 nb::arg("pretty_debug_info") = false,
3291 nb::arg("print_generic_op_form") = false,
3292 nb::arg("use_local_scope") = false,
3293 nb::arg("use_name_loc_as_prefix") = false,
3294 nb::arg("assume_verified") = false,
3295 nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
3296 nb::arg("skip_regions") = false, kOperationPrintDocstring)
3297 .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3298 nb::arg("desired_version").none() = nb::none(),
3299 kOperationPrintBytecodeDocstring)
3300 .def("get_asm", &PyOperationBase::getAsm,
3301 // Careful: Lots of arguments must match up with get_asm method.
3302 nb::arg("binary") = false,
3303 nb::arg("large_elements_limit").none() = nb::none(),
3304 nb::arg("enable_debug_info") = false,
3305 nb::arg("pretty_debug_info") = false,
3306 nb::arg("print_generic_op_form") = false,
3307 nb::arg("use_local_scope") = false,
3308 nb::arg("use_name_loc_as_prefix") = false,
3309 nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3310 kOperationGetAsmDocstring)
3311 .def("verify", &PyOperationBase::verify,
3312 "Verify the operation. Raises MLIRError if verification fails, and "
3313 "returns true otherwise.")
3314 .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3315 "Puts self immediately after the other operation in its parent "
3316 "block.")
3317 .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3318 "Puts self immediately before the other operation in its parent "
3319 "block.")
3320 .def(
3321 "clone",
3322 [](PyOperationBase &self, nb::object ip) {
3323 return self.getOperation().clone(ip);
3324 },
3325 nb::arg("ip").none() = nb::none())
3326 .def(
3327 "detach_from_parent",
3328 [](PyOperationBase &self) {
3329 PyOperation &operation = self.getOperation();
3330 operation.checkValid();
3331 if (!operation.isAttached())
3332 throw nb::value_error("Detached operation has no parent.");
3333
3334 operation.detachFromParent();
3335 return operation.createOpView();
3336 },
3337 "Detaches the operation from its parent block.")
3338 .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3339 .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3340 nb::arg("walk_order") = MlirWalkPostOrder);
3341
3342 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3343 .def_static(
3344 "create",
3345 [](std::string_view name,
3346 std::optional<std::vector<PyType *>> results,
3347 std::optional<std::vector<PyValue *>> operands,
3348 std::optional<nb::dict> attributes,
3349 std::optional<std::vector<PyBlock *>> successors, int regions,
3350 DefaultingPyLocation location, const nb::object &maybeIp,
3351 bool inferType) {
3352 // Unpack/validate operands.
3353 llvm::SmallVector<MlirValue, 4> mlirOperands;
3354 if (operands) {
3355 mlirOperands.reserve(operands->size());
3356 for (PyValue *operand : *operands) {
3357 if (!operand)
3358 throw nb::value_error("operand value cannot be None");
3359 mlirOperands.push_back(operand->get());
3360 }
3361 }
3362
3363 return PyOperation::create(name, results, mlirOperands, attributes,
3364 successors, regions, location, maybeIp,
3365 inferType);
3366 },
3367 nb::arg("name"), nb::arg("results").none() = nb::none(),
3368 nb::arg("operands").none() = nb::none(),
3369 nb::arg("attributes").none() = nb::none(),
3370 nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
3371 nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3372 nb::arg("infer_type") = false, kOperationCreateDocstring)
3373 .def_static(
3374 "parse",
3375 [](const std::string &sourceStr, const std::string &sourceName,
3376 DefaultingPyMlirContext context) {
3377 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3378 ->createOpView();
3379 },
3380 nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3381 nb::arg("context").none() = nb::none(),
3382 "Parses an operation. Supports both text assembly format and binary "
3383 "bytecode format.")
3384 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
3385 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3386 .def_prop_ro("operation", [](nb::object self) { return self; })
3387 .def_prop_ro("opview", &PyOperation::createOpView)
3388 .def_prop_ro(
3389 "successors",
3390 [](PyOperationBase &self) {
3391 return PyOpSuccessors(self.getOperation().getRef());
3392 },
3393 "Returns the list of Operation successors.");
3394
3395 auto opViewClass =
3396 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3397 .def(nb::init<nb::object>(), nb::arg("operation"))
3398 .def(
3399 "__init__",
3400 [](PyOpView *self, std::string_view name,
3401 std::tuple<int, bool> opRegionSpec,
3402 nb::object operandSegmentSpecObj,
3403 nb::object resultSegmentSpecObj,
3404 std::optional<nb::list> resultTypeList, nb::list operandList,
3405 std::optional<nb::dict> attributes,
3406 std::optional<std::vector<PyBlock *>> successors,
3407 std::optional<int> regions, DefaultingPyLocation location,
3408 const nb::object &maybeIp) {
3409 new (self) PyOpView(PyOpView::buildGeneric(
3410 name, opRegionSpec, operandSegmentSpecObj,
3411 resultSegmentSpecObj, resultTypeList, operandList,
3412 attributes, successors, regions, location, maybeIp));
3413 },
3414 nb::arg("name"), nb::arg("opRegionSpec"),
3415 nb::arg("operandSegmentSpecObj").none() = nb::none(),
3416 nb::arg("resultSegmentSpecObj").none() = nb::none(),
3417 nb::arg("results").none() = nb::none(),
3418 nb::arg("operands").none() = nb::none(),
3419 nb::arg("attributes").none() = nb::none(),
3420 nb::arg("successors").none() = nb::none(),
3421 nb::arg("regions").none() = nb::none(),
3422 nb::arg("loc").none() = nb::none(),
3423 nb::arg("ip").none() = nb::none())
3424
3425 .def_prop_ro("operation", &PyOpView::getOperationObject)
3426 .def_prop_ro("opview", [](nb::object self) { return self; })
3427 .def(
3428 "__str__",
3429 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3430 .def_prop_ro(
3431 "successors",
3432 [](PyOperationBase &self) {
3433 return PyOpSuccessors(self.getOperation().getRef());
3434 },
3435 "Returns the list of Operation successors.");
3436 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3437 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3438 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3439 // It is faster to pass the operation_name, ods_regions, and
3440 // ods_operand_segments/ods_result_segments as arguments to the constructor,
3441 // rather than to access them as attributes.
3442 opViewClass.attr("build_generic") = classmethod(
3443 [](nb::handle cls, std::optional<nb::list> resultTypeList,
3444 nb::list operandList, std::optional<nb::dict> attributes,
3445 std::optional<std::vector<PyBlock *>> successors,
3446 std::optional<int> regions, DefaultingPyLocation location,
3447 const nb::object &maybeIp) {
3448 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3449 std::tuple<int, bool> opRegionSpec =
3450 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3451 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3452 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3453 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3454 resultSegmentSpec, resultTypeList,
3455 operandList, attributes, successors,
3456 regions, location, maybeIp);
3457 },
3458 nb::arg("cls"), nb::arg("results").none() = nb::none(),
3459 nb::arg("operands").none() = nb::none(),
3460 nb::arg("attributes").none() = nb::none(),
3461 nb::arg("successors").none() = nb::none(),
3462 nb::arg("regions").none() = nb::none(),
3463 nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3464 "Builds a specific, generated OpView based on class level attributes.");
3465 opViewClass.attr("parse") = classmethod(
3466 [](const nb::object &cls, const std::string &sourceStr,
3467 const std::string &sourceName, DefaultingPyMlirContext context) {
3468 PyOperationRef parsed =
3469 PyOperation::parse(context->getRef(), sourceStr, sourceName);
3470
3471 // Check if the expected operation was parsed, and cast to to the
3472 // appropriate `OpView` subclass if successful.
3473 // NOTE: This accesses attributes that have been automatically added to
3474 // `OpView` subclasses, and is not intended to be used on `OpView`
3475 // directly.
3476 std::string clsOpName =
3477 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3478 MlirStringRef identifier =
3479 mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
3480 std::string_view parsedOpName(identifier.data, identifier.length);
3481 if (clsOpName != parsedOpName)
3482 throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3483 parsedOpName + "'");
3484 return PyOpView::constructDerived(cls, parsed.getObject());
3485 },
3486 nb::arg("cls"), nb::arg("source"), nb::kw_only(),
3487 nb::arg("source_name") = "", nb::arg("context").none() = nb::none(),
3488 "Parses a specific, generated OpView based on class level attributes");
3489
3490 //----------------------------------------------------------------------------
3491 // Mapping of PyRegion.
3492 //----------------------------------------------------------------------------
3493 nb::class_<PyRegion>(m, "Region")
3494 .def_prop_ro(
3495 "blocks",
3496 [](PyRegion &self) {
3497 return PyBlockList(self.getParentOperation(), self.get());
3498 },
3499 "Returns a forward-optimized sequence of blocks.")
3500 .def_prop_ro(
3501 "owner",
3502 [](PyRegion &self) {
3503 return self.getParentOperation()->createOpView();
3504 },
3505 "Returns the operation owning this region.")
3506 .def(
3507 "__iter__",
3508 [](PyRegion &self) {
3509 self.checkValid();
3510 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3511 return PyBlockIterator(self.getParentOperation(), firstBlock);
3512 },
3513 "Iterates over blocks in the region.")
3514 .def("__eq__",
3515 [](PyRegion &self, PyRegion &other) {
3516 return self.get().ptr == other.get().ptr;
3517 })
3518 .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
3519
3520 //----------------------------------------------------------------------------
3521 // Mapping of PyBlock.
3522 //----------------------------------------------------------------------------
3523 nb::class_<PyBlock>(m, "Block")
3524 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3525 .def_prop_ro(
3526 "owner",
3527 [](PyBlock &self) {
3528 return self.getParentOperation()->createOpView();
3529 },
3530 "Returns the owning operation of this block.")
3531 .def_prop_ro(
3532 "region",
3533 [](PyBlock &self) {
3534 MlirRegion region = mlirBlockGetParentRegion(self.get());
3535 return PyRegion(self.getParentOperation(), region);
3536 },
3537 "Returns the owning region of this block.")
3538 .def_prop_ro(
3539 "arguments",
3540 [](PyBlock &self) {
3541 return PyBlockArgumentList(self.getParentOperation(), self.get());
3542 },
3543 "Returns a list of block arguments.")
3544 .def(
3545 "add_argument",
3546 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3547 return mlirBlockAddArgument(self.get(), type, loc);
3548 },
3549 "Append an argument of the specified type to the block and returns "
3550 "the newly added argument.")
3551 .def(
3552 "erase_argument",
3553 [](PyBlock &self, unsigned index) {
3554 return mlirBlockEraseArgument(self.get(), index);
3555 },
3556 "Erase the argument at 'index' and remove it from the argument list.")
3557 .def_prop_ro(
3558 "operations",
3559 [](PyBlock &self) {
3560 return PyOperationList(self.getParentOperation(), self.get());
3561 },
3562 "Returns a forward-optimized sequence of operations.")
3563 .def_static(
3564 "create_at_start",
3565 [](PyRegion &parent, const nb::sequence &pyArgTypes,
3566 const std::optional<nb::sequence> &pyArgLocs) {
3567 parent.checkValid();
3568 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3569 mlirRegionInsertOwnedBlock(parent, 0, block);
3570 return PyBlock(parent.getParentOperation(), block);
3571 },
3572 nb::arg("parent"), nb::arg("arg_types") = nb::list(),
3573 nb::arg("arg_locs") = std::nullopt,
3574 "Creates and returns a new Block at the beginning of the given "
3575 "region (with given argument types and locations).")
3576 .def(
3577 "append_to",
3578 [](PyBlock &self, PyRegion &region) {
3579 MlirBlock b = self.get();
3580 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
3581 mlirBlockDetach(b);
3582 mlirRegionAppendOwnedBlock(region.get(), b);
3583 },
3584 "Append this block to a region, transferring ownership if necessary")
3585 .def(
3586 "create_before",
3587 [](PyBlock &self, const nb::args &pyArgTypes,
3588 const std::optional<nb::sequence> &pyArgLocs) {
3589 self.checkValid();
3590 MlirBlock block =
3591 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3592 MlirRegion region = mlirBlockGetParentRegion(self.get());
3593 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3594 return PyBlock(self.getParentOperation(), block);
3595 },
3596 nb::arg("arg_types"), nb::kw_only(),
3597 nb::arg("arg_locs") = std::nullopt,
3598 "Creates and returns a new Block before this block "
3599 "(with given argument types and locations).")
3600 .def(
3601 "create_after",
3602 [](PyBlock &self, const nb::args &pyArgTypes,
3603 const std::optional<nb::sequence> &pyArgLocs) {
3604 self.checkValid();
3605 MlirBlock block =
3606 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3607 MlirRegion region = mlirBlockGetParentRegion(self.get());
3608 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3609 return PyBlock(self.getParentOperation(), block);
3610 },
3611 nb::arg("arg_types"), nb::kw_only(),
3612 nb::arg("arg_locs") = std::nullopt,
3613 "Creates and returns a new Block after this block "
3614 "(with given argument types and locations).")
3615 .def(
3616 "__iter__",
3617 [](PyBlock &self) {
3618 self.checkValid();
3619 MlirOperation firstOperation =
3620 mlirBlockGetFirstOperation(self.get());
3621 return PyOperationIterator(self.getParentOperation(),
3622 firstOperation);
3623 },
3624 "Iterates over operations in the block.")
3625 .def("__eq__",
3626 [](PyBlock &self, PyBlock &other) {
3627 return self.get().ptr == other.get().ptr;
3628 })
3629 .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
3630 .def("__hash__",
3631 [](PyBlock &self) {
3632 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3633 })
3634 .def(
3635 "__str__",
3636 [](PyBlock &self) {
3637 self.checkValid();
3638 PyPrintAccumulator printAccum;
3639 mlirBlockPrint(self.get(), printAccum.getCallback(),
3640 printAccum.getUserData());
3641 return printAccum.join();
3642 },
3643 "Returns the assembly form of the block.")
3644 .def(
3645 "append",
3646 [](PyBlock &self, PyOperationBase &operation) {
3647 if (operation.getOperation().isAttached())
3648 operation.getOperation().detachFromParent();
3649
3650 MlirOperation mlirOperation = operation.getOperation().get();
3651 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3652 operation.getOperation().setAttached(
3653 self.getParentOperation().getObject());
3654 },
3655 nb::arg("operation"),
3656 "Appends an operation to this block. If the operation is currently "
3657 "in another block, it will be moved.");
3658
3659 //----------------------------------------------------------------------------
3660 // Mapping of PyInsertionPoint.
3661 //----------------------------------------------------------------------------
3662
3663 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
3664 .def(nb::init<PyBlock &>(), nb::arg("block"),
3665 "Inserts after the last operation but still inside the block.")
3666 .def("__enter__", &PyInsertionPoint::contextEnter)
3667 .def("__exit__", &PyInsertionPoint::contextExit,
3668 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
3669 nb::arg("traceback").none())
3670 .def_prop_ro_static(
3671 "current",
3672 [](nb::object & /*class*/) {
3673 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3674 if (!ip)
3675 throw nb::value_error("No current InsertionPoint");
3676 return ip;
3677 },
3678 "Gets the InsertionPoint bound to the current thread or raises "
3679 "ValueError if none has been set")
3680 .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
3681 "Inserts before a referenced operation.")
3682 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3683 nb::arg("block"), "Inserts at the beginning of the block.")
3684 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3685 nb::arg("block"), "Inserts before the block terminator.")
3686 .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
3687 "Inserts an operation.")
3688 .def_prop_ro(
3689 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3690 "Returns the block that this InsertionPoint points to.")
3691 .def_prop_ro(
3692 "ref_operation",
3693 [](PyInsertionPoint &self) -> nb::object {
3694 auto refOperation = self.getRefOperation();
3695 if (refOperation)
3696 return refOperation->getObject();
3697 return nb::none();
3698 },
3699 "The reference operation before which new operations are "
3700 "inserted, or None if the insertion point is at the end of "
3701 "the block");
3702
3703 //----------------------------------------------------------------------------
3704 // Mapping of PyAttribute.
3705 //----------------------------------------------------------------------------
3706 nb::class_<PyAttribute>(m, "Attribute")
3707 // Delegate to the PyAttribute copy constructor, which will also lifetime
3708 // extend the backing context which owns the MlirAttribute.
3709 .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
3710 "Casts the passed attribute to the generic Attribute")
3711 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
3712 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3713 .def_static(
3714 "parse",
3715 [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3716 PyMlirContext::ErrorCapture errors(context->getRef());
3717 MlirAttribute attr = mlirAttributeParseGet(
3718 context->get(), toMlirStringRef(attrSpec));
3719 if (mlirAttributeIsNull(attr))
3720 throw MLIRError("Unable to parse attribute", errors.take());
3721 return attr;
3722 },
3723 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3724 "Parses an attribute from an assembly form. Raises an MLIRError on "
3725 "failure.")
3726 .def_prop_ro(
3727 "context",
3728 [](PyAttribute &self) { return self.getContext().getObject(); },
3729 "Context that owns the Attribute")
3730 .def_prop_ro("type",
3731 [](PyAttribute &self) { return mlirAttributeGetType(self); })
3732 .def(
3733 "get_named",
3734 [](PyAttribute &self, std::string name) {
3735 return PyNamedAttribute(self, std::move(name));
3736 },
3737 nb::keep_alive<0, 1>(), "Binds a name to the attribute")
3738 .def("__eq__",
3739 [](PyAttribute &self, PyAttribute &other) { return self == other; })
3740 .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
3741 .def("__hash__",
3742 [](PyAttribute &self) {
3743 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3744 })
3745 .def(
3746 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3747 kDumpDocstring)
3748 .def(
3749 "__str__",
3750 [](PyAttribute &self) {
3751 PyPrintAccumulator printAccum;
3752 mlirAttributePrint(self, printAccum.getCallback(),
3753 printAccum.getUserData());
3754 return printAccum.join();
3755 },
3756 "Returns the assembly form of the Attribute.")
3757 .def("__repr__",
3758 [](PyAttribute &self) {
3759 // Generally, assembly formats are not printed for __repr__ because
3760 // this can cause exceptionally long debug output and exceptions.
3761 // However, attribute values are generally considered useful and
3762 // are printed. This may need to be re-evaluated if debug dumps end
3763 // up being excessive.
3764 PyPrintAccumulator printAccum;
3765 printAccum.parts.append("Attribute(");
3766 mlirAttributePrint(self, printAccum.getCallback(),
3767 printAccum.getUserData());
3768 printAccum.parts.append(")");
3769 return printAccum.join();
3770 })
3771 .def_prop_ro("typeid",
3772 [](PyAttribute &self) -> MlirTypeID {
3773 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3774 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3775 "mlirTypeID was expected to be non-null.");
3776 return mlirTypeID;
3777 })
3778 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
3779 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3780 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3781 "mlirTypeID was expected to be non-null.");
3782 std::optional<nb::callable> typeCaster =
3783 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3784 mlirAttributeGetDialect(self));
3785 if (!typeCaster)
3786 return nb::cast(self);
3787 return typeCaster.value()(self);
3788 });
3789
3790 //----------------------------------------------------------------------------
3791 // Mapping of PyNamedAttribute
3792 //----------------------------------------------------------------------------
3793 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
3794 .def("__repr__",
3795 [](PyNamedAttribute &self) {
3796 PyPrintAccumulator printAccum;
3797 printAccum.parts.append("NamedAttribute(");
3798 printAccum.parts.append(
3799 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3800 mlirIdentifierStr(self.namedAttr.name).length));
3801 printAccum.parts.append("=");
3802 mlirAttributePrint(self.namedAttr.attribute,
3803 printAccum.getCallback(),
3804 printAccum.getUserData());
3805 printAccum.parts.append(")");
3806 return printAccum.join();
3807 })
3808 .def_prop_ro(
3809 "name",
3810 [](PyNamedAttribute &self) {
3811 return mlirIdentifierStr(self.namedAttr.name);
3812 },
3813 "The name of the NamedAttribute binding")
3814 .def_prop_ro(
3815 "attr",
3816 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3817 nb::keep_alive<0, 1>(),
3818 "The underlying generic attribute of the NamedAttribute binding");
3819
3820 //----------------------------------------------------------------------------
3821 // Mapping of PyType.
3822 //----------------------------------------------------------------------------
3823 nb::class_<PyType>(m, "Type")
3824 // Delegate to the PyType copy constructor, which will also lifetime
3825 // extend the backing context which owns the MlirType.
3826 .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
3827 "Casts the passed type to the generic Type")
3828 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3829 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3830 .def_static(
3831 "parse",
3832 [](std::string typeSpec, DefaultingPyMlirContext context) {
3833 PyMlirContext::ErrorCapture errors(context->getRef());
3834 MlirType type =
3835 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3836 if (mlirTypeIsNull(type))
3837 throw MLIRError("Unable to parse type", errors.take());
3838 return type;
3839 },
3840 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3841 kContextParseTypeDocstring)
3842 .def_prop_ro(
3843 "context", [](PyType &self) { return self.getContext().getObject(); },
3844 "Context that owns the Type")
3845 .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3846 .def(
3847 "__eq__", [](PyType &self, nb::object &other) { return false; },
3848 nb::arg("other").none())
3849 .def("__hash__",
3850 [](PyType &self) {
3851 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3852 })
3853 .def(
3854 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3855 .def(
3856 "__str__",
3857 [](PyType &self) {
3858 PyPrintAccumulator printAccum;
3859 mlirTypePrint(self, printAccum.getCallback(),
3860 printAccum.getUserData());
3861 return printAccum.join();
3862 },
3863 "Returns the assembly form of the type.")
3864 .def("__repr__",
3865 [](PyType &self) {
3866 // Generally, assembly formats are not printed for __repr__ because
3867 // this can cause exceptionally long debug output and exceptions.
3868 // However, types are an exception as they typically have compact
3869 // assembly forms and printing them is useful.
3870 PyPrintAccumulator printAccum;
3871 printAccum.parts.append("Type(");
3872 mlirTypePrint(self, printAccum.getCallback(),
3873 printAccum.getUserData());
3874 printAccum.parts.append(")");
3875 return printAccum.join();
3876 })
3877 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3878 [](PyType &self) {
3879 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3880 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3881 "mlirTypeID was expected to be non-null.");
3882 std::optional<nb::callable> typeCaster =
3883 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3884 mlirTypeGetDialect(self));
3885 if (!typeCaster)
3886 return nb::cast(self);
3887 return typeCaster.value()(self);
3888 })
3889 .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID {
3890 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3891 if (!mlirTypeIDIsNull(mlirTypeID))
3892 return mlirTypeID;
3893 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
3894 throw nb::value_error(
3895 (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
3896 });
3897
3898 //----------------------------------------------------------------------------
3899 // Mapping of PyTypeID.
3900 //----------------------------------------------------------------------------
3901 nb::class_<PyTypeID>(m, "TypeID")
3902 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3903 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3904 // Note, this tests whether the underlying TypeIDs are the same,
3905 // not whether the wrapper MlirTypeIDs are the same, nor whether
3906 // the Python objects are the same (i.e., PyTypeID is a value type).
3907 .def("__eq__",
3908 [](PyTypeID &self, PyTypeID &other) { return self == other; })
3909 .def("__eq__",
3910 [](PyTypeID &self, const nb::object &other) { return false; })
3911 // Note, this gives the hash value of the underlying TypeID, not the
3912 // hash value of the Python object, nor the hash value of the
3913 // MlirTypeID wrapper.
3914 .def("__hash__", [](PyTypeID &self) {
3915 return static_cast<size_t>(mlirTypeIDHashValue(self));
3916 });
3917
3918 //----------------------------------------------------------------------------
3919 // Mapping of Value.
3920 //----------------------------------------------------------------------------
3921 nb::class_<PyValue>(m, "Value")
3922 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
3923 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3924 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3925 .def_prop_ro(
3926 "context",
3927 [](PyValue &self) { return self.getParentOperation()->getContext(); },
3928 "Context in which the value lives.")
3929 .def(
3930 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3931 kDumpDocstring)
3932 .def_prop_ro(
3933 "owner",
3934 [](PyValue &self) -> nb::object {
3935 MlirValue v = self.get();
3936 if (mlirValueIsAOpResult(v)) {
3937 assert(
3938 mlirOperationEqual(self.getParentOperation()->get(),
3939 mlirOpResultGetOwner(self.get())) &&
3940 "expected the owner of the value in Python to match that in "
3941 "the IR");
3942 return self.getParentOperation().getObject();
3943 }
3944
3945 if (mlirValueIsABlockArgument(v)) {
3946 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3947 return nb::cast(PyBlock(self.getParentOperation(), block));
3948 }
3949
3950 assert(false && "Value must be a block argument or an op result");
3951 return nb::none();
3952 })
3953 .def_prop_ro("uses",
3954 [](PyValue &self) {
3955 return PyOpOperandIterator(
3956 mlirValueGetFirstUse(self.get()));
3957 })
3958 .def("__eq__",
3959 [](PyValue &self, PyValue &other) {
3960 return self.get().ptr == other.get().ptr;
3961 })
3962 .def("__eq__", [](PyValue &self, nb::object other) { return false; })
3963 .def("__hash__",
3964 [](PyValue &self) {
3965 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3966 })
3967 .def(
3968 "__str__",
3969 [](PyValue &self) {
3970 PyPrintAccumulator printAccum;
3971 printAccum.parts.append("Value(");
3972 mlirValuePrint(self.get(), printAccum.getCallback(),
3973 printAccum.getUserData());
3974 printAccum.parts.append(")");
3975 return printAccum.join();
3976 },
3977 kValueDunderStrDocstring)
3978 .def(
3979 "get_name",
3980 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
3981 PyPrintAccumulator printAccum;
3982 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3983 if (useLocalScope)
3984 mlirOpPrintingFlagsUseLocalScope(flags);
3985 if (useNameLocAsPrefix)
3986 mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
3987 MlirAsmState valueState =
3988 mlirAsmStateCreateForValue(self.get(), flags);
3989 mlirValuePrintAsOperand(self.get(), valueState,
3990 printAccum.getCallback(),
3991 printAccum.getUserData());
3992 mlirOpPrintingFlagsDestroy(flags);
3993 mlirAsmStateDestroy(valueState);
3994 return printAccum.join();
3995 },
3996 nb::arg("use_local_scope") = false,
3997 nb::arg("use_name_loc_as_prefix") = false)
3998 .def(
3999 "get_name",
4000 [](PyValue &self, PyAsmState &state) {
4001 PyPrintAccumulator printAccum;
4002 MlirAsmState valueState = state.get();
4003 mlirValuePrintAsOperand(self.get(), valueState,
4004 printAccum.getCallback(),
4005 printAccum.getUserData());
4006 return printAccum.join();
4007 },
4008 nb::arg("state"), kGetNameAsOperand)
4009 .def_prop_ro("type",
4010 [](PyValue &self) { return mlirValueGetType(self.get()); })
4011 .def(
4012 "set_type",
4013 [](PyValue &self, const PyType &type) {
4014 return mlirValueSetType(self.get(), type);
4015 },
4016 nb::arg("type"))
4017 .def(
4018 "replace_all_uses_with",
4019 [](PyValue &self, PyValue &with) {
4020 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4021 },
4022 kValueReplaceAllUsesWithDocstring)
4023 .def(
4024 "replace_all_uses_except",
4025 [](MlirValue self, MlirValue with, PyOperation &exception) {
4026 MlirOperation exceptedUser = exception.get();
4027 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4028 },
4029 nb::arg("with"), nb::arg("exceptions"),
4030 kValueReplaceAllUsesExceptDocstring)
4031 .def(
4032 "replace_all_uses_except",
4033 [](MlirValue self, MlirValue with, nb::list exceptions) {
4034 // Convert Python list to a SmallVector of MlirOperations
4035 llvm::SmallVector<MlirOperation> exceptionOps;
4036 for (nb::handle exception : exceptions) {
4037 exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4038 }
4039
4040 mlirValueReplaceAllUsesExcept(
4041 self, with, static_cast<intptr_t>(exceptionOps.size()),
4042 exceptionOps.data());
4043 },
4044 nb::arg("with"), nb::arg("exceptions"),
4045 kValueReplaceAllUsesExceptDocstring)
4046 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
4047 [](PyValue &self) { return self.maybeDownCast(); })
4048 .def_prop_ro(
4049 "location",
4050 [](MlirValue self) {
4051 return PyLocation(
4052 PyMlirContext::forContext(mlirValueGetContext(self)),
4053 mlirValueGetLocation(self));
4054 },
4055 "Returns the source location the value");
4056
4057 PyBlockArgument::bind(m);
4058 PyOpResult::bind(m);
4059 PyOpOperand::bind(m);
4060
4061 nb::class_<PyAsmState>(m, "AsmState")
4062 .def(nb::init<PyValue &, bool>(), nb::arg("value"),
4063 nb::arg("use_local_scope") = false)
4064 .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
4065 nb::arg("use_local_scope") = false);
4066
4067 //----------------------------------------------------------------------------
4068 // Mapping of SymbolTable.
4069 //----------------------------------------------------------------------------
4070 nb::class_<PySymbolTable>(m, "SymbolTable")
4071 .def(nb::init<PyOperationBase &>())
4072 .def("__getitem__", &PySymbolTable::dunderGetItem)
4073 .def("insert", &PySymbolTable::insert, nb::arg("operation"))
4074 .def("erase", &PySymbolTable::erase, nb::arg("operation"))
4075 .def("__delitem__", &PySymbolTable::dunderDel)
4076 .def("__contains__",
4077 [](PySymbolTable &table, const std::string &name) {
4078 return !mlirOperationIsNull(mlirSymbolTableLookup(
4079 table, mlirStringRefCreate(name.data(), name.length())));
4080 })
4081 // Static helpers.
4082 .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
4083 nb::arg("symbol"), nb::arg("name"))
4084 .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
4085 nb::arg("symbol"))
4086 .def_static("get_visibility", &PySymbolTable::getVisibility,
4087 nb::arg("symbol"))
4088 .def_static("set_visibility", &PySymbolTable::setVisibility,
4089 nb::arg("symbol"), nb::arg("visibility"))
4090 .def_static("replace_all_symbol_uses",
4091 &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
4092 nb::arg("new_symbol"), nb::arg("from_op"))
4093 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4094 nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
4095 nb::arg("callback"));
4096
4097 // Container bindings.
4098 PyBlockArgumentList::bind(m);
4099 PyBlockIterator::bind(m);
4100 PyBlockList::bind(m);
4101 PyOperationIterator::bind(m);
4102 PyOperationList::bind(m);
4103 PyOpAttributeMap::bind(m);
4104 PyOpOperandIterator::bind(m);
4105 PyOpOperandList::bind(m);
4106 PyOpResultList::bind(m);
4107 PyOpSuccessors::bind(m);
4108 PyRegionIterator::bind(m);
4109 PyRegionList::bind(m);
4110
4111 // Debug bindings.
4112 PyGlobalDebugFlag::bind(m);
4113
4114 // Attribute builder getter.
4115 PyAttrBuilderMap::bind(m);
4116
4117 nb::register_exception_translator([](const std::exception_ptr &p,
4118 void *payload) {
4119 // We can't define exceptions with custom fields through pybind, so instead
4120 // the exception class is defined in python and imported here.
4121 try {
4122 if (p)
4123 std::rethrow_exception(p);
4124 } catch (const MLIRError &e) {
4125 nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
4126 .attr("MLIRError")(e.message, e.errorDiagnostics);
4127 PyErr_SetObject(PyExc_Exception, obj.ptr());
4128 }
4129 });
4130}
4131

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Bindings/Python/IRCore.cpp