1//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Globals.h"
10#include "IRModule.h"
11#include "NanobindUtils.h"
12#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
13#include "mlir-c/BuiltinAttributes.h"
14#include "mlir-c/Debug.h"
15#include "mlir-c/Diagnostics.h"
16#include "mlir-c/IR.h"
17#include "mlir-c/Support.h"
18#include "mlir/Bindings/Python/Nanobind.h"
19#include "mlir/Bindings/Python/NanobindAdaptors.h"
20#include "nanobind/nanobind.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/Support/raw_ostream.h"
24
25#include <optional>
26#include <system_error>
27#include <utility>
28
29namespace nb = nanobind;
30using namespace nb::literals;
31using namespace mlir;
32using namespace mlir::python;
33
34using llvm::SmallVector;
35using llvm::StringRef;
36using llvm::Twine;
37
38//------------------------------------------------------------------------------
39// Docstrings (trivial, non-duplicated docstrings are included inline).
40//------------------------------------------------------------------------------
41
42static const char kContextParseTypeDocstring[] =
43 R"(Parses the assembly form of a type.
44
45Returns a Type object or raises an MLIRError if the type cannot be parsed.
46
47See also: https://mlir.llvm.org/docs/LangRef/#type-system
48)";
49
50static const char kContextGetCallSiteLocationDocstring[] =
51 R"(Gets a Location representing a caller and callsite)";
52
53static const char kContextGetFileLocationDocstring[] =
54 R"(Gets a Location representing a file, line and column)";
55
56static const char kContextGetFileRangeDocstring[] =
57 R"(Gets a Location representing a file, line and column range)";
58
59static const char kContextGetFusedLocationDocstring[] =
60 R"(Gets a Location representing a fused location with optional metadata)";
61
62static const char kContextGetNameLocationDocString[] =
63 R"(Gets a Location representing a named location with optional child location)";
64
65static const char kModuleParseDocstring[] =
66 R"(Parses a module's assembly format from a string.
67
68Returns a new MlirModule or raises an MLIRError if the parsing fails.
69
70See also: https://mlir.llvm.org/docs/LangRef/
71)";
72
73static const char kOperationCreateDocstring[] =
74 R"(Creates a new operation.
75
76Args:
77 name: Operation name (e.g. "dialect.operation").
78 results: Sequence of Type representing op result types.
79 attributes: Dict of str:Attribute.
80 successors: List of Block for the operation's successors.
81 regions: Number of regions to create.
82 location: A Location object (defaults to resolve from context manager).
83 ip: An InsertionPoint (defaults to resolve from context manager or set to
84 False to disable insertion, even with an insertion point set in the
85 context manager).
86 infer_type: Whether to infer result types.
87Returns:
88 A new "detached" Operation object. Detached operations can be added
89 to blocks, which causes them to become "attached."
90)";
91
92static const char kOperationPrintDocstring[] =
93 R"(Prints the assembly form of the operation to a file like object.
94
95Args:
96 file: The file like object to write to. Defaults to sys.stdout.
97 binary: Whether to write bytes (True) or str (False). Defaults to False.
98 large_elements_limit: Whether to elide elements attributes above this
99 number of elements. Defaults to None (no limit).
100 enable_debug_info: Whether to print debug/location information. Defaults
101 to False.
102 pretty_debug_info: Whether to format debug information for easier reading
103 by a human (warning: the result is unparseable).
104 print_generic_op_form: Whether to print the generic assembly forms of all
105 ops. Defaults to False.
106 use_local_Scope: Whether to print in a way that is more optimized for
107 multi-threaded access but may not be consistent with how the overall
108 module prints.
109 assume_verified: By default, if not printing generic form, the verifier
110 will be run and if it fails, generic form will be printed with a comment
111 about failed verification. While a reasonable default for interactive use,
112 for systematic use, it is often better for the caller to verify explicitly
113 and report failures in a more robust fashion. Set this to True if doing this
114 in order to avoid running a redundant verification. If the IR is actually
115 invalid, behavior is undefined.
116 skip_regions: Whether to skip printing regions. Defaults to False.
117)";
118
119static const char kOperationPrintStateDocstring[] =
120 R"(Prints the assembly form of the operation to a file like object.
121
122Args:
123 file: The file like object to write to. Defaults to sys.stdout.
124 binary: Whether to write bytes (True) or str (False). Defaults to False.
125 state: AsmState capturing the operation numbering and flags.
126)";
127
128static const char kOperationGetAsmDocstring[] =
129 R"(Gets the assembly form of the operation with all options available.
130
131Args:
132 binary: Whether to return a bytes (True) or str (False) object. Defaults to
133 False.
134 ... others ...: See the print() method for common keyword arguments for
135 configuring the printout.
136Returns:
137 Either a bytes or str object, depending on the setting of the 'binary'
138 argument.
139)";
140
141static const char kOperationPrintBytecodeDocstring[] =
142 R"(Write the bytecode form of the operation to a file like object.
143
144Args:
145 file: The file like object to write to.
146 desired_version: The version of bytecode to emit.
147Returns:
148 The bytecode writer status.
149)";
150
151static const char kOperationStrDunderDocstring[] =
152 R"(Gets the assembly form of the operation with default options.
153
154If more advanced control over the assembly formatting or I/O options is needed,
155use the dedicated print or get_asm method, which supports keyword arguments to
156customize behavior.
157)";
158
159static const char kDumpDocstring[] =
160 R"(Dumps a debug representation of the object to stderr.)";
161
162static const char kAppendBlockDocstring[] =
163 R"(Appends a new block, with argument types as positional args.
164
165Returns:
166 The created block.
167)";
168
169static const char kValueDunderStrDocstring[] =
170 R"(Returns the string form of the value.
171
172If the value is a block argument, this is the assembly form of its type and the
173position in the argument list. If the value is an operation result, this is
174equivalent to printing the operation that produced it.
175)";
176
177static const char kGetNameAsOperand[] =
178 R"(Returns the string form of value as an operand (i.e., the ValueID).
179)";
180
181static const char kValueReplaceAllUsesWithDocstring[] =
182 R"(Replace all uses of value with the new value, updating anything in
183the IR that uses 'self' to use the other value instead.
184)";
185
186static const char kValueReplaceAllUsesExceptDocstring[] =
187 R"("Replace all uses of this value with the 'with' value, except for those
188in 'exceptions'. 'exceptions' can be either a single operation or a list of
189operations.
190)";
191
192//------------------------------------------------------------------------------
193// Utilities.
194//------------------------------------------------------------------------------
195
196/// Helper for creating an @classmethod.
197template <class Func, typename... Args>
198nb::object classmethod(Func f, Args... args) {
199 nb::object cf = nb::cpp_function(f, args...);
200 return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
201}
202
203static nb::object
204createCustomDialectWrapper(const std::string &dialectNamespace,
205 nb::object dialectDescriptor) {
206 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
207 if (!dialectClass) {
208 // Use the base class.
209 return nb::cast(PyDialect(std::move(dialectDescriptor)));
210 }
211
212 // Create the custom implementation.
213 return (*dialectClass)(std::move(dialectDescriptor));
214}
215
216static MlirStringRef toMlirStringRef(const std::string &s) {
217 return mlirStringRefCreate(s.data(), s.size());
218}
219
220static MlirStringRef toMlirStringRef(std::string_view s) {
221 return mlirStringRefCreate(s.data(), s.size());
222}
223
224static MlirStringRef toMlirStringRef(const nb::bytes &s) {
225 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
226}
227
228/// Create a block, using the current location context if no locations are
229/// specified.
230static MlirBlock createBlock(const nb::sequence &pyArgTypes,
231 const std::optional<nb::sequence> &pyArgLocs) {
232 SmallVector<MlirType> argTypes;
233 argTypes.reserve(nb::len(pyArgTypes));
234 for (const auto &pyType : pyArgTypes)
235 argTypes.push_back(nb::cast<PyType &>(pyType));
236
237 SmallVector<MlirLocation> argLocs;
238 if (pyArgLocs) {
239 argLocs.reserve(nb::len(*pyArgLocs));
240 for (const auto &pyLoc : *pyArgLocs)
241 argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
242 } else if (!argTypes.empty()) {
243 argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
244 }
245
246 if (argTypes.size() != argLocs.size())
247 throw nb::value_error(("Expected " + Twine(argTypes.size()) +
248 " locations, got: " + Twine(argLocs.size()))
249 .str()
250 .c_str());
251 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
252}
253
254/// Wrapper for the global LLVM debugging flag.
255struct PyGlobalDebugFlag {
256 static void set(nb::object &o, bool enable) {
257 nb::ft_lock_guard lock(mutex);
258 mlirEnableGlobalDebug(enable);
259 }
260
261 static bool get(const nb::object &) {
262 nb::ft_lock_guard lock(mutex);
263 return mlirIsGlobalDebugEnabled();
264 }
265
266 static void bind(nb::module_ &m) {
267 // Debug flags.
268 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
269 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
270 &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
271 .def_static(
272 "set_types",
273 [](const std::string &type) {
274 nb::ft_lock_guard lock(mutex);
275 mlirSetGlobalDebugType(type.c_str());
276 },
277 "types"_a, "Sets specific debug types to be produced by LLVM")
278 .def_static("set_types", [](const std::vector<std::string> &types) {
279 std::vector<const char *> pointers;
280 pointers.reserve(types.size());
281 for (const std::string &str : types)
282 pointers.push_back(str.c_str());
283 nb::ft_lock_guard lock(mutex);
284 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
285 });
286 }
287
288private:
289 static nb::ft_mutex mutex;
290};
291
292nb::ft_mutex PyGlobalDebugFlag::mutex;
293
294struct PyAttrBuilderMap {
295 static bool dunderContains(const std::string &attributeKind) {
296 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
297 }
298 static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
299 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
300 if (!builder)
301 throw nb::key_error(attributeKind.c_str());
302 return *builder;
303 }
304 static void dunderSetItemNamed(const std::string &attributeKind,
305 nb::callable func, bool replace) {
306 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
307 replace);
308 }
309
310 static void bind(nb::module_ &m) {
311 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
312 .def_static("contains", &PyAttrBuilderMap::dunderContains)
313 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
314 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
315 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
316 "Register an attribute builder for building MLIR "
317 "attributes from python values.");
318 }
319};
320
321//------------------------------------------------------------------------------
322// PyBlock
323//------------------------------------------------------------------------------
324
325nb::object PyBlock::getCapsule() {
326 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
327}
328
329//------------------------------------------------------------------------------
330// Collections.
331//------------------------------------------------------------------------------
332
333namespace {
334
335class PyRegionIterator {
336public:
337 PyRegionIterator(PyOperationRef operation)
338 : operation(std::move(operation)) {}
339
340 PyRegionIterator &dunderIter() { return *this; }
341
342 PyRegion dunderNext() {
343 operation->checkValid();
344 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
345 throw nb::stop_iteration();
346 }
347 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
348 return PyRegion(operation, region);
349 }
350
351 static void bind(nb::module_ &m) {
352 nb::class_<PyRegionIterator>(m, "RegionIterator")
353 .def("__iter__", &PyRegionIterator::dunderIter)
354 .def("__next__", &PyRegionIterator::dunderNext);
355 }
356
357private:
358 PyOperationRef operation;
359 int nextIndex = 0;
360};
361
362/// Regions of an op are fixed length and indexed numerically so are represented
363/// with a sequence-like container.
364class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
365public:
366 static constexpr const char *pyClassName = "RegionSequence";
367
368 PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
369 intptr_t length = -1, intptr_t step = 1)
370 : Sliceable(startIndex,
371 length == -1 ? mlirOperationGetNumRegions(operation->get())
372 : length,
373 step),
374 operation(std::move(operation)) {}
375
376 PyRegionIterator dunderIter() {
377 operation->checkValid();
378 return PyRegionIterator(operation);
379 }
380
381 static void bindDerived(ClassTy &c) {
382 c.def("__iter__", &PyRegionList::dunderIter);
383 }
384
385private:
386 /// Give the parent CRTP class access to hook implementations below.
387 friend class Sliceable<PyRegionList, PyRegion>;
388
389 intptr_t getRawNumElements() {
390 operation->checkValid();
391 return mlirOperationGetNumRegions(operation->get());
392 }
393
394 PyRegion getRawElement(intptr_t pos) {
395 operation->checkValid();
396 return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
397 }
398
399 PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
400 return PyRegionList(operation, startIndex, length, step);
401 }
402
403 PyOperationRef operation;
404};
405
406class PyBlockIterator {
407public:
408 PyBlockIterator(PyOperationRef operation, MlirBlock next)
409 : operation(std::move(operation)), next(next) {}
410
411 PyBlockIterator &dunderIter() { return *this; }
412
413 PyBlock dunderNext() {
414 operation->checkValid();
415 if (mlirBlockIsNull(next)) {
416 throw nb::stop_iteration();
417 }
418
419 PyBlock returnBlock(operation, next);
420 next = mlirBlockGetNextInRegion(next);
421 return returnBlock;
422 }
423
424 static void bind(nb::module_ &m) {
425 nb::class_<PyBlockIterator>(m, "BlockIterator")
426 .def("__iter__", &PyBlockIterator::dunderIter)
427 .def("__next__", &PyBlockIterator::dunderNext);
428 }
429
430private:
431 PyOperationRef operation;
432 MlirBlock next;
433};
434
435/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
436/// we present them as a more full-featured list-like container but optimize
437/// it for forward iteration. Blocks are always owned by a region.
438class PyBlockList {
439public:
440 PyBlockList(PyOperationRef operation, MlirRegion region)
441 : operation(std::move(operation)), region(region) {}
442
443 PyBlockIterator dunderIter() {
444 operation->checkValid();
445 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
446 }
447
448 intptr_t dunderLen() {
449 operation->checkValid();
450 intptr_t count = 0;
451 MlirBlock block = mlirRegionGetFirstBlock(region);
452 while (!mlirBlockIsNull(block)) {
453 count += 1;
454 block = mlirBlockGetNextInRegion(block);
455 }
456 return count;
457 }
458
459 PyBlock dunderGetItem(intptr_t index) {
460 operation->checkValid();
461 if (index < 0) {
462 index += dunderLen();
463 }
464 if (index < 0) {
465 throw nb::index_error("attempt to access out of bounds block");
466 }
467 MlirBlock block = mlirRegionGetFirstBlock(region);
468 while (!mlirBlockIsNull(block)) {
469 if (index == 0) {
470 return PyBlock(operation, block);
471 }
472 block = mlirBlockGetNextInRegion(block);
473 index -= 1;
474 }
475 throw nb::index_error("attempt to access out of bounds block");
476 }
477
478 PyBlock appendBlock(const nb::args &pyArgTypes,
479 const std::optional<nb::sequence> &pyArgLocs) {
480 operation->checkValid();
481 MlirBlock block =
482 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
483 mlirRegionAppendOwnedBlock(region, block);
484 return PyBlock(operation, block);
485 }
486
487 static void bind(nb::module_ &m) {
488 nb::class_<PyBlockList>(m, "BlockList")
489 .def("__getitem__", &PyBlockList::dunderGetItem)
490 .def("__iter__", &PyBlockList::dunderIter)
491 .def("__len__", &PyBlockList::dunderLen)
492 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
493 nb::arg("args"), nb::kw_only(),
494 nb::arg("arg_locs") = std::nullopt);
495 }
496
497private:
498 PyOperationRef operation;
499 MlirRegion region;
500};
501
502class PyOperationIterator {
503public:
504 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
505 : parentOperation(std::move(parentOperation)), next(next) {}
506
507 PyOperationIterator &dunderIter() { return *this; }
508
509 nb::object dunderNext() {
510 parentOperation->checkValid();
511 if (mlirOperationIsNull(next)) {
512 throw nb::stop_iteration();
513 }
514
515 PyOperationRef returnOperation =
516 PyOperation::forOperation(parentOperation->getContext(), next);
517 next = mlirOperationGetNextInBlock(next);
518 return returnOperation->createOpView();
519 }
520
521 static void bind(nb::module_ &m) {
522 nb::class_<PyOperationIterator>(m, "OperationIterator")
523 .def("__iter__", &PyOperationIterator::dunderIter)
524 .def("__next__", &PyOperationIterator::dunderNext);
525 }
526
527private:
528 PyOperationRef parentOperation;
529 MlirOperation next;
530};
531
532/// Operations are exposed by the C-API as a forward-only linked list. In
533/// Python, we present them as a more full-featured list-like container but
534/// optimize it for forward iteration. Iterable operations are always owned
535/// by a block.
536class PyOperationList {
537public:
538 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
539 : parentOperation(std::move(parentOperation)), block(block) {}
540
541 PyOperationIterator dunderIter() {
542 parentOperation->checkValid();
543 return PyOperationIterator(parentOperation,
544 mlirBlockGetFirstOperation(block));
545 }
546
547 intptr_t dunderLen() {
548 parentOperation->checkValid();
549 intptr_t count = 0;
550 MlirOperation childOp = mlirBlockGetFirstOperation(block);
551 while (!mlirOperationIsNull(childOp)) {
552 count += 1;
553 childOp = mlirOperationGetNextInBlock(childOp);
554 }
555 return count;
556 }
557
558 nb::object dunderGetItem(intptr_t index) {
559 parentOperation->checkValid();
560 if (index < 0) {
561 index += dunderLen();
562 }
563 if (index < 0) {
564 throw nb::index_error("attempt to access out of bounds operation");
565 }
566 MlirOperation childOp = mlirBlockGetFirstOperation(block);
567 while (!mlirOperationIsNull(childOp)) {
568 if (index == 0) {
569 return PyOperation::forOperation(parentOperation->getContext(), childOp)
570 ->createOpView();
571 }
572 childOp = mlirOperationGetNextInBlock(childOp);
573 index -= 1;
574 }
575 throw nb::index_error("attempt to access out of bounds operation");
576 }
577
578 static void bind(nb::module_ &m) {
579 nb::class_<PyOperationList>(m, "OperationList")
580 .def("__getitem__", &PyOperationList::dunderGetItem)
581 .def("__iter__", &PyOperationList::dunderIter)
582 .def("__len__", &PyOperationList::dunderLen);
583 }
584
585private:
586 PyOperationRef parentOperation;
587 MlirBlock block;
588};
589
590class PyOpOperand {
591public:
592 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
593
594 nb::object getOwner() {
595 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
596 PyMlirContextRef context =
597 PyMlirContext::forContext(mlirOperationGetContext(owner));
598 return PyOperation::forOperation(context, owner)->createOpView();
599 }
600
601 size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
602
603 static void bind(nb::module_ &m) {
604 nb::class_<PyOpOperand>(m, "OpOperand")
605 .def_prop_ro("owner", &PyOpOperand::getOwner)
606 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
607 }
608
609private:
610 MlirOpOperand opOperand;
611};
612
613class PyOpOperandIterator {
614public:
615 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
616
617 PyOpOperandIterator &dunderIter() { return *this; }
618
619 PyOpOperand dunderNext() {
620 if (mlirOpOperandIsNull(opOperand))
621 throw nb::stop_iteration();
622
623 PyOpOperand returnOpOperand(opOperand);
624 opOperand = mlirOpOperandGetNextUse(opOperand);
625 return returnOpOperand;
626 }
627
628 static void bind(nb::module_ &m) {
629 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
630 .def("__iter__", &PyOpOperandIterator::dunderIter)
631 .def("__next__", &PyOpOperandIterator::dunderNext);
632 }
633
634private:
635 MlirOpOperand opOperand;
636};
637
638} // namespace
639
640//------------------------------------------------------------------------------
641// PyMlirContext
642//------------------------------------------------------------------------------
643
644PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
645 nb::gil_scoped_acquire acquire;
646 nb::ft_lock_guard lock(live_contexts_mutex);
647 auto &liveContexts = getLiveContexts();
648 liveContexts[context.ptr] = this;
649}
650
651PyMlirContext::~PyMlirContext() {
652 // Note that the only public way to construct an instance is via the
653 // forContext method, which always puts the associated handle into
654 // liveContexts.
655 nb::gil_scoped_acquire acquire;
656 {
657 nb::ft_lock_guard lock(live_contexts_mutex);
658 getLiveContexts().erase(context.ptr);
659 }
660 mlirContextDestroy(context);
661}
662
663nb::object PyMlirContext::getCapsule() {
664 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
665}
666
667nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
668 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
669 if (mlirContextIsNull(rawContext))
670 throw nb::python_error();
671 return forContext(rawContext).releaseObject();
672}
673
674PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
675 nb::gil_scoped_acquire acquire;
676 nb::ft_lock_guard lock(live_contexts_mutex);
677 auto &liveContexts = getLiveContexts();
678 auto it = liveContexts.find(context.ptr);
679 if (it == liveContexts.end()) {
680 // Create.
681 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
682 nb::object pyRef = nb::cast(unownedContextWrapper);
683 assert(pyRef && "cast to nb::object failed");
684 liveContexts[context.ptr] = unownedContextWrapper;
685 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
686 }
687 // Use existing.
688 nb::object pyRef = nb::cast(it->second);
689 return PyMlirContextRef(it->second, std::move(pyRef));
690}
691
692nb::ft_mutex PyMlirContext::live_contexts_mutex;
693
694PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
695 static LiveContextMap liveContexts;
696 return liveContexts;
697}
698
699size_t PyMlirContext::getLiveCount() {
700 nb::ft_lock_guard lock(live_contexts_mutex);
701 return getLiveContexts().size();
702}
703
704size_t PyMlirContext::getLiveOperationCount() {
705 nb::ft_lock_guard lock(liveOperationsMutex);
706 return liveOperations.size();
707}
708
709std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
710 std::vector<PyOperation *> liveObjects;
711 nb::ft_lock_guard lock(liveOperationsMutex);
712 for (auto &entry : liveOperations)
713 liveObjects.push_back(entry.second.second);
714 return liveObjects;
715}
716
717size_t PyMlirContext::clearLiveOperations() {
718
719 LiveOperationMap operations;
720 {
721 nb::ft_lock_guard lock(liveOperationsMutex);
722 std::swap(operations, liveOperations);
723 }
724 for (auto &op : operations)
725 op.second.second->setInvalid();
726 size_t numInvalidated = operations.size();
727 return numInvalidated;
728}
729
730void PyMlirContext::clearOperation(MlirOperation op) {
731 PyOperation *py_op;
732 {
733 nb::ft_lock_guard lock(liveOperationsMutex);
734 auto it = liveOperations.find(op.ptr);
735 if (it == liveOperations.end()) {
736 return;
737 }
738 py_op = it->second.second;
739 liveOperations.erase(it);
740 }
741 py_op->setInvalid();
742}
743
744void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
745 typedef struct {
746 PyOperation &rootOp;
747 bool rootSeen;
748 } callBackData;
749 callBackData data{.rootOp: op.getOperation(), .rootSeen: false};
750 // Mark all ops below the op that the passmanager will be rooted
751 // at (but not op itself - note the preorder) as invalid.
752 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
753 void *userData) {
754 callBackData *data = static_cast<callBackData *>(userData);
755 if (LLVM_LIKELY(data->rootSeen))
756 data->rootOp.getOperation().getContext()->clearOperation(op);
757 else
758 data->rootSeen = true;
759 return MlirWalkResult::MlirWalkResultAdvance;
760 };
761 mlirOperationWalk(op.getOperation(), invalidatingCallback,
762 static_cast<void *>(&data), MlirWalkPreOrder);
763}
764void PyMlirContext::clearOperationsInside(MlirOperation op) {
765 PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
766 clearOperationsInside(opRef->getOperation());
767}
768
769void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
770 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
771 void *userData) {
772 PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
773 contextRef->clearOperation(op);
774 return MlirWalkResult::MlirWalkResultAdvance;
775 };
776 mlirOperationWalk(op.getOperation(), invalidatingCallback,
777 &op.getOperation().getContext(), MlirWalkPreOrder);
778}
779
780size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
781
782nb::object PyMlirContext::contextEnter(nb::object context) {
783 return PyThreadContextEntry::pushContext(context);
784}
785
786void PyMlirContext::contextExit(const nb::object &excType,
787 const nb::object &excVal,
788 const nb::object &excTb) {
789 PyThreadContextEntry::popContext(context&: *this);
790}
791
792nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
793 // Note that ownership is transferred to the delete callback below by way of
794 // an explicit inc_ref (borrow).
795 PyDiagnosticHandler *pyHandler =
796 new PyDiagnosticHandler(get(), std::move(callback));
797 nb::object pyHandlerObject =
798 nb::cast(pyHandler, nb::rv_policy::take_ownership);
799 pyHandlerObject.inc_ref();
800
801 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
802 // guaranteed to be known to pybind.
803 auto handlerCallback =
804 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
805 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
806 nb::object pyDiagnosticObject =
807 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
808
809 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
810 bool result = false;
811 {
812 // Since this can be called from arbitrary C++ contexts, always get the
813 // gil.
814 nb::gil_scoped_acquire gil;
815 try {
816 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
817 } catch (std::exception &e) {
818 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
819 e.what());
820 pyHandler->hadError = true;
821 }
822 }
823
824 pyDiagnostic->invalidate();
825 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
826 };
827 auto deleteCallback = +[](void *userData) {
828 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
829 assert(pyHandler->registeredID && "handler is not registered");
830 pyHandler->registeredID.reset();
831
832 // Decrement reference, balancing the inc_ref() above.
833 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
834 pyHandlerObject.dec_ref();
835 };
836
837 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
838 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
839 return pyHandlerObject;
840}
841
842MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
843 void *userData) {
844 auto *self = static_cast<ErrorCapture *>(userData);
845 // Check if the context requested we emit errors instead of capturing them.
846 if (self->ctx->emitErrorDiagnostics)
847 return mlirLogicalResultFailure();
848
849 if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
850 return mlirLogicalResultFailure();
851
852 self->errors.emplace_back(args: PyDiagnostic(diag).getInfo());
853 return mlirLogicalResultSuccess();
854}
855
856PyMlirContext &DefaultingPyMlirContext::resolve() {
857 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
858 if (!context) {
859 throw std::runtime_error(
860 "An MLIR function requires a Context but none was provided in the call "
861 "or from the surrounding environment. Either pass to the function with "
862 "a 'context=' argument or establish a default using 'with Context():'");
863 }
864 return *context;
865}
866
867//------------------------------------------------------------------------------
868// PyThreadContextEntry management
869//------------------------------------------------------------------------------
870
871std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
872 static thread_local std::vector<PyThreadContextEntry> stack;
873 return stack;
874}
875
876PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
877 auto &stack = getStack();
878 if (stack.empty())
879 return nullptr;
880 return &stack.back();
881}
882
883void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
884 nb::object insertionPoint,
885 nb::object location) {
886 auto &stack = getStack();
887 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
888 std::move(location));
889 // If the new stack has more than one entry and the context of the new top
890 // entry matches the previous, copy the insertionPoint and location from the
891 // previous entry if missing from the new top entry.
892 if (stack.size() > 1) {
893 auto &prev = *(stack.rbegin() + 1);
894 auto &current = stack.back();
895 if (current.context.is(prev.context)) {
896 // Default non-context objects from the previous entry.
897 if (!current.insertionPoint)
898 current.insertionPoint = prev.insertionPoint;
899 if (!current.location)
900 current.location = prev.location;
901 }
902 }
903}
904
905PyMlirContext *PyThreadContextEntry::getContext() {
906 if (!context)
907 return nullptr;
908 return nb::cast<PyMlirContext *>(context);
909}
910
911PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
912 if (!insertionPoint)
913 return nullptr;
914 return nb::cast<PyInsertionPoint *>(insertionPoint);
915}
916
917PyLocation *PyThreadContextEntry::getLocation() {
918 if (!location)
919 return nullptr;
920 return nb::cast<PyLocation *>(location);
921}
922
923PyMlirContext *PyThreadContextEntry::getDefaultContext() {
924 auto *tos = getTopOfStack();
925 return tos ? tos->getContext() : nullptr;
926}
927
928PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
929 auto *tos = getTopOfStack();
930 return tos ? tos->getInsertionPoint() : nullptr;
931}
932
933PyLocation *PyThreadContextEntry::getDefaultLocation() {
934 auto *tos = getTopOfStack();
935 return tos ? tos->getLocation() : nullptr;
936}
937
938nb::object PyThreadContextEntry::pushContext(nb::object context) {
939 push(FrameKind::Context, /*context=*/context,
940 /*insertionPoint=*/nb::object(),
941 /*location=*/nb::object());
942 return context;
943}
944
945void PyThreadContextEntry::popContext(PyMlirContext &context) {
946 auto &stack = getStack();
947 if (stack.empty())
948 throw std::runtime_error("Unbalanced Context enter/exit");
949 auto &tos = stack.back();
950 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
951 throw std::runtime_error("Unbalanced Context enter/exit");
952 stack.pop_back();
953}
954
955nb::object
956PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
957 PyInsertionPoint &insertionPoint =
958 nb::cast<PyInsertionPoint &>(insertionPointObj);
959 nb::object contextObj =
960 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
961 push(FrameKind::InsertionPoint,
962 /*context=*/contextObj,
963 /*insertionPoint=*/insertionPointObj,
964 /*location=*/nb::object());
965 return insertionPointObj;
966}
967
968void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
969 auto &stack = getStack();
970 if (stack.empty())
971 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
972 auto &tos = stack.back();
973 if (tos.frameKind != FrameKind::InsertionPoint &&
974 tos.getInsertionPoint() != &insertionPoint)
975 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
976 stack.pop_back();
977}
978
979nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
980 PyLocation &location = nb::cast<PyLocation &>(locationObj);
981 nb::object contextObj = location.getContext().getObject();
982 push(FrameKind::Location, /*context=*/contextObj,
983 /*insertionPoint=*/nb::object(),
984 /*location=*/locationObj);
985 return locationObj;
986}
987
988void PyThreadContextEntry::popLocation(PyLocation &location) {
989 auto &stack = getStack();
990 if (stack.empty())
991 throw std::runtime_error("Unbalanced Location enter/exit");
992 auto &tos = stack.back();
993 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
994 throw std::runtime_error("Unbalanced Location enter/exit");
995 stack.pop_back();
996}
997
998//------------------------------------------------------------------------------
999// PyDiagnostic*
1000//------------------------------------------------------------------------------
1001
1002void PyDiagnostic::invalidate() {
1003 valid = false;
1004 if (materializedNotes) {
1005 for (nb::handle noteObject : *materializedNotes) {
1006 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
1007 note->invalidate();
1008 }
1009 }
1010}
1011
1012PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
1013 nb::object callback)
1014 : context(context), callback(std::move(callback)) {}
1015
1016PyDiagnosticHandler::~PyDiagnosticHandler() = default;
1017
1018void PyDiagnosticHandler::detach() {
1019 if (!registeredID)
1020 return;
1021 MlirDiagnosticHandlerID localID = *registeredID;
1022 mlirContextDetachDiagnosticHandler(context, localID);
1023 assert(!registeredID && "should have unregistered");
1024 // Not strictly necessary but keeps stale pointers from being around to cause
1025 // issues.
1026 context = {nullptr};
1027}
1028
1029void PyDiagnostic::checkValid() {
1030 if (!valid) {
1031 throw std::invalid_argument(
1032 "Diagnostic is invalid (used outside of callback)");
1033 }
1034}
1035
1036MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
1037 checkValid();
1038 return mlirDiagnosticGetSeverity(diagnostic);
1039}
1040
1041PyLocation PyDiagnostic::getLocation() {
1042 checkValid();
1043 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1044 MlirContext context = mlirLocationGetContext(loc);
1045 return PyLocation(PyMlirContext::forContext(context), loc);
1046}
1047
1048nb::str PyDiagnostic::getMessage() {
1049 checkValid();
1050 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
1051 PyFileAccumulator accum(fileObject, /*binary=*/false);
1052 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
1053 return nb::cast<nb::str>(fileObject.attr("getvalue")());
1054}
1055
1056nb::tuple PyDiagnostic::getNotes() {
1057 checkValid();
1058 if (materializedNotes)
1059 return *materializedNotes;
1060 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
1061 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
1062 for (intptr_t i = 0; i < numNotes; ++i) {
1063 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1064 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
1065 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
1066 }
1067 materializedNotes = std::move(notes);
1068
1069 return *materializedNotes;
1070}
1071
1072PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
1073 std::vector<DiagnosticInfo> notes;
1074 for (nb::handle n : getNotes())
1075 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
1076 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
1077 std::move(notes)};
1078}
1079
1080//------------------------------------------------------------------------------
1081// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1082//------------------------------------------------------------------------------
1083
1084MlirDialect PyDialects::getDialectForKey(const std::string &key,
1085 bool attrError) {
1086 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1087 {key.data(), key.size()});
1088 if (mlirDialectIsNull(dialect)) {
1089 std::string msg = (Twine("Dialect '") + key + "' not found").str();
1090 if (attrError)
1091 throw nb::attribute_error(msg.c_str());
1092 throw nb::index_error(msg.c_str());
1093 }
1094 return dialect;
1095}
1096
1097nb::object PyDialectRegistry::getCapsule() {
1098 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
1099}
1100
1101PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) {
1102 MlirDialectRegistry rawRegistry =
1103 mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1104 if (mlirDialectRegistryIsNull(rawRegistry))
1105 throw nb::python_error();
1106 return PyDialectRegistry(rawRegistry);
1107}
1108
1109//------------------------------------------------------------------------------
1110// PyLocation
1111//------------------------------------------------------------------------------
1112
1113nb::object PyLocation::getCapsule() {
1114 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
1115}
1116
1117PyLocation PyLocation::createFromCapsule(nb::object capsule) {
1118 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1119 if (mlirLocationIsNull(rawLoc))
1120 throw nb::python_error();
1121 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
1122 rawLoc);
1123}
1124
1125nb::object PyLocation::contextEnter(nb::object locationObj) {
1126 return PyThreadContextEntry::pushLocation(locationObj);
1127}
1128
1129void PyLocation::contextExit(const nb::object &excType,
1130 const nb::object &excVal,
1131 const nb::object &excTb) {
1132 PyThreadContextEntry::popLocation(location&: *this);
1133}
1134
1135PyLocation &DefaultingPyLocation::resolve() {
1136 auto *location = PyThreadContextEntry::getDefaultLocation();
1137 if (!location) {
1138 throw std::runtime_error(
1139 "An MLIR function requires a Location but none was provided in the "
1140 "call or from the surrounding environment. Either pass to the function "
1141 "with a 'loc=' argument or establish a default using 'with loc:'");
1142 }
1143 return *location;
1144}
1145
1146//------------------------------------------------------------------------------
1147// PyModule
1148//------------------------------------------------------------------------------
1149
1150PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1151 : BaseContextObject(std::move(contextRef)), module(module) {}
1152
1153PyModule::~PyModule() {
1154 nb::gil_scoped_acquire acquire;
1155 auto &liveModules = getContext()->liveModules;
1156 assert(liveModules.count(module.ptr) == 1 &&
1157 "destroying module not in live map");
1158 liveModules.erase(module.ptr);
1159 mlirModuleDestroy(module);
1160}
1161
1162PyModuleRef PyModule::forModule(MlirModule module) {
1163 MlirContext context = mlirModuleGetContext(module);
1164 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1165
1166 nb::gil_scoped_acquire acquire;
1167 auto &liveModules = contextRef->liveModules;
1168 auto it = liveModules.find(module.ptr);
1169 if (it == liveModules.end()) {
1170 // Create.
1171 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1172 // Note that the default return value policy on cast is automatic_reference,
1173 // which does not take ownership (delete will not be called).
1174 // Just be explicit.
1175 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1176 unownedModule->handle = pyRef;
1177 liveModules[module.ptr] =
1178 std::make_pair(unownedModule->handle, unownedModule);
1179 return PyModuleRef(unownedModule, std::move(pyRef));
1180 }
1181 // Use existing.
1182 PyModule *existing = it->second.second;
1183 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1184 return PyModuleRef(existing, std::move(pyRef));
1185}
1186
1187nb::object PyModule::createFromCapsule(nb::object capsule) {
1188 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1189 if (mlirModuleIsNull(rawModule))
1190 throw nb::python_error();
1191 return forModule(rawModule).releaseObject();
1192}
1193
1194nb::object PyModule::getCapsule() {
1195 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1196}
1197
1198//------------------------------------------------------------------------------
1199// PyOperation
1200//------------------------------------------------------------------------------
1201
1202PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1203 : BaseContextObject(std::move(contextRef)), operation(operation) {}
1204
1205PyOperation::~PyOperation() {
1206 // If the operation has already been invalidated there is nothing to do.
1207 if (!valid)
1208 return;
1209
1210 // Otherwise, invalidate the operation and remove it from live map when it is
1211 // attached.
1212 if (isAttached()) {
1213 getContext()->clearOperation(*this);
1214 } else {
1215 // And destroy it when it is detached, i.e. owned by Python, in which case
1216 // all nested operations must be invalidated at removed from the live map as
1217 // well.
1218 erase();
1219 }
1220}
1221
1222namespace {
1223
1224// Constructs a new object of type T in-place on the Python heap, returning a
1225// PyObjectRef to it, loosely analogous to std::make_shared<T>().
1226template <typename T, class... Args>
1227PyObjectRef<T> makeObjectRef(Args &&...args) {
1228 nb::handle type = nb::type<T>();
1229 nb::object instance = nb::inst_alloc(type);
1230 T *ptr = nb::inst_ptr<T>(instance);
1231 new (ptr) T(std::forward<Args>(args)...);
1232 nb::inst_mark_ready(instance);
1233 return PyObjectRef<T>(ptr, std::move(instance));
1234}
1235
1236} // namespace
1237
1238PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1239 MlirOperation operation,
1240 nb::object parentKeepAlive) {
1241 // Create.
1242 PyOperationRef unownedOperation =
1243 makeObjectRef<PyOperation>(std::move(contextRef), operation);
1244 unownedOperation->handle = unownedOperation.getObject();
1245 if (parentKeepAlive) {
1246 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1247 }
1248 return unownedOperation;
1249}
1250
1251PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
1252 MlirOperation operation,
1253 nb::object parentKeepAlive) {
1254 nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1255 auto &liveOperations = contextRef->liveOperations;
1256 auto it = liveOperations.find(operation.ptr);
1257 if (it == liveOperations.end()) {
1258 // Create.
1259 PyOperationRef result = createInstance(std::move(contextRef), operation,
1260 std::move(parentKeepAlive));
1261 liveOperations[operation.ptr] =
1262 std::make_pair(result.getObject(), result.get());
1263 return result;
1264 }
1265 // Use existing.
1266 PyOperation *existing = it->second.second;
1267 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1268 return PyOperationRef(existing, std::move(pyRef));
1269}
1270
1271PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
1272 MlirOperation operation,
1273 nb::object parentKeepAlive) {
1274 nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1275 auto &liveOperations = contextRef->liveOperations;
1276 assert(liveOperations.count(operation.ptr) == 0 &&
1277 "cannot create detached operation that already exists");
1278 (void)liveOperations;
1279 PyOperationRef created = createInstance(std::move(contextRef), operation,
1280 std::move(parentKeepAlive));
1281 liveOperations[operation.ptr] =
1282 std::make_pair(created.getObject(), created.get());
1283 created->attached = false;
1284 return created;
1285}
1286
1287PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
1288 const std::string &sourceStr,
1289 const std::string &sourceName) {
1290 PyMlirContext::ErrorCapture errors(contextRef);
1291 MlirOperation op =
1292 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1293 toMlirStringRef(sourceName));
1294 if (mlirOperationIsNull(op))
1295 throw MLIRError("Unable to parse operation assembly", errors.take());
1296 return PyOperation::createDetached(std::move(contextRef), op);
1297}
1298
1299void PyOperation::checkValid() const {
1300 if (!valid) {
1301 throw std::runtime_error("the operation has been invalidated");
1302 }
1303}
1304
1305void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1306 bool enableDebugInfo, bool prettyDebugInfo,
1307 bool printGenericOpForm, bool useLocalScope,
1308 bool useNameLocAsPrefix, bool assumeVerified,
1309 nb::object fileObject, bool binary,
1310 bool skipRegions) {
1311 PyOperation &operation = getOperation();
1312 operation.checkValid();
1313 if (fileObject.is_none())
1314 fileObject = nb::module_::import_("sys").attr("stdout");
1315
1316 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1317 if (largeElementsLimit) {
1318 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1319 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit);
1320 }
1321 if (enableDebugInfo)
1322 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1323 /*prettyForm=*/prettyDebugInfo);
1324 if (printGenericOpForm)
1325 mlirOpPrintingFlagsPrintGenericOpForm(flags);
1326 if (useLocalScope)
1327 mlirOpPrintingFlagsUseLocalScope(flags);
1328 if (assumeVerified)
1329 mlirOpPrintingFlagsAssumeVerified(flags);
1330 if (skipRegions)
1331 mlirOpPrintingFlagsSkipRegions(flags);
1332 if (useNameLocAsPrefix)
1333 mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
1334
1335 PyFileAccumulator accum(fileObject, binary);
1336 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1337 accum.getUserData());
1338 mlirOpPrintingFlagsDestroy(flags);
1339}
1340
1341void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1342 bool binary) {
1343 PyOperation &operation = getOperation();
1344 operation.checkValid();
1345 if (fileObject.is_none())
1346 fileObject = nb::module_::import_("sys").attr("stdout");
1347 PyFileAccumulator accum(fileObject, binary);
1348 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1349 accum.getUserData());
1350}
1351
1352void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
1353 std::optional<int64_t> bytecodeVersion) {
1354 PyOperation &operation = getOperation();
1355 operation.checkValid();
1356 PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
1357
1358 if (!bytecodeVersion.has_value())
1359 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1360 accum.getUserData());
1361
1362 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1363 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1364 MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
1365 operation, config, accum.getCallback(), accum.getUserData());
1366 mlirBytecodeWriterConfigDestroy(config);
1367 if (mlirLogicalResultIsFailure(res))
1368 throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1369 Twine(*bytecodeVersion))
1370 .str()
1371 .c_str());
1372}
1373
1374void PyOperationBase::walk(
1375 std::function<MlirWalkResult(MlirOperation)> callback,
1376 MlirWalkOrder walkOrder) {
1377 PyOperation &operation = getOperation();
1378 operation.checkValid();
1379 struct UserData {
1380 std::function<MlirWalkResult(MlirOperation)> callback;
1381 bool gotException;
1382 std::string exceptionWhat;
1383 nb::object exceptionType;
1384 };
1385 UserData userData{callback, false, {}, {}};
1386 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1387 void *userData) {
1388 UserData *calleeUserData = static_cast<UserData *>(userData);
1389 try {
1390 return (calleeUserData->callback)(op);
1391 } catch (nb::python_error &e) {
1392 calleeUserData->gotException = true;
1393 calleeUserData->exceptionWhat = std::string(e.what());
1394 calleeUserData->exceptionType = nb::borrow(e.type());
1395 return MlirWalkResult::MlirWalkResultInterrupt;
1396 }
1397 };
1398 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1399 if (userData.gotException) {
1400 std::string message("Exception raised in callback: ");
1401 message.append(str: userData.exceptionWhat);
1402 throw std::runtime_error(message);
1403 }
1404}
1405
1406nb::object PyOperationBase::getAsm(bool binary,
1407 std::optional<int64_t> largeElementsLimit,
1408 bool enableDebugInfo, bool prettyDebugInfo,
1409 bool printGenericOpForm, bool useLocalScope,
1410 bool useNameLocAsPrefix, bool assumeVerified,
1411 bool skipRegions) {
1412 nb::object fileObject;
1413 if (binary) {
1414 fileObject = nb::module_::import_("io").attr("BytesIO")();
1415 } else {
1416 fileObject = nb::module_::import_("io").attr("StringIO")();
1417 }
1418 print(/*largeElementsLimit=*/largeElementsLimit,
1419 /*enableDebugInfo=*/enableDebugInfo,
1420 /*prettyDebugInfo=*/prettyDebugInfo,
1421 /*printGenericOpForm=*/printGenericOpForm,
1422 /*useLocalScope=*/useLocalScope,
1423 /*useNameLocAsPrefix=*/useNameLocAsPrefix,
1424 /*assumeVerified=*/assumeVerified,
1425 /*fileObject=*/fileObject,
1426 /*binary=*/binary,
1427 /*skipRegions=*/skipRegions);
1428
1429 return fileObject.attr("getvalue")();
1430}
1431
1432void PyOperationBase::moveAfter(PyOperationBase &other) {
1433 PyOperation &operation = getOperation();
1434 PyOperation &otherOp = other.getOperation();
1435 operation.checkValid();
1436 otherOp.checkValid();
1437 mlirOperationMoveAfter(operation, otherOp);
1438 operation.parentKeepAlive = otherOp.parentKeepAlive;
1439}
1440
1441void PyOperationBase::moveBefore(PyOperationBase &other) {
1442 PyOperation &operation = getOperation();
1443 PyOperation &otherOp = other.getOperation();
1444 operation.checkValid();
1445 otherOp.checkValid();
1446 mlirOperationMoveBefore(operation, otherOp);
1447 operation.parentKeepAlive = otherOp.parentKeepAlive;
1448}
1449
1450bool PyOperationBase::verify() {
1451 PyOperation &op = getOperation();
1452 PyMlirContext::ErrorCapture errors(op.getContext());
1453 if (!mlirOperationVerify(op.get()))
1454 throw MLIRError("Verification failed", errors.take());
1455 return true;
1456}
1457
1458std::optional<PyOperationRef> PyOperation::getParentOperation() {
1459 checkValid();
1460 if (!isAttached())
1461 throw nb::value_error("Detached operations have no parent");
1462 MlirOperation operation = mlirOperationGetParentOperation(get());
1463 if (mlirOperationIsNull(operation))
1464 return {};
1465 return PyOperation::forOperation(getContext(), operation);
1466}
1467
1468PyBlock PyOperation::getBlock() {
1469 checkValid();
1470 std::optional<PyOperationRef> parentOperation = getParentOperation();
1471 MlirBlock block = mlirOperationGetBlock(get());
1472 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1473 assert(parentOperation && "Operation has no parent");
1474 return PyBlock{std::move(*parentOperation), block};
1475}
1476
1477nb::object PyOperation::getCapsule() {
1478 checkValid();
1479 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1480}
1481
1482nb::object PyOperation::createFromCapsule(nb::object capsule) {
1483 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1484 if (mlirOperationIsNull(rawOperation))
1485 throw nb::python_error();
1486 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1487 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1488 .releaseObject();
1489}
1490
1491static void maybeInsertOperation(PyOperationRef &op,
1492 const nb::object &maybeIp) {
1493 // InsertPoint active?
1494 if (!maybeIp.is(nb::cast(false))) {
1495 PyInsertionPoint *ip;
1496 if (maybeIp.is_none()) {
1497 ip = PyThreadContextEntry::getDefaultInsertionPoint();
1498 } else {
1499 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1500 }
1501 if (ip)
1502 ip->insert(*op.get());
1503 }
1504}
1505
1506nb::object PyOperation::create(std::string_view name,
1507 std::optional<std::vector<PyType *>> results,
1508 llvm::ArrayRef<MlirValue> operands,
1509 std::optional<nb::dict> attributes,
1510 std::optional<std::vector<PyBlock *>> successors,
1511 int regions, DefaultingPyLocation location,
1512 const nb::object &maybeIp, bool inferType) {
1513 llvm::SmallVector<MlirType, 4> mlirResults;
1514 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1515 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1516
1517 // General parameter validation.
1518 if (regions < 0)
1519 throw nb::value_error("number of regions must be >= 0");
1520
1521 // Unpack/validate results.
1522 if (results) {
1523 mlirResults.reserve(results->size());
1524 for (PyType *result : *results) {
1525 // TODO: Verify result type originate from the same context.
1526 if (!result)
1527 throw nb::value_error("result type cannot be None");
1528 mlirResults.push_back(*result);
1529 }
1530 }
1531 // Unpack/validate attributes.
1532 if (attributes) {
1533 mlirAttributes.reserve(attributes->size());
1534 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1535 std::string key;
1536 try {
1537 key = nb::cast<std::string>(it.first);
1538 } catch (nb::cast_error &err) {
1539 std::string msg = "Invalid attribute key (not a string) when "
1540 "attempting to create the operation \"" +
1541 std::string(name) + "\" (" + err.what() + ")";
1542 throw nb::type_error(msg.c_str());
1543 }
1544 try {
1545 auto &attribute = nb::cast<PyAttribute &>(it.second);
1546 // TODO: Verify attribute originates from the same context.
1547 mlirAttributes.emplace_back(std::move(key), attribute);
1548 } catch (nb::cast_error &err) {
1549 std::string msg = "Invalid attribute value for the key \"" + key +
1550 "\" when attempting to create the operation \"" +
1551 std::string(name) + "\" (" + err.what() + ")";
1552 throw nb::type_error(msg.c_str());
1553 } catch (std::runtime_error &) {
1554 // This exception seems thrown when the value is "None".
1555 std::string msg =
1556 "Found an invalid (`None`?) attribute value for the key \"" + key +
1557 "\" when attempting to create the operation \"" +
1558 std::string(name) + "\"";
1559 throw std::runtime_error(msg);
1560 }
1561 }
1562 }
1563 // Unpack/validate successors.
1564 if (successors) {
1565 mlirSuccessors.reserve(successors->size());
1566 for (auto *successor : *successors) {
1567 // TODO: Verify successor originate from the same context.
1568 if (!successor)
1569 throw nb::value_error("successor block cannot be None");
1570 mlirSuccessors.push_back(successor->get());
1571 }
1572 }
1573
1574 // Apply unpacked/validated to the operation state. Beyond this
1575 // point, exceptions cannot be thrown or else the state will leak.
1576 MlirOperationState state =
1577 mlirOperationStateGet(toMlirStringRef(name), location);
1578 if (!operands.empty())
1579 mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1580 state.enableResultTypeInference = inferType;
1581 if (!mlirResults.empty())
1582 mlirOperationStateAddResults(&state, mlirResults.size(),
1583 mlirResults.data());
1584 if (!mlirAttributes.empty()) {
1585 // Note that the attribute names directly reference bytes in
1586 // mlirAttributes, so that vector must not be changed from here
1587 // on.
1588 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1589 mlirNamedAttributes.reserve(mlirAttributes.size());
1590 for (auto &it : mlirAttributes)
1591 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1592 mlirIdentifierGet(mlirAttributeGetContext(it.second),
1593 toMlirStringRef(it.first)),
1594 it.second));
1595 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1596 mlirNamedAttributes.data());
1597 }
1598 if (!mlirSuccessors.empty())
1599 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1600 mlirSuccessors.data());
1601 if (regions) {
1602 llvm::SmallVector<MlirRegion, 4> mlirRegions;
1603 mlirRegions.resize(regions);
1604 for (int i = 0; i < regions; ++i)
1605 mlirRegions[i] = mlirRegionCreate();
1606 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1607 mlirRegions.data());
1608 }
1609
1610 // Construct the operation.
1611 MlirOperation operation = mlirOperationCreate(&state);
1612 if (!operation.ptr)
1613 throw nb::value_error("Operation creation failed");
1614 PyOperationRef created =
1615 PyOperation::createDetached(location->getContext(), operation);
1616 maybeInsertOperation(created, maybeIp);
1617
1618 return created.getObject();
1619}
1620
1621nb::object PyOperation::clone(const nb::object &maybeIp) {
1622 MlirOperation clonedOperation = mlirOperationClone(operation);
1623 PyOperationRef cloned =
1624 PyOperation::createDetached(getContext(), clonedOperation);
1625 maybeInsertOperation(cloned, maybeIp);
1626
1627 return cloned->createOpView();
1628}
1629
1630nb::object PyOperation::createOpView() {
1631 checkValid();
1632 MlirIdentifier ident = mlirOperationGetName(get());
1633 MlirStringRef identStr = mlirIdentifierStr(ident);
1634 auto operationCls = PyGlobals::get().lookupOperationClass(
1635 StringRef(identStr.data, identStr.length));
1636 if (operationCls)
1637 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1638 return nb::cast(PyOpView(getRef().getObject()));
1639}
1640
1641void PyOperation::erase() {
1642 checkValid();
1643 getContext()->clearOperationAndInside(*this);
1644 mlirOperationDestroy(operation);
1645}
1646
1647namespace {
1648/// CRTP base class for Python MLIR values that subclass Value and should be
1649/// castable from it. The value hierarchy is one level deep and is not supposed
1650/// to accommodate other levels unless core MLIR changes.
1651template <typename DerivedTy>
1652class PyConcreteValue : public PyValue {
1653public:
1654 // Derived classes must define statics for:
1655 // IsAFunctionTy isaFunction
1656 // const char *pyClassName
1657 // and redefine bindDerived.
1658 using ClassTy = nb::class_<DerivedTy, PyValue>;
1659 using IsAFunctionTy = bool (*)(MlirValue);
1660
1661 PyConcreteValue() = default;
1662 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1663 : PyValue(operationRef, value) {}
1664 PyConcreteValue(PyValue &orig)
1665 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1666
1667 /// Attempts to cast the original value to the derived type and throws on
1668 /// type mismatches.
1669 static MlirValue castFrom(PyValue &orig) {
1670 if (!DerivedTy::isaFunction(orig.get())) {
1671 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1672 throw nb::value_error((Twine("Cannot cast value to ") +
1673 DerivedTy::pyClassName + " (from " + origRepr +
1674 ")")
1675 .str()
1676 .c_str());
1677 }
1678 return orig.get();
1679 }
1680
1681 /// Binds the Python module objects to functions of this class.
1682 static void bind(nb::module_ &m) {
1683 auto cls = ClassTy(m, DerivedTy::pyClassName);
1684 cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1685 cls.def_static(
1686 "isinstance",
1687 [](PyValue &otherValue) -> bool {
1688 return DerivedTy::isaFunction(otherValue);
1689 },
1690 nb::arg("other_value"));
1691 cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
1692 [](DerivedTy &self) { return self.maybeDownCast(); });
1693 DerivedTy::bindDerived(cls);
1694 }
1695
1696 /// Implemented by derived classes to add methods to the Python subclass.
1697 static void bindDerived(ClassTy &m) {}
1698};
1699
1700} // namespace
1701
1702/// Python wrapper for MlirOpResult.
1703class PyOpResult : public PyConcreteValue<PyOpResult> {
1704public:
1705 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1706 static constexpr const char *pyClassName = "OpResult";
1707 using PyConcreteValue::PyConcreteValue;
1708
1709 static void bindDerived(ClassTy &c) {
1710 c.def_prop_ro("owner", [](PyOpResult &self) {
1711 assert(
1712 mlirOperationEqual(self.getParentOperation()->get(),
1713 mlirOpResultGetOwner(self.get())) &&
1714 "expected the owner of the value in Python to match that in the IR");
1715 return self.getParentOperation().getObject();
1716 });
1717 c.def_prop_ro("result_number", [](PyOpResult &self) {
1718 return mlirOpResultGetResultNumber(self.get());
1719 });
1720 }
1721};
1722
1723/// Returns the list of types of the values held by container.
1724template <typename Container>
1725static std::vector<MlirType> getValueTypes(Container &container,
1726 PyMlirContextRef &context) {
1727 std::vector<MlirType> result;
1728 result.reserve(container.size());
1729 for (int i = 0, e = container.size(); i < e; ++i) {
1730 result.push_back(mlirValueGetType(container.getElement(i).get()));
1731 }
1732 return result;
1733}
1734
1735/// A list of operation results. Internally, these are stored as consecutive
1736/// elements, random access is cheap. The (returned) result list is associated
1737/// with the operation whose results these are, and thus extends the lifetime of
1738/// this operation.
1739class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1740public:
1741 static constexpr const char *pyClassName = "OpResultList";
1742 using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
1743
1744 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1745 intptr_t length = -1, intptr_t step = 1)
1746 : Sliceable(startIndex,
1747 length == -1 ? mlirOperationGetNumResults(operation->get())
1748 : length,
1749 step),
1750 operation(std::move(operation)) {}
1751
1752 static void bindDerived(ClassTy &c) {
1753 c.def_prop_ro("types", [](PyOpResultList &self) {
1754 return getValueTypes(self, self.operation->getContext());
1755 });
1756 c.def_prop_ro("owner", [](PyOpResultList &self) {
1757 return self.operation->createOpView();
1758 });
1759 }
1760
1761 PyOperationRef &getOperation() { return operation; }
1762
1763private:
1764 /// Give the parent CRTP class access to hook implementations below.
1765 friend class Sliceable<PyOpResultList, PyOpResult>;
1766
1767 intptr_t getRawNumElements() {
1768 operation->checkValid();
1769 return mlirOperationGetNumResults(operation->get());
1770 }
1771
1772 PyOpResult getRawElement(intptr_t index) {
1773 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1774 return PyOpResult(value);
1775 }
1776
1777 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1778 return PyOpResultList(operation, startIndex, length, step);
1779 }
1780
1781 PyOperationRef operation;
1782};
1783
1784//------------------------------------------------------------------------------
1785// PyOpView
1786//------------------------------------------------------------------------------
1787
1788static void populateResultTypes(StringRef name, nb::list resultTypeList,
1789 const nb::object &resultSegmentSpecObj,
1790 std::vector<int32_t> &resultSegmentLengths,
1791 std::vector<PyType *> &resultTypes) {
1792 resultTypes.reserve(n: resultTypeList.size());
1793 if (resultSegmentSpecObj.is_none()) {
1794 // Non-variadic result unpacking.
1795 for (const auto &it : llvm::enumerate(resultTypeList)) {
1796 try {
1797 resultTypes.push_back(nb::cast<PyType *>(it.value()));
1798 if (!resultTypes.back())
1799 throw nb::cast_error();
1800 } catch (nb::cast_error &err) {
1801 throw nb::value_error((llvm::Twine("Result ") +
1802 llvm::Twine(it.index()) + " of operation \"" +
1803 name + "\" must be a Type (" + err.what() + ")")
1804 .str()
1805 .c_str());
1806 }
1807 }
1808 } else {
1809 // Sized result unpacking.
1810 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1811 if (resultSegmentSpec.size() != resultTypeList.size()) {
1812 throw nb::value_error((llvm::Twine("Operation \"") + name +
1813 "\" requires " +
1814 llvm::Twine(resultSegmentSpec.size()) +
1815 " result segments but was provided " +
1816 llvm::Twine(resultTypeList.size()))
1817 .str()
1818 .c_str());
1819 }
1820 resultSegmentLengths.reserve(n: resultTypeList.size());
1821 for (const auto &it :
1822 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1823 int segmentSpec = std::get<1>(it.value());
1824 if (segmentSpec == 1 || segmentSpec == 0) {
1825 // Unpack unary element.
1826 try {
1827 auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1828 if (resultType) {
1829 resultTypes.push_back(resultType);
1830 resultSegmentLengths.push_back(1);
1831 } else if (segmentSpec == 0) {
1832 // Allowed to be optional.
1833 resultSegmentLengths.push_back(0);
1834 } else {
1835 throw nb::value_error(
1836 (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1837 " of operation \"" + name +
1838 "\" must be a Type (was None and result is not optional)")
1839 .str()
1840 .c_str());
1841 }
1842 } catch (nb::cast_error &err) {
1843 throw nb::value_error((llvm::Twine("Result ") +
1844 llvm::Twine(it.index()) + " of operation \"" +
1845 name + "\" must be a Type (" + err.what() +
1846 ")")
1847 .str()
1848 .c_str());
1849 }
1850 } else if (segmentSpec == -1) {
1851 // Unpack sequence by appending.
1852 try {
1853 if (std::get<0>(it.value()).is_none()) {
1854 // Treat it as an empty list.
1855 resultSegmentLengths.push_back(0);
1856 } else {
1857 // Unpack the list.
1858 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1859 for (nb::handle segmentItem : segment) {
1860 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1861 if (!resultTypes.back()) {
1862 throw nb::type_error("contained a None item");
1863 }
1864 }
1865 resultSegmentLengths.push_back(nb::len(segment));
1866 }
1867 } catch (std::exception &err) {
1868 // NOTE: Sloppy to be using a catch-all here, but there are at least
1869 // three different unrelated exceptions that can be thrown in the
1870 // above "casts". Just keep the scope above small and catch them all.
1871 throw nb::value_error((llvm::Twine("Result ") +
1872 llvm::Twine(it.index()) + " of operation \"" +
1873 name + "\" must be a Sequence of Types (" +
1874 err.what() + ")")
1875 .str()
1876 .c_str());
1877 }
1878 } else {
1879 throw nb::value_error("Unexpected segment spec");
1880 }
1881 }
1882 }
1883}
1884
1885static MlirValue getUniqueResult(MlirOperation operation) {
1886 auto numResults = mlirOperationGetNumResults(operation);
1887 if (numResults != 1) {
1888 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1889 throw nb::value_error((Twine("Cannot call .result on operation ") +
1890 StringRef(name.data, name.length) + " which has " +
1891 Twine(numResults) +
1892 " results (it is only valid for operations with a "
1893 "single result)")
1894 .str()
1895 .c_str());
1896 }
1897 return mlirOperationGetResult(operation, 0);
1898}
1899
1900static MlirValue getOpResultOrValue(nb::handle operand) {
1901 if (operand.is_none()) {
1902 throw nb::value_error("contained a None item");
1903 }
1904 PyOperationBase *op;
1905 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1906 return getUniqueResult(op->getOperation());
1907 }
1908 PyOpResultList *opResultList;
1909 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1910 return getUniqueResult(opResultList->getOperation()->get());
1911 }
1912 PyValue *value;
1913 if (nb::try_cast<PyValue *>(operand, value)) {
1914 return value->get();
1915 }
1916 throw nb::value_error("is not a Value");
1917}
1918
1919nb::object PyOpView::buildGeneric(
1920 std::string_view name, std::tuple<int, bool> opRegionSpec,
1921 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1922 std::optional<nb::list> resultTypeList, nb::list operandList,
1923 std::optional<nb::dict> attributes,
1924 std::optional<std::vector<PyBlock *>> successors,
1925 std::optional<int> regions, DefaultingPyLocation location,
1926 const nb::object &maybeIp) {
1927 PyMlirContextRef context = location->getContext();
1928
1929 // Class level operation construction metadata.
1930 // Operand and result segment specs are either none, which does no
1931 // variadic unpacking, or a list of ints with segment sizes, where each
1932 // element is either a positive number (typically 1 for a scalar) or -1 to
1933 // indicate that it is derived from the length of the same-indexed operand
1934 // or result (implying that it is a list at that position).
1935 std::vector<int32_t> operandSegmentLengths;
1936 std::vector<int32_t> resultSegmentLengths;
1937
1938 // Validate/determine region count.
1939 int opMinRegionCount = std::get<0>(t&: opRegionSpec);
1940 bool opHasNoVariadicRegions = std::get<1>(t&: opRegionSpec);
1941 if (!regions) {
1942 regions = opMinRegionCount;
1943 }
1944 if (*regions < opMinRegionCount) {
1945 throw nb::value_error(
1946 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1947 llvm::Twine(opMinRegionCount) +
1948 " regions but was built with regions=" + llvm::Twine(*regions))
1949 .str()
1950 .c_str());
1951 }
1952 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1953 throw nb::value_error(
1954 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1955 llvm::Twine(opMinRegionCount) +
1956 " regions but was built with regions=" + llvm::Twine(*regions))
1957 .str()
1958 .c_str());
1959 }
1960
1961 // Unpack results.
1962 std::vector<PyType *> resultTypes;
1963 if (resultTypeList.has_value()) {
1964 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1965 resultSegmentLengths, resultTypes);
1966 }
1967
1968 // Unpack operands.
1969 llvm::SmallVector<MlirValue, 4> operands;
1970 operands.reserve(operands.size());
1971 if (operandSegmentSpecObj.is_none()) {
1972 // Non-sized operand unpacking.
1973 for (const auto &it : llvm::enumerate(operandList)) {
1974 try {
1975 operands.push_back(getOpResultOrValue(it.value()));
1976 } catch (nb::builtin_exception &err) {
1977 throw nb::value_error((llvm::Twine("Operand ") +
1978 llvm::Twine(it.index()) + " of operation \"" +
1979 name + "\" must be a Value (" + err.what() + ")")
1980 .str()
1981 .c_str());
1982 }
1983 }
1984 } else {
1985 // Sized operand unpacking.
1986 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1987 if (operandSegmentSpec.size() != operandList.size()) {
1988 throw nb::value_error((llvm::Twine("Operation \"") + name +
1989 "\" requires " +
1990 llvm::Twine(operandSegmentSpec.size()) +
1991 "operand segments but was provided " +
1992 llvm::Twine(operandList.size()))
1993 .str()
1994 .c_str());
1995 }
1996 operandSegmentLengths.reserve(n: operandList.size());
1997 for (const auto &it :
1998 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1999 int segmentSpec = std::get<1>(it.value());
2000 if (segmentSpec == 1 || segmentSpec == 0) {
2001 // Unpack unary element.
2002 auto &operand = std::get<0>(it.value());
2003 if (!operand.is_none()) {
2004 try {
2005
2006 operands.push_back(getOpResultOrValue(operand));
2007 } catch (nb::builtin_exception &err) {
2008 throw nb::value_error((llvm::Twine("Operand ") +
2009 llvm::Twine(it.index()) +
2010 " of operation \"" + name +
2011 "\" must be a Value (" + err.what() + ")")
2012 .str()
2013 .c_str());
2014 }
2015
2016 operandSegmentLengths.push_back(1);
2017 } else if (segmentSpec == 0) {
2018 // Allowed to be optional.
2019 operandSegmentLengths.push_back(0);
2020 } else {
2021 throw nb::value_error(
2022 (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
2023 " of operation \"" + name +
2024 "\" must be a Value (was None and operand is not optional)")
2025 .str()
2026 .c_str());
2027 }
2028 } else if (segmentSpec == -1) {
2029 // Unpack sequence by appending.
2030 try {
2031 if (std::get<0>(it.value()).is_none()) {
2032 // Treat it as an empty list.
2033 operandSegmentLengths.push_back(0);
2034 } else {
2035 // Unpack the list.
2036 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
2037 for (nb::handle segmentItem : segment) {
2038 operands.push_back(getOpResultOrValue(segmentItem));
2039 }
2040 operandSegmentLengths.push_back(nb::len(segment));
2041 }
2042 } catch (std::exception &err) {
2043 // NOTE: Sloppy to be using a catch-all here, but there are at least
2044 // three different unrelated exceptions that can be thrown in the
2045 // above "casts". Just keep the scope above small and catch them all.
2046 throw nb::value_error((llvm::Twine("Operand ") +
2047 llvm::Twine(it.index()) + " of operation \"" +
2048 name + "\" must be a Sequence of Values (" +
2049 err.what() + ")")
2050 .str()
2051 .c_str());
2052 }
2053 } else {
2054 throw nb::value_error("Unexpected segment spec");
2055 }
2056 }
2057 }
2058
2059 // Merge operand/result segment lengths into attributes if needed.
2060 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
2061 // Dup.
2062 if (attributes) {
2063 attributes = nb::dict(*attributes);
2064 } else {
2065 attributes = nb::dict();
2066 }
2067 if (attributes->contains("resultSegmentSizes") ||
2068 attributes->contains("operandSegmentSizes")) {
2069 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
2070 "'operandSegmentSizes' attribute is unsupported. "
2071 "Use Operation.create for such low-level access.");
2072 }
2073
2074 // Add resultSegmentSizes attribute.
2075 if (!resultSegmentLengths.empty()) {
2076 MlirAttribute segmentLengthAttr =
2077 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
2078 resultSegmentLengths.data());
2079 (*attributes)["resultSegmentSizes"] =
2080 PyAttribute(context, segmentLengthAttr);
2081 }
2082
2083 // Add operandSegmentSizes attribute.
2084 if (!operandSegmentLengths.empty()) {
2085 MlirAttribute segmentLengthAttr =
2086 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
2087 operandSegmentLengths.data());
2088 (*attributes)["operandSegmentSizes"] =
2089 PyAttribute(context, segmentLengthAttr);
2090 }
2091 }
2092
2093 // Delegate to create.
2094 return PyOperation::create(name,
2095 /*results=*/std::move(resultTypes),
2096 /*operands=*/std::move(operands),
2097 /*attributes=*/std::move(attributes),
2098 /*successors=*/std::move(successors),
2099 /*regions=*/*regions, location, maybeIp,
2100 !resultTypeList);
2101}
2102
2103nb::object PyOpView::constructDerived(const nb::object &cls,
2104 const nb::object &operation) {
2105 nb::handle opViewType = nb::type<PyOpView>();
2106 nb::object instance = cls.attr("__new__")(cls);
2107 opViewType.attr("__init__")(instance, operation);
2108 return instance;
2109}
2110
2111PyOpView::PyOpView(const nb::object &operationObject)
2112 // Casting through the PyOperationBase base-class and then back to the
2113 // Operation lets us accept any PyOperationBase subclass.
2114 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
2115 operationObject(operation.getRef().getObject()) {}
2116
2117//------------------------------------------------------------------------------
2118// PyInsertionPoint.
2119//------------------------------------------------------------------------------
2120
2121PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
2122
2123PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
2124 : refOperation(beforeOperationBase.getOperation().getRef()),
2125 block((*refOperation)->getBlock()) {}
2126
2127void PyInsertionPoint::insert(PyOperationBase &operationBase) {
2128 PyOperation &operation = operationBase.getOperation();
2129 if (operation.isAttached())
2130 throw nb::value_error(
2131 "Attempt to insert operation that is already attached");
2132 block.getParentOperation()->checkValid();
2133 MlirOperation beforeOp = {nullptr};
2134 if (refOperation) {
2135 // Insert before operation.
2136 (*refOperation)->checkValid();
2137 beforeOp = (*refOperation)->get();
2138 } else {
2139 // Insert at end (before null) is only valid if the block does not
2140 // already end in a known terminator (violating this will cause assertion
2141 // failures later).
2142 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
2143 throw nb::index_error("Cannot insert operation at the end of a block "
2144 "that already has a terminator. Did you mean to "
2145 "use 'InsertionPoint.at_block_terminator(block)' "
2146 "versus 'InsertionPoint(block)'?");
2147 }
2148 }
2149 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
2150 operation.setAttached();
2151}
2152
2153PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
2154 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
2155 if (mlirOperationIsNull(firstOp)) {
2156 // Just insert at end.
2157 return PyInsertionPoint(block);
2158 }
2159
2160 // Insert before first op.
2161 PyOperationRef firstOpRef = PyOperation::forOperation(
2162 block.getParentOperation()->getContext(), firstOp);
2163 return PyInsertionPoint{block, std::move(firstOpRef)};
2164}
2165
2166PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
2167 MlirOperation terminator = mlirBlockGetTerminator(block.get());
2168 if (mlirOperationIsNull(terminator))
2169 throw nb::value_error("Block has no terminator");
2170 PyOperationRef terminatorOpRef = PyOperation::forOperation(
2171 block.getParentOperation()->getContext(), terminator);
2172 return PyInsertionPoint{block, std::move(terminatorOpRef)};
2173}
2174
2175nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2176 return PyThreadContextEntry::pushInsertionPoint(insertPoint);
2177}
2178
2179void PyInsertionPoint::contextExit(const nb::object &excType,
2180 const nb::object &excVal,
2181 const nb::object &excTb) {
2182 PyThreadContextEntry::popInsertionPoint(insertionPoint&: *this);
2183}
2184
2185//------------------------------------------------------------------------------
2186// PyAttribute.
2187//------------------------------------------------------------------------------
2188
2189bool PyAttribute::operator==(const PyAttribute &other) const {
2190 return mlirAttributeEqual(attr, other.attr);
2191}
2192
2193nb::object PyAttribute::getCapsule() {
2194 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2195}
2196
2197PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
2198 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2199 if (mlirAttributeIsNull(rawAttr))
2200 throw nb::python_error();
2201 return PyAttribute(
2202 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
2203}
2204
2205//------------------------------------------------------------------------------
2206// PyNamedAttribute.
2207//------------------------------------------------------------------------------
2208
2209PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2210 : ownedName(new std::string(std::move(ownedName))) {
2211 namedAttr = mlirNamedAttributeGet(
2212 mlirIdentifierGet(mlirAttributeGetContext(attr),
2213 toMlirStringRef(*this->ownedName)),
2214 attr);
2215}
2216
2217//------------------------------------------------------------------------------
2218// PyType.
2219//------------------------------------------------------------------------------
2220
2221bool PyType::operator==(const PyType &other) const {
2222 return mlirTypeEqual(type, other.type);
2223}
2224
2225nb::object PyType::getCapsule() {
2226 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2227}
2228
2229PyType PyType::createFromCapsule(nb::object capsule) {
2230 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2231 if (mlirTypeIsNull(rawType))
2232 throw nb::python_error();
2233 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
2234 rawType);
2235}
2236
2237//------------------------------------------------------------------------------
2238// PyTypeID.
2239//------------------------------------------------------------------------------
2240
2241nb::object PyTypeID::getCapsule() {
2242 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2243}
2244
2245PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
2246 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2247 if (mlirTypeIDIsNull(mlirTypeID))
2248 throw nb::python_error();
2249 return PyTypeID(mlirTypeID);
2250}
2251bool PyTypeID::operator==(const PyTypeID &other) const {
2252 return mlirTypeIDEqual(typeID, other.typeID);
2253}
2254
2255//------------------------------------------------------------------------------
2256// PyValue and subclasses.
2257//------------------------------------------------------------------------------
2258
2259nb::object PyValue::getCapsule() {
2260 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2261}
2262
2263nb::object PyValue::maybeDownCast() {
2264 MlirType type = mlirValueGetType(get());
2265 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2266 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2267 "mlirTypeID was expected to be non-null.");
2268 std::optional<nb::callable> valueCaster =
2269 PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2270 // nb::rv_policy::move means use std::move to move the return value
2271 // contents into a new instance that will be owned by Python.
2272 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2273 if (!valueCaster)
2274 return thisObj;
2275 return valueCaster.value()(thisObj);
2276}
2277
2278PyValue PyValue::createFromCapsule(nb::object capsule) {
2279 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2280 if (mlirValueIsNull(value))
2281 throw nb::python_error();
2282 MlirOperation owner;
2283 if (mlirValueIsAOpResult(value))
2284 owner = mlirOpResultGetOwner(value);
2285 if (mlirValueIsABlockArgument(value))
2286 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
2287 if (mlirOperationIsNull(owner))
2288 throw nb::python_error();
2289 MlirContext ctx = mlirOperationGetContext(owner);
2290 PyOperationRef ownerRef =
2291 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
2292 return PyValue(ownerRef, value);
2293}
2294
2295//------------------------------------------------------------------------------
2296// PySymbolTable.
2297//------------------------------------------------------------------------------
2298
2299PySymbolTable::PySymbolTable(PyOperationBase &operation)
2300 : operation(operation.getOperation().getRef()) {
2301 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2302 if (mlirSymbolTableIsNull(symbolTable)) {
2303 throw nb::type_error("Operation is not a Symbol Table.");
2304 }
2305}
2306
2307nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2308 operation->checkValid();
2309 MlirOperation symbol = mlirSymbolTableLookup(
2310 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2311 if (mlirOperationIsNull(symbol))
2312 throw nb::key_error(
2313 ("Symbol '" + name + "' not in the symbol table.").c_str());
2314
2315 return PyOperation::forOperation(operation->getContext(), symbol,
2316 operation.getObject())
2317 ->createOpView();
2318}
2319
2320void PySymbolTable::erase(PyOperationBase &symbol) {
2321 operation->checkValid();
2322 symbol.getOperation().checkValid();
2323 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2324 // The operation is also erased, so we must invalidate it. There may be Python
2325 // references to this operation so we don't want to delete it from the list of
2326 // live operations here.
2327 symbol.getOperation().valid = false;
2328}
2329
2330void PySymbolTable::dunderDel(const std::string &name) {
2331 nb::object operation = dunderGetItem(name);
2332 erase(nb::cast<PyOperationBase &>(operation));
2333}
2334
2335MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2336 operation->checkValid();
2337 symbol.getOperation().checkValid();
2338 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2339 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
2340 if (mlirAttributeIsNull(symbolAttr))
2341 throw nb::value_error("Expected operation to have a symbol name.");
2342 return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2343}
2344
2345MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
2346 // Op must already be a symbol.
2347 PyOperation &operation = symbol.getOperation();
2348 operation.checkValid();
2349 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2350 MlirAttribute existingNameAttr =
2351 mlirOperationGetAttributeByName(operation.get(), attrName);
2352 if (mlirAttributeIsNull(existingNameAttr))
2353 throw nb::value_error("Expected operation to have a symbol name.");
2354 return existingNameAttr;
2355}
2356
2357void PySymbolTable::setSymbolName(PyOperationBase &symbol,
2358 const std::string &name) {
2359 // Op must already be a symbol.
2360 PyOperation &operation = symbol.getOperation();
2361 operation.checkValid();
2362 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2363 MlirAttribute existingNameAttr =
2364 mlirOperationGetAttributeByName(operation.get(), attrName);
2365 if (mlirAttributeIsNull(existingNameAttr))
2366 throw nb::value_error("Expected operation to have a symbol name.");
2367 MlirAttribute newNameAttr =
2368 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2369 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2370}
2371
2372MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
2373 PyOperation &operation = symbol.getOperation();
2374 operation.checkValid();
2375 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2376 MlirAttribute existingVisAttr =
2377 mlirOperationGetAttributeByName(operation.get(), attrName);
2378 if (mlirAttributeIsNull(existingVisAttr))
2379 throw nb::value_error("Expected operation to have a symbol visibility.");
2380 return existingVisAttr;
2381}
2382
2383void PySymbolTable::setVisibility(PyOperationBase &symbol,
2384 const std::string &visibility) {
2385 if (visibility != "public" && visibility != "private" &&
2386 visibility != "nested")
2387 throw nb::value_error(
2388 "Expected visibility to be 'public', 'private' or 'nested'");
2389 PyOperation &operation = symbol.getOperation();
2390 operation.checkValid();
2391 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2392 MlirAttribute existingVisAttr =
2393 mlirOperationGetAttributeByName(operation.get(), attrName);
2394 if (mlirAttributeIsNull(existingVisAttr))
2395 throw nb::value_error("Expected operation to have a symbol visibility.");
2396 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2397 toMlirStringRef(visibility));
2398 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2399}
2400
2401void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2402 const std::string &newSymbol,
2403 PyOperationBase &from) {
2404 PyOperation &fromOperation = from.getOperation();
2405 fromOperation.checkValid();
2406 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
2407 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2408 from.getOperation())))
2409
2410 throw nb::value_error("Symbol rename failed");
2411}
2412
2413void PySymbolTable::walkSymbolTables(PyOperationBase &from,
2414 bool allSymUsesVisible,
2415 nb::object callback) {
2416 PyOperation &fromOperation = from.getOperation();
2417 fromOperation.checkValid();
2418 struct UserData {
2419 PyMlirContextRef context;
2420 nb::object callback;
2421 bool gotException;
2422 std::string exceptionWhat;
2423 nb::object exceptionType;
2424 };
2425 UserData userData{
2426 fromOperation.getContext(), std::move(callback), false, {}, {}};
2427 mlirSymbolTableWalkSymbolTables(
2428 fromOperation.get(), allSymUsesVisible,
2429 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2430 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2431 auto pyFoundOp =
2432 PyOperation::forOperation(calleeUserData->context, foundOp);
2433 if (calleeUserData->gotException)
2434 return;
2435 try {
2436 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2437 } catch (nb::python_error &e) {
2438 calleeUserData->gotException = true;
2439 calleeUserData->exceptionWhat = e.what();
2440 calleeUserData->exceptionType = nb::borrow(e.type());
2441 }
2442 },
2443 static_cast<void *>(&userData));
2444 if (userData.gotException) {
2445 std::string message("Exception raised in callback: ");
2446 message.append(str: userData.exceptionWhat);
2447 throw std::runtime_error(message);
2448 }
2449}
2450
2451namespace {
2452
2453/// Python wrapper for MlirBlockArgument.
2454class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2455public:
2456 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2457 static constexpr const char *pyClassName = "BlockArgument";
2458 using PyConcreteValue::PyConcreteValue;
2459
2460 static void bindDerived(ClassTy &c) {
2461 c.def_prop_ro("owner", [](PyBlockArgument &self) {
2462 return PyBlock(self.getParentOperation(),
2463 mlirBlockArgumentGetOwner(self.get()));
2464 });
2465 c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
2466 return mlirBlockArgumentGetArgNumber(self.get());
2467 });
2468 c.def(
2469 "set_type",
2470 [](PyBlockArgument &self, PyType type) {
2471 return mlirBlockArgumentSetType(self.get(), type);
2472 },
2473 nb::arg("type"));
2474 }
2475};
2476
2477/// A list of block arguments. Internally, these are stored as consecutive
2478/// elements, random access is cheap. The argument list is associated with the
2479/// operation that contains the block (detached blocks are not allowed in
2480/// Python bindings) and extends its lifetime.
2481class PyBlockArgumentList
2482 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2483public:
2484 static constexpr const char *pyClassName = "BlockArgumentList";
2485 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2486
2487 PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2488 intptr_t startIndex = 0, intptr_t length = -1,
2489 intptr_t step = 1)
2490 : Sliceable(startIndex,
2491 length == -1 ? mlirBlockGetNumArguments(block) : length,
2492 step),
2493 operation(std::move(operation)), block(block) {}
2494
2495 static void bindDerived(ClassTy &c) {
2496 c.def_prop_ro("types", [](PyBlockArgumentList &self) {
2497 return getValueTypes(self, self.operation->getContext());
2498 });
2499 }
2500
2501private:
2502 /// Give the parent CRTP class access to hook implementations below.
2503 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2504
2505 /// Returns the number of arguments in the list.
2506 intptr_t getRawNumElements() {
2507 operation->checkValid();
2508 return mlirBlockGetNumArguments(block);
2509 }
2510
2511 /// Returns `pos`-the element in the list.
2512 PyBlockArgument getRawElement(intptr_t pos) {
2513 MlirValue argument = mlirBlockGetArgument(block, pos);
2514 return PyBlockArgument(operation, argument);
2515 }
2516
2517 /// Returns a sublist of this list.
2518 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2519 intptr_t step) {
2520 return PyBlockArgumentList(operation, block, startIndex, length, step);
2521 }
2522
2523 PyOperationRef operation;
2524 MlirBlock block;
2525};
2526
2527/// A list of operation operands. Internally, these are stored as consecutive
2528/// elements, random access is cheap. The (returned) operand list is associated
2529/// with the operation whose operands these are, and thus extends the lifetime
2530/// of this operation.
2531class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2532public:
2533 static constexpr const char *pyClassName = "OpOperandList";
2534 using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2535
2536 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2537 intptr_t length = -1, intptr_t step = 1)
2538 : Sliceable(startIndex,
2539 length == -1 ? mlirOperationGetNumOperands(operation->get())
2540 : length,
2541 step),
2542 operation(operation) {}
2543
2544 void dunderSetItem(intptr_t index, PyValue value) {
2545 index = wrapIndex(index);
2546 mlirOperationSetOperand(operation->get(), index, value.get());
2547 }
2548
2549 static void bindDerived(ClassTy &c) {
2550 c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2551 }
2552
2553private:
2554 /// Give the parent CRTP class access to hook implementations below.
2555 friend class Sliceable<PyOpOperandList, PyValue>;
2556
2557 intptr_t getRawNumElements() {
2558 operation->checkValid();
2559 return mlirOperationGetNumOperands(operation->get());
2560 }
2561
2562 PyValue getRawElement(intptr_t pos) {
2563 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2564 MlirOperation owner;
2565 if (mlirValueIsAOpResult(operand))
2566 owner = mlirOpResultGetOwner(operand);
2567 else if (mlirValueIsABlockArgument(operand))
2568 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2569 else
2570 assert(false && "Value must be an block arg or op result.");
2571 PyOperationRef pyOwner =
2572 PyOperation::forOperation(operation->getContext(), owner);
2573 return PyValue(pyOwner, operand);
2574 }
2575
2576 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2577 return PyOpOperandList(operation, startIndex, length, step);
2578 }
2579
2580 PyOperationRef operation;
2581};
2582
2583/// A list of operation successors. Internally, these are stored as consecutive
2584/// elements, random access is cheap. The (returned) successor list is
2585/// associated with the operation whose successors these are, and thus extends
2586/// the lifetime of this operation.
2587class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2588public:
2589 static constexpr const char *pyClassName = "OpSuccessors";
2590
2591 PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2592 intptr_t length = -1, intptr_t step = 1)
2593 : Sliceable(startIndex,
2594 length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2595 : length,
2596 step),
2597 operation(operation) {}
2598
2599 void dunderSetItem(intptr_t index, PyBlock block) {
2600 index = wrapIndex(index);
2601 mlirOperationSetSuccessor(operation->get(), index, block.get());
2602 }
2603
2604 static void bindDerived(ClassTy &c) {
2605 c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2606 }
2607
2608private:
2609 /// Give the parent CRTP class access to hook implementations below.
2610 friend class Sliceable<PyOpSuccessors, PyBlock>;
2611
2612 intptr_t getRawNumElements() {
2613 operation->checkValid();
2614 return mlirOperationGetNumSuccessors(operation->get());
2615 }
2616
2617 PyBlock getRawElement(intptr_t pos) {
2618 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2619 return PyBlock(operation, block);
2620 }
2621
2622 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2623 return PyOpSuccessors(operation, startIndex, length, step);
2624 }
2625
2626 PyOperationRef operation;
2627};
2628
2629/// A list of block successors. Internally, these are stored as consecutive
2630/// elements, random access is cheap. The (returned) successor list is
2631/// associated with the operation and block whose successors these are, and thus
2632/// extends the lifetime of this operation and block.
2633class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
2634public:
2635 static constexpr const char *pyClassName = "BlockSuccessors";
2636
2637 PyBlockSuccessors(PyBlock block, PyOperationRef operation,
2638 intptr_t startIndex = 0, intptr_t length = -1,
2639 intptr_t step = 1)
2640 : Sliceable(startIndex,
2641 length == -1 ? mlirBlockGetNumSuccessors(block.get())
2642 : length,
2643 step),
2644 operation(operation), block(block) {}
2645
2646private:
2647 /// Give the parent CRTP class access to hook implementations below.
2648 friend class Sliceable<PyBlockSuccessors, PyBlock>;
2649
2650 intptr_t getRawNumElements() {
2651 block.checkValid();
2652 return mlirBlockGetNumSuccessors(block.get());
2653 }
2654
2655 PyBlock getRawElement(intptr_t pos) {
2656 MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2657 return PyBlock(operation, block);
2658 }
2659
2660 PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2661 return PyBlockSuccessors(block, operation, startIndex, length, step);
2662 }
2663
2664 PyOperationRef operation;
2665 PyBlock block;
2666};
2667
2668/// A list of block predecessors. The (returned) predecessor list is
2669/// associated with the operation and block whose predecessors these are, and
2670/// thus extends the lifetime of this operation and block.
2671///
2672/// WARNING: This Sliceable is more expensive than the others here because
2673/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
2674/// operands) anew for each indexed access.
2675class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
2676public:
2677 static constexpr const char *pyClassName = "BlockPredecessors";
2678
2679 PyBlockPredecessors(PyBlock block, PyOperationRef operation,
2680 intptr_t startIndex = 0, intptr_t length = -1,
2681 intptr_t step = 1)
2682 : Sliceable(startIndex,
2683 length == -1 ? mlirBlockGetNumPredecessors(block.get())
2684 : length,
2685 step),
2686 operation(operation), block(block) {}
2687
2688private:
2689 /// Give the parent CRTP class access to hook implementations below.
2690 friend class Sliceable<PyBlockPredecessors, PyBlock>;
2691
2692 intptr_t getRawNumElements() {
2693 block.checkValid();
2694 return mlirBlockGetNumPredecessors(block.get());
2695 }
2696
2697 PyBlock getRawElement(intptr_t pos) {
2698 MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2699 return PyBlock(operation, block);
2700 }
2701
2702 PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
2703 intptr_t step) {
2704 return PyBlockPredecessors(block, operation, startIndex, length, step);
2705 }
2706
2707 PyOperationRef operation;
2708 PyBlock block;
2709};
2710
2711/// A list of operation attributes. Can be indexed by name, producing
2712/// attributes, or by index, producing named attributes.
2713class PyOpAttributeMap {
2714public:
2715 PyOpAttributeMap(PyOperationRef operation)
2716 : operation(std::move(operation)) {}
2717
2718 MlirAttribute dunderGetItemNamed(const std::string &name) {
2719 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2720 toMlirStringRef(name));
2721 if (mlirAttributeIsNull(attr)) {
2722 throw nb::key_error("attempt to access a non-existent attribute");
2723 }
2724 return attr;
2725 }
2726
2727 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2728 if (index < 0) {
2729 index += dunderLen();
2730 }
2731 if (index < 0 || index >= dunderLen()) {
2732 throw nb::index_error("attempt to access out of bounds attribute");
2733 }
2734 MlirNamedAttribute namedAttr =
2735 mlirOperationGetAttribute(operation->get(), index);
2736 return PyNamedAttribute(
2737 namedAttr.attribute,
2738 std::string(mlirIdentifierStr(namedAttr.name).data,
2739 mlirIdentifierStr(namedAttr.name).length));
2740 }
2741
2742 void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2743 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2744 attr);
2745 }
2746
2747 void dunderDelItem(const std::string &name) {
2748 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2749 toMlirStringRef(name));
2750 if (!removed)
2751 throw nb::key_error("attempt to delete a non-existent attribute");
2752 }
2753
2754 intptr_t dunderLen() {
2755 return mlirOperationGetNumAttributes(operation->get());
2756 }
2757
2758 bool dunderContains(const std::string &name) {
2759 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2760 operation->get(), toMlirStringRef(name)));
2761 }
2762
2763 static void bind(nb::module_ &m) {
2764 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2765 .def("__contains__", &PyOpAttributeMap::dunderContains)
2766 .def("__len__", &PyOpAttributeMap::dunderLen)
2767 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2768 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2769 .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2770 .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2771 }
2772
2773private:
2774 PyOperationRef operation;
2775};
2776
2777} // namespace
2778
2779//------------------------------------------------------------------------------
2780// Populates the core exports of the 'ir' submodule.
2781//------------------------------------------------------------------------------
2782
2783void mlir::python::populateIRCore(nb::module_ &m) {
2784 // disable leak warnings which tend to be false positives.
2785 nb::set_leak_warnings(false);
2786 //----------------------------------------------------------------------------
2787 // Enums.
2788 //----------------------------------------------------------------------------
2789 nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2790 .value("ERROR", MlirDiagnosticError)
2791 .value("WARNING", MlirDiagnosticWarning)
2792 .value("NOTE", MlirDiagnosticNote)
2793 .value("REMARK", MlirDiagnosticRemark);
2794
2795 nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2796 .value("PRE_ORDER", MlirWalkPreOrder)
2797 .value("POST_ORDER", MlirWalkPostOrder);
2798
2799 nb::enum_<MlirWalkResult>(m, "WalkResult")
2800 .value("ADVANCE", MlirWalkResultAdvance)
2801 .value("INTERRUPT", MlirWalkResultInterrupt)
2802 .value("SKIP", MlirWalkResultSkip);
2803
2804 //----------------------------------------------------------------------------
2805 // Mapping of Diagnostics.
2806 //----------------------------------------------------------------------------
2807 nb::class_<PyDiagnostic>(m, "Diagnostic")
2808 .def_prop_ro("severity", &PyDiagnostic::getSeverity)
2809 .def_prop_ro("location", &PyDiagnostic::getLocation)
2810 .def_prop_ro("message", &PyDiagnostic::getMessage)
2811 .def_prop_ro("notes", &PyDiagnostic::getNotes)
2812 .def("__str__", [](PyDiagnostic &self) -> nb::str {
2813 if (!self.isValid())
2814 return nb::str("<Invalid Diagnostic>");
2815 return self.getMessage();
2816 });
2817
2818 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2819 .def("__init__",
2820 [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
2821 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2822 })
2823 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
2824 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
2825 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
2826 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
2827 .def("__str__",
2828 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2829
2830 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2831 .def("detach", &PyDiagnosticHandler::detach)
2832 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
2833 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
2834 .def("__enter__", &PyDiagnosticHandler::contextEnter)
2835 .def("__exit__", &PyDiagnosticHandler::contextExit,
2836 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2837 nb::arg("traceback").none());
2838
2839 //----------------------------------------------------------------------------
2840 // Mapping of MlirContext.
2841 // Note that this is exported as _BaseContext. The containing, Python level
2842 // __init__.py will subclass it with site-specific functionality and set a
2843 // "Context" attribute on this module.
2844 //----------------------------------------------------------------------------
2845
2846 // Expose DefaultThreadPool to python
2847 nb::class_<PyThreadPool>(m, "ThreadPool")
2848 .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
2849 .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
2850 .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
2851
2852 nb::class_<PyMlirContext>(m, "_BaseContext")
2853 .def("__init__",
2854 [](PyMlirContext &self) {
2855 MlirContext context = mlirContextCreateWithThreading(false);
2856 new (&self) PyMlirContext(context);
2857 })
2858 .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2859 .def("_get_context_again",
2860 [](PyMlirContext &self) {
2861 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2862 return ref.releaseObject();
2863 })
2864 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2865 .def("_get_live_operation_objects",
2866 &PyMlirContext::getLiveOperationObjects)
2867 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2868 .def("_clear_live_operations_inside",
2869 nb::overload_cast<MlirOperation>(
2870 &PyMlirContext::clearOperationsInside))
2871 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2872 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2873 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2874 .def("__enter__", &PyMlirContext::contextEnter)
2875 .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
2876 nb::arg("exc_value").none(), nb::arg("traceback").none())
2877 .def_prop_ro_static(
2878 "current",
2879 [](nb::object & /*class*/) {
2880 auto *context = PyThreadContextEntry::getDefaultContext();
2881 if (!context)
2882 return nb::none();
2883 return nb::cast(context);
2884 },
2885 "Gets the Context bound to the current thread or raises ValueError")
2886 .def_prop_ro(
2887 "dialects",
2888 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2889 "Gets a container for accessing dialects by name")
2890 .def_prop_ro(
2891 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2892 "Alias for 'dialect'")
2893 .def(
2894 "get_dialect_descriptor",
2895 [=](PyMlirContext &self, std::string &name) {
2896 MlirDialect dialect = mlirContextGetOrLoadDialect(
2897 self.get(), {name.data(), name.size()});
2898 if (mlirDialectIsNull(dialect)) {
2899 throw nb::value_error(
2900 (Twine("Dialect '") + name + "' not found").str().c_str());
2901 }
2902 return PyDialectDescriptor(self.getRef(), dialect);
2903 },
2904 nb::arg("dialect_name"),
2905 "Gets or loads a dialect by name, returning its descriptor object")
2906 .def_prop_rw(
2907 "allow_unregistered_dialects",
2908 [](PyMlirContext &self) -> bool {
2909 return mlirContextGetAllowUnregisteredDialects(self.get());
2910 },
2911 [](PyMlirContext &self, bool value) {
2912 mlirContextSetAllowUnregisteredDialects(self.get(), value);
2913 })
2914 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2915 nb::arg("callback"),
2916 "Attaches a diagnostic handler that will receive callbacks")
2917 .def(
2918 "enable_multithreading",
2919 [](PyMlirContext &self, bool enable) {
2920 mlirContextEnableMultithreading(self.get(), enable);
2921 },
2922 nb::arg("enable"))
2923 .def("set_thread_pool",
2924 [](PyMlirContext &self, PyThreadPool &pool) {
2925 // we should disable multi-threading first before setting
2926 // new thread pool otherwise the assert in
2927 // MLIRContext::setThreadPool will be raised.
2928 mlirContextEnableMultithreading(self.get(), false);
2929 mlirContextSetThreadPool(self.get(), pool.get());
2930 })
2931 .def("get_num_threads",
2932 [](PyMlirContext &self) {
2933 return mlirContextGetNumThreads(self.get());
2934 })
2935 .def("_mlir_thread_pool_ptr",
2936 [](PyMlirContext &self) {
2937 MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
2938 std::stringstream ss;
2939 ss << pool.ptr;
2940 return ss.str();
2941 })
2942 .def(
2943 "is_registered_operation",
2944 [](PyMlirContext &self, std::string &name) {
2945 return mlirContextIsRegisteredOperation(
2946 self.get(), MlirStringRef{name.data(), name.size()});
2947 },
2948 nb::arg("operation_name"))
2949 .def(
2950 "append_dialect_registry",
2951 [](PyMlirContext &self, PyDialectRegistry &registry) {
2952 mlirContextAppendDialectRegistry(self.get(), registry);
2953 },
2954 nb::arg("registry"))
2955 .def_prop_rw("emit_error_diagnostics", nullptr,
2956 &PyMlirContext::setEmitErrorDiagnostics,
2957 "Emit error diagnostics to diagnostic handlers. By default "
2958 "error diagnostics are captured and reported through "
2959 "MLIRError exceptions.")
2960 .def("load_all_available_dialects", [](PyMlirContext &self) {
2961 mlirContextLoadAllAvailableDialects(self.get());
2962 });
2963
2964 //----------------------------------------------------------------------------
2965 // Mapping of PyDialectDescriptor
2966 //----------------------------------------------------------------------------
2967 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2968 .def_prop_ro("namespace",
2969 [](PyDialectDescriptor &self) {
2970 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2971 return nb::str(ns.data, ns.length);
2972 })
2973 .def("__repr__", [](PyDialectDescriptor &self) {
2974 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2975 std::string repr("<DialectDescriptor ");
2976 repr.append(ns.data, ns.length);
2977 repr.append(">");
2978 return repr;
2979 });
2980
2981 //----------------------------------------------------------------------------
2982 // Mapping of PyDialects
2983 //----------------------------------------------------------------------------
2984 nb::class_<PyDialects>(m, "Dialects")
2985 .def("__getitem__",
2986 [=](PyDialects &self, std::string keyName) {
2987 MlirDialect dialect =
2988 self.getDialectForKey(keyName, /*attrError=*/false);
2989 nb::object descriptor =
2990 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2991 return createCustomDialectWrapper(keyName, std::move(descriptor));
2992 })
2993 .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2994 MlirDialect dialect =
2995 self.getDialectForKey(attrName, /*attrError=*/true);
2996 nb::object descriptor =
2997 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2998 return createCustomDialectWrapper(attrName, std::move(descriptor));
2999 });
3000
3001 //----------------------------------------------------------------------------
3002 // Mapping of PyDialect
3003 //----------------------------------------------------------------------------
3004 nb::class_<PyDialect>(m, "Dialect")
3005 .def(nb::init<nb::object>(), nb::arg("descriptor"))
3006 .def_prop_ro("descriptor",
3007 [](PyDialect &self) { return self.getDescriptor(); })
3008 .def("__repr__", [](nb::object self) {
3009 auto clazz = self.attr("__class__");
3010 return nb::str("<Dialect ") +
3011 self.attr("descriptor").attr("namespace") + nb::str(" (class ") +
3012 clazz.attr("__module__") + nb::str(".") +
3013 clazz.attr("__name__") + nb::str(")>");
3014 });
3015
3016 //----------------------------------------------------------------------------
3017 // Mapping of PyDialectRegistry
3018 //----------------------------------------------------------------------------
3019 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
3020 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
3021 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
3022 .def(nb::init<>());
3023
3024 //----------------------------------------------------------------------------
3025 // Mapping of Location
3026 //----------------------------------------------------------------------------
3027 nb::class_<PyLocation>(m, "Location")
3028 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
3029 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
3030 .def("__enter__", &PyLocation::contextEnter)
3031 .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
3032 nb::arg("exc_value").none(), nb::arg("traceback").none())
3033 .def("__eq__",
3034 [](PyLocation &self, PyLocation &other) -> bool {
3035 return mlirLocationEqual(self, other);
3036 })
3037 .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
3038 .def_prop_ro_static(
3039 "current",
3040 [](nb::object & /*class*/) {
3041 auto *loc = PyThreadContextEntry::getDefaultLocation();
3042 if (!loc)
3043 throw nb::value_error("No current Location");
3044 return loc;
3045 },
3046 "Gets the Location bound to the current thread or raises ValueError")
3047 .def_static(
3048 "unknown",
3049 [](DefaultingPyMlirContext context) {
3050 return PyLocation(context->getRef(),
3051 mlirLocationUnknownGet(context->get()));
3052 },
3053 nb::arg("context").none() = nb::none(),
3054 "Gets a Location representing an unknown location")
3055 .def_static(
3056 "callsite",
3057 [](PyLocation callee, const std::vector<PyLocation> &frames,
3058 DefaultingPyMlirContext context) {
3059 if (frames.empty())
3060 throw nb::value_error("No caller frames provided");
3061 MlirLocation caller = frames.back().get();
3062 for (const PyLocation &frame :
3063 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
3064 caller = mlirLocationCallSiteGet(frame.get(), caller);
3065 return PyLocation(context->getRef(),
3066 mlirLocationCallSiteGet(callee.get(), caller));
3067 },
3068 nb::arg("callee"), nb::arg("frames"),
3069 nb::arg("context").none() = nb::none(),
3070 kContextGetCallSiteLocationDocstring)
3071 .def("is_a_callsite", mlirLocationIsACallSite)
3072 .def_prop_ro("callee", mlirLocationCallSiteGetCallee)
3073 .def_prop_ro("caller", mlirLocationCallSiteGetCaller)
3074 .def_static(
3075 "file",
3076 [](std::string filename, int line, int col,
3077 DefaultingPyMlirContext context) {
3078 return PyLocation(
3079 context->getRef(),
3080 mlirLocationFileLineColGet(
3081 context->get(), toMlirStringRef(filename), line, col));
3082 },
3083 nb::arg("filename"), nb::arg("line"), nb::arg("col"),
3084 nb::arg("context").none() = nb::none(),
3085 kContextGetFileLocationDocstring)
3086 .def_static(
3087 "file",
3088 [](std::string filename, int startLine, int startCol, int endLine,
3089 int endCol, DefaultingPyMlirContext context) {
3090 return PyLocation(context->getRef(),
3091 mlirLocationFileLineColRangeGet(
3092 context->get(), toMlirStringRef(filename),
3093 startLine, startCol, endLine, endCol));
3094 },
3095 nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
3096 nb::arg("end_line"), nb::arg("end_col"),
3097 nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
3098 .def("is_a_file", mlirLocationIsAFileLineColRange)
3099 .def_prop_ro("filename",
3100 [](MlirLocation loc) {
3101 return mlirIdentifierStr(
3102 mlirLocationFileLineColRangeGetFilename(loc));
3103 })
3104 .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
3105 .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
3106 .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
3107 .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
3108 .def_static(
3109 "fused",
3110 [](const std::vector<PyLocation> &pyLocations,
3111 std::optional<PyAttribute> metadata,
3112 DefaultingPyMlirContext context) {
3113 llvm::SmallVector<MlirLocation, 4> locations;
3114 locations.reserve(pyLocations.size());
3115 for (auto &pyLocation : pyLocations)
3116 locations.push_back(pyLocation.get());
3117 MlirLocation location = mlirLocationFusedGet(
3118 context->get(), locations.size(), locations.data(),
3119 metadata ? metadata->get() : MlirAttribute{0});
3120 return PyLocation(context->getRef(), location);
3121 },
3122 nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
3123 nb::arg("context").none() = nb::none(),
3124 kContextGetFusedLocationDocstring)
3125 .def("is_a_fused", mlirLocationIsAFused)
3126 .def_prop_ro("locations",
3127 [](MlirLocation loc) {
3128 unsigned numLocations =
3129 mlirLocationFusedGetNumLocations(loc);
3130 std::vector<MlirLocation> locations(numLocations);
3131 if (numLocations)
3132 mlirLocationFusedGetLocations(loc, locations.data());
3133 return locations;
3134 })
3135 .def_static(
3136 "name",
3137 [](std::string name, std::optional<PyLocation> childLoc,
3138 DefaultingPyMlirContext context) {
3139 return PyLocation(
3140 context->getRef(),
3141 mlirLocationNameGet(
3142 context->get(), toMlirStringRef(name),
3143 childLoc ? childLoc->get()
3144 : mlirLocationUnknownGet(context->get())));
3145 },
3146 nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
3147 nb::arg("context").none() = nb::none(),
3148 kContextGetNameLocationDocString)
3149 .def("is_a_name", mlirLocationIsAName)
3150 .def_prop_ro("name_str",
3151 [](MlirLocation loc) {
3152 return mlirIdentifierStr(mlirLocationNameGetName(loc));
3153 })
3154 .def_prop_ro("child_loc", mlirLocationNameGetChildLoc)
3155 .def_static(
3156 "from_attr",
3157 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
3158 return PyLocation(context->getRef(),
3159 mlirLocationFromAttribute(attribute));
3160 },
3161 nb::arg("attribute"), nb::arg("context").none() = nb::none(),
3162 "Gets a Location from a LocationAttr")
3163 .def_prop_ro(
3164 "context",
3165 [](PyLocation &self) { return self.getContext().getObject(); },
3166 "Context that owns the Location")
3167 .def_prop_ro(
3168 "attr",
3169 [](PyLocation &self) { return mlirLocationGetAttribute(self); },
3170 "Get the underlying LocationAttr")
3171 .def(
3172 "emit_error",
3173 [](PyLocation &self, std::string message) {
3174 mlirEmitError(self, message.c_str());
3175 },
3176 nb::arg("message"), "Emits an error at this location")
3177 .def("__repr__", [](PyLocation &self) {
3178 PyPrintAccumulator printAccum;
3179 mlirLocationPrint(self, printAccum.getCallback(),
3180 printAccum.getUserData());
3181 return printAccum.join();
3182 });
3183
3184 //----------------------------------------------------------------------------
3185 // Mapping of Module
3186 //----------------------------------------------------------------------------
3187 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3188 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3189 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3190 .def_static(
3191 "parse",
3192 [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
3193 PyMlirContext::ErrorCapture errors(context->getRef());
3194 MlirModule module = mlirModuleCreateParse(
3195 context->get(), toMlirStringRef(moduleAsm));
3196 if (mlirModuleIsNull(module))
3197 throw MLIRError("Unable to parse module assembly", errors.take());
3198 return PyModule::forModule(module).releaseObject();
3199 },
3200 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3201 kModuleParseDocstring)
3202 .def_static(
3203 "parse",
3204 [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
3205 PyMlirContext::ErrorCapture errors(context->getRef());
3206 MlirModule module = mlirModuleCreateParse(
3207 context->get(), toMlirStringRef(moduleAsm));
3208 if (mlirModuleIsNull(module))
3209 throw MLIRError("Unable to parse module assembly", errors.take());
3210 return PyModule::forModule(module).releaseObject();
3211 },
3212 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3213 kModuleParseDocstring)
3214 .def_static(
3215 "parseFile",
3216 [](const std::string &path, DefaultingPyMlirContext context) {
3217 PyMlirContext::ErrorCapture errors(context->getRef());
3218 MlirModule module = mlirModuleCreateParseFromFile(
3219 context->get(), toMlirStringRef(path));
3220 if (mlirModuleIsNull(module))
3221 throw MLIRError("Unable to parse module assembly", errors.take());
3222 return PyModule::forModule(module).releaseObject();
3223 },
3224 nb::arg("path"), nb::arg("context").none() = nb::none(),
3225 kModuleParseDocstring)
3226 .def_static(
3227 "create",
3228 [](DefaultingPyLocation loc) {
3229 MlirModule module = mlirModuleCreateEmpty(loc);
3230 return PyModule::forModule(module).releaseObject();
3231 },
3232 nb::arg("loc").none() = nb::none(), "Creates an empty module")
3233 .def_prop_ro(
3234 "context",
3235 [](PyModule &self) { return self.getContext().getObject(); },
3236 "Context that created the Module")
3237 .def_prop_ro(
3238 "operation",
3239 [](PyModule &self) {
3240 return PyOperation::forOperation(self.getContext(),
3241 mlirModuleGetOperation(self.get()),
3242 self.getRef().releaseObject())
3243 .releaseObject();
3244 },
3245 "Accesses the module as an operation")
3246 .def_prop_ro(
3247 "body",
3248 [](PyModule &self) {
3249 PyOperationRef moduleOp = PyOperation::forOperation(
3250 self.getContext(), mlirModuleGetOperation(self.get()),
3251 self.getRef().releaseObject());
3252 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3253 return returnBlock;
3254 },
3255 "Return the block for this module")
3256 .def(
3257 "dump",
3258 [](PyModule &self) {
3259 mlirOperationDump(mlirModuleGetOperation(self.get()));
3260 },
3261 kDumpDocstring)
3262 .def(
3263 "__str__",
3264 [](nb::object self) {
3265 // Defer to the operation's __str__.
3266 return self.attr("operation").attr("__str__")();
3267 },
3268 kOperationStrDunderDocstring);
3269
3270 //----------------------------------------------------------------------------
3271 // Mapping of Operation.
3272 //----------------------------------------------------------------------------
3273 nb::class_<PyOperationBase>(m, "_OperationBase")
3274 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
3275 [](PyOperationBase &self) {
3276 return self.getOperation().getCapsule();
3277 })
3278 .def("__eq__",
3279 [](PyOperationBase &self, PyOperationBase &other) {
3280 return &self.getOperation() == &other.getOperation();
3281 })
3282 .def("__eq__",
3283 [](PyOperationBase &self, nb::object other) { return false; })
3284 .def("__hash__",
3285 [](PyOperationBase &self) {
3286 return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
3287 })
3288 .def_prop_ro("attributes",
3289 [](PyOperationBase &self) {
3290 return PyOpAttributeMap(self.getOperation().getRef());
3291 })
3292 .def_prop_ro(
3293 "context",
3294 [](PyOperationBase &self) {
3295 PyOperation &concreteOperation = self.getOperation();
3296 concreteOperation.checkValid();
3297 return concreteOperation.getContext().getObject();
3298 },
3299 "Context that owns the Operation")
3300 .def_prop_ro("name",
3301 [](PyOperationBase &self) {
3302 auto &concreteOperation = self.getOperation();
3303 concreteOperation.checkValid();
3304 MlirOperation operation = concreteOperation.get();
3305 return mlirIdentifierStr(mlirOperationGetName(operation));
3306 })
3307 .def_prop_ro("operands",
3308 [](PyOperationBase &self) {
3309 return PyOpOperandList(self.getOperation().getRef());
3310 })
3311 .def_prop_ro("regions",
3312 [](PyOperationBase &self) {
3313 return PyRegionList(self.getOperation().getRef());
3314 })
3315 .def_prop_ro(
3316 "results",
3317 [](PyOperationBase &self) {
3318 return PyOpResultList(self.getOperation().getRef());
3319 },
3320 "Returns the list of Operation results.")
3321 .def_prop_ro(
3322 "result",
3323 [](PyOperationBase &self) {
3324 auto &operation = self.getOperation();
3325 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3326 .maybeDownCast();
3327 },
3328 "Shortcut to get an op result if it has only one (throws an error "
3329 "otherwise).")
3330 .def_prop_ro(
3331 "location",
3332 [](PyOperationBase &self) {
3333 PyOperation &operation = self.getOperation();
3334 return PyLocation(operation.getContext(),
3335 mlirOperationGetLocation(operation.get()));
3336 },
3337 "Returns the source location the operation was defined or derived "
3338 "from.")
3339 .def_prop_ro("parent",
3340 [](PyOperationBase &self) -> nb::object {
3341 auto parent = self.getOperation().getParentOperation();
3342 if (parent)
3343 return parent->getObject();
3344 return nb::none();
3345 })
3346 .def(
3347 "__str__",
3348 [](PyOperationBase &self) {
3349 return self.getAsm(/*binary=*/false,
3350 /*largeElementsLimit=*/std::nullopt,
3351 /*enableDebugInfo=*/false,
3352 /*prettyDebugInfo=*/false,
3353 /*printGenericOpForm=*/false,
3354 /*useLocalScope=*/false,
3355 /*useNameLocAsPrefix=*/false,
3356 /*assumeVerified=*/false,
3357 /*skipRegions=*/false);
3358 },
3359 "Returns the assembly form of the operation.")
3360 .def("print",
3361 nb::overload_cast<PyAsmState &, nb::object, bool>(
3362 &PyOperationBase::print),
3363 nb::arg("state"), nb::arg("file").none() = nb::none(),
3364 nb::arg("binary") = false, kOperationPrintStateDocstring)
3365 .def("print",
3366 nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3367 bool, bool, nb::object, bool, bool>(
3368 &PyOperationBase::print),
3369 // Careful: Lots of arguments must match up with print method.
3370 nb::arg("large_elements_limit").none() = nb::none(),
3371 nb::arg("enable_debug_info") = false,
3372 nb::arg("pretty_debug_info") = false,
3373 nb::arg("print_generic_op_form") = false,
3374 nb::arg("use_local_scope") = false,
3375 nb::arg("use_name_loc_as_prefix") = false,
3376 nb::arg("assume_verified") = false,
3377 nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
3378 nb::arg("skip_regions") = false, kOperationPrintDocstring)
3379 .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3380 nb::arg("desired_version").none() = nb::none(),
3381 kOperationPrintBytecodeDocstring)
3382 .def("get_asm", &PyOperationBase::getAsm,
3383 // Careful: Lots of arguments must match up with get_asm method.
3384 nb::arg("binary") = false,
3385 nb::arg("large_elements_limit").none() = nb::none(),
3386 nb::arg("enable_debug_info") = false,
3387 nb::arg("pretty_debug_info") = false,
3388 nb::arg("print_generic_op_form") = false,
3389 nb::arg("use_local_scope") = false,
3390 nb::arg("use_name_loc_as_prefix") = false,
3391 nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3392 kOperationGetAsmDocstring)
3393 .def("verify", &PyOperationBase::verify,
3394 "Verify the operation. Raises MLIRError if verification fails, and "
3395 "returns true otherwise.")
3396 .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3397 "Puts self immediately after the other operation in its parent "
3398 "block.")
3399 .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3400 "Puts self immediately before the other operation in its parent "
3401 "block.")
3402 .def(
3403 "clone",
3404 [](PyOperationBase &self, nb::object ip) {
3405 return self.getOperation().clone(ip);
3406 },
3407 nb::arg("ip").none() = nb::none())
3408 .def(
3409 "detach_from_parent",
3410 [](PyOperationBase &self) {
3411 PyOperation &operation = self.getOperation();
3412 operation.checkValid();
3413 if (!operation.isAttached())
3414 throw nb::value_error("Detached operation has no parent.");
3415
3416 operation.detachFromParent();
3417 return operation.createOpView();
3418 },
3419 "Detaches the operation from its parent block.")
3420 .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3421 .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3422 nb::arg("walk_order") = MlirWalkPostOrder);
3423
3424 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3425 .def_static(
3426 "create",
3427 [](std::string_view name,
3428 std::optional<std::vector<PyType *>> results,
3429 std::optional<std::vector<PyValue *>> operands,
3430 std::optional<nb::dict> attributes,
3431 std::optional<std::vector<PyBlock *>> successors, int regions,
3432 DefaultingPyLocation location, const nb::object &maybeIp,
3433 bool inferType) {
3434 // Unpack/validate operands.
3435 llvm::SmallVector<MlirValue, 4> mlirOperands;
3436 if (operands) {
3437 mlirOperands.reserve(operands->size());
3438 for (PyValue *operand : *operands) {
3439 if (!operand)
3440 throw nb::value_error("operand value cannot be None");
3441 mlirOperands.push_back(operand->get());
3442 }
3443 }
3444
3445 return PyOperation::create(name, results, mlirOperands, attributes,
3446 successors, regions, location, maybeIp,
3447 inferType);
3448 },
3449 nb::arg("name"), nb::arg("results").none() = nb::none(),
3450 nb::arg("operands").none() = nb::none(),
3451 nb::arg("attributes").none() = nb::none(),
3452 nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
3453 nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3454 nb::arg("infer_type") = false, kOperationCreateDocstring)
3455 .def_static(
3456 "parse",
3457 [](const std::string &sourceStr, const std::string &sourceName,
3458 DefaultingPyMlirContext context) {
3459 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3460 ->createOpView();
3461 },
3462 nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3463 nb::arg("context").none() = nb::none(),
3464 "Parses an operation. Supports both text assembly format and binary "
3465 "bytecode format.")
3466 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
3467 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3468 .def_prop_ro("operation", [](nb::object self) { return self; })
3469 .def_prop_ro("opview", &PyOperation::createOpView)
3470 .def_prop_ro("block", &PyOperation::getBlock)
3471 .def_prop_ro(
3472 "successors",
3473 [](PyOperationBase &self) {
3474 return PyOpSuccessors(self.getOperation().getRef());
3475 },
3476 "Returns the list of Operation successors.");
3477
3478 auto opViewClass =
3479 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3480 .def(nb::init<nb::object>(), nb::arg("operation"))
3481 .def(
3482 "__init__",
3483 [](PyOpView *self, std::string_view name,
3484 std::tuple<int, bool> opRegionSpec,
3485 nb::object operandSegmentSpecObj,
3486 nb::object resultSegmentSpecObj,
3487 std::optional<nb::list> resultTypeList, nb::list operandList,
3488 std::optional<nb::dict> attributes,
3489 std::optional<std::vector<PyBlock *>> successors,
3490 std::optional<int> regions, DefaultingPyLocation location,
3491 const nb::object &maybeIp) {
3492 new (self) PyOpView(PyOpView::buildGeneric(
3493 name, opRegionSpec, operandSegmentSpecObj,
3494 resultSegmentSpecObj, resultTypeList, operandList,
3495 attributes, successors, regions, location, maybeIp));
3496 },
3497 nb::arg("name"), nb::arg("opRegionSpec"),
3498 nb::arg("operandSegmentSpecObj").none() = nb::none(),
3499 nb::arg("resultSegmentSpecObj").none() = nb::none(),
3500 nb::arg("results").none() = nb::none(),
3501 nb::arg("operands").none() = nb::none(),
3502 nb::arg("attributes").none() = nb::none(),
3503 nb::arg("successors").none() = nb::none(),
3504 nb::arg("regions").none() = nb::none(),
3505 nb::arg("loc").none() = nb::none(),
3506 nb::arg("ip").none() = nb::none())
3507
3508 .def_prop_ro("operation", &PyOpView::getOperationObject)
3509 .def_prop_ro("opview", [](nb::object self) { return self; })
3510 .def(
3511 "__str__",
3512 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3513 .def_prop_ro(
3514 "successors",
3515 [](PyOperationBase &self) {
3516 return PyOpSuccessors(self.getOperation().getRef());
3517 },
3518 "Returns the list of Operation successors.");
3519 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3520 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3521 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3522 // It is faster to pass the operation_name, ods_regions, and
3523 // ods_operand_segments/ods_result_segments as arguments to the constructor,
3524 // rather than to access them as attributes.
3525 opViewClass.attr("build_generic") = classmethod(
3526 [](nb::handle cls, std::optional<nb::list> resultTypeList,
3527 nb::list operandList, std::optional<nb::dict> attributes,
3528 std::optional<std::vector<PyBlock *>> successors,
3529 std::optional<int> regions, DefaultingPyLocation location,
3530 const nb::object &maybeIp) {
3531 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3532 std::tuple<int, bool> opRegionSpec =
3533 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3534 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3535 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3536 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3537 resultSegmentSpec, resultTypeList,
3538 operandList, attributes, successors,
3539 regions, location, maybeIp);
3540 },
3541 nb::arg("cls"), nb::arg("results").none() = nb::none(),
3542 nb::arg("operands").none() = nb::none(),
3543 nb::arg("attributes").none() = nb::none(),
3544 nb::arg("successors").none() = nb::none(),
3545 nb::arg("regions").none() = nb::none(),
3546 nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3547 "Builds a specific, generated OpView based on class level attributes.");
3548 opViewClass.attr("parse") = classmethod(
3549 [](const nb::object &cls, const std::string &sourceStr,
3550 const std::string &sourceName, DefaultingPyMlirContext context) {
3551 PyOperationRef parsed =
3552 PyOperation::parse(context->getRef(), sourceStr, sourceName);
3553
3554 // Check if the expected operation was parsed, and cast to to the
3555 // appropriate `OpView` subclass if successful.
3556 // NOTE: This accesses attributes that have been automatically added to
3557 // `OpView` subclasses, and is not intended to be used on `OpView`
3558 // directly.
3559 std::string clsOpName =
3560 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3561 MlirStringRef identifier =
3562 mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
3563 std::string_view parsedOpName(identifier.data, identifier.length);
3564 if (clsOpName != parsedOpName)
3565 throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3566 parsedOpName + "'");
3567 return PyOpView::constructDerived(cls, parsed.getObject());
3568 },
3569 nb::arg("cls"), nb::arg("source"), nb::kw_only(),
3570 nb::arg("source_name") = "", nb::arg("context").none() = nb::none(),
3571 "Parses a specific, generated OpView based on class level attributes");
3572
3573 //----------------------------------------------------------------------------
3574 // Mapping of PyRegion.
3575 //----------------------------------------------------------------------------
3576 nb::class_<PyRegion>(m, "Region")
3577 .def_prop_ro(
3578 "blocks",
3579 [](PyRegion &self) {
3580 return PyBlockList(self.getParentOperation(), self.get());
3581 },
3582 "Returns a forward-optimized sequence of blocks.")
3583 .def_prop_ro(
3584 "owner",
3585 [](PyRegion &self) {
3586 return self.getParentOperation()->createOpView();
3587 },
3588 "Returns the operation owning this region.")
3589 .def(
3590 "__iter__",
3591 [](PyRegion &self) {
3592 self.checkValid();
3593 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3594 return PyBlockIterator(self.getParentOperation(), firstBlock);
3595 },
3596 "Iterates over blocks in the region.")
3597 .def("__eq__",
3598 [](PyRegion &self, PyRegion &other) {
3599 return self.get().ptr == other.get().ptr;
3600 })
3601 .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
3602
3603 //----------------------------------------------------------------------------
3604 // Mapping of PyBlock.
3605 //----------------------------------------------------------------------------
3606 nb::class_<PyBlock>(m, "Block")
3607 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3608 .def_prop_ro(
3609 "owner",
3610 [](PyBlock &self) {
3611 return self.getParentOperation()->createOpView();
3612 },
3613 "Returns the owning operation of this block.")
3614 .def_prop_ro(
3615 "region",
3616 [](PyBlock &self) {
3617 MlirRegion region = mlirBlockGetParentRegion(self.get());
3618 return PyRegion(self.getParentOperation(), region);
3619 },
3620 "Returns the owning region of this block.")
3621 .def_prop_ro(
3622 "arguments",
3623 [](PyBlock &self) {
3624 return PyBlockArgumentList(self.getParentOperation(), self.get());
3625 },
3626 "Returns a list of block arguments.")
3627 .def(
3628 "add_argument",
3629 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3630 return mlirBlockAddArgument(self.get(), type, loc);
3631 },
3632 "Append an argument of the specified type to the block and returns "
3633 "the newly added argument.")
3634 .def(
3635 "erase_argument",
3636 [](PyBlock &self, unsigned index) {
3637 return mlirBlockEraseArgument(self.get(), index);
3638 },
3639 "Erase the argument at 'index' and remove it from the argument list.")
3640 .def_prop_ro(
3641 "operations",
3642 [](PyBlock &self) {
3643 return PyOperationList(self.getParentOperation(), self.get());
3644 },
3645 "Returns a forward-optimized sequence of operations.")
3646 .def_static(
3647 "create_at_start",
3648 [](PyRegion &parent, const nb::sequence &pyArgTypes,
3649 const std::optional<nb::sequence> &pyArgLocs) {
3650 parent.checkValid();
3651 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3652 mlirRegionInsertOwnedBlock(parent, 0, block);
3653 return PyBlock(parent.getParentOperation(), block);
3654 },
3655 nb::arg("parent"), nb::arg("arg_types") = nb::list(),
3656 nb::arg("arg_locs") = std::nullopt,
3657 "Creates and returns a new Block at the beginning of the given "
3658 "region (with given argument types and locations).")
3659 .def(
3660 "append_to",
3661 [](PyBlock &self, PyRegion &region) {
3662 MlirBlock b = self.get();
3663 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
3664 mlirBlockDetach(b);
3665 mlirRegionAppendOwnedBlock(region.get(), b);
3666 },
3667 "Append this block to a region, transferring ownership if necessary")
3668 .def(
3669 "create_before",
3670 [](PyBlock &self, const nb::args &pyArgTypes,
3671 const std::optional<nb::sequence> &pyArgLocs) {
3672 self.checkValid();
3673 MlirBlock block =
3674 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3675 MlirRegion region = mlirBlockGetParentRegion(self.get());
3676 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3677 return PyBlock(self.getParentOperation(), block);
3678 },
3679 nb::arg("arg_types"), nb::kw_only(),
3680 nb::arg("arg_locs") = std::nullopt,
3681 "Creates and returns a new Block before this block "
3682 "(with given argument types and locations).")
3683 .def(
3684 "create_after",
3685 [](PyBlock &self, const nb::args &pyArgTypes,
3686 const std::optional<nb::sequence> &pyArgLocs) {
3687 self.checkValid();
3688 MlirBlock block =
3689 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3690 MlirRegion region = mlirBlockGetParentRegion(self.get());
3691 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3692 return PyBlock(self.getParentOperation(), block);
3693 },
3694 nb::arg("arg_types"), nb::kw_only(),
3695 nb::arg("arg_locs") = std::nullopt,
3696 "Creates and returns a new Block after this block "
3697 "(with given argument types and locations).")
3698 .def(
3699 "__iter__",
3700 [](PyBlock &self) {
3701 self.checkValid();
3702 MlirOperation firstOperation =
3703 mlirBlockGetFirstOperation(self.get());
3704 return PyOperationIterator(self.getParentOperation(),
3705 firstOperation);
3706 },
3707 "Iterates over operations in the block.")
3708 .def("__eq__",
3709 [](PyBlock &self, PyBlock &other) {
3710 return self.get().ptr == other.get().ptr;
3711 })
3712 .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
3713 .def("__hash__",
3714 [](PyBlock &self) {
3715 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3716 })
3717 .def(
3718 "__str__",
3719 [](PyBlock &self) {
3720 self.checkValid();
3721 PyPrintAccumulator printAccum;
3722 mlirBlockPrint(self.get(), printAccum.getCallback(),
3723 printAccum.getUserData());
3724 return printAccum.join();
3725 },
3726 "Returns the assembly form of the block.")
3727 .def(
3728 "append",
3729 [](PyBlock &self, PyOperationBase &operation) {
3730 if (operation.getOperation().isAttached())
3731 operation.getOperation().detachFromParent();
3732
3733 MlirOperation mlirOperation = operation.getOperation().get();
3734 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3735 operation.getOperation().setAttached(
3736 self.getParentOperation().getObject());
3737 },
3738 nb::arg("operation"),
3739 "Appends an operation to this block. If the operation is currently "
3740 "in another block, it will be moved.")
3741 .def_prop_ro(
3742 "successors",
3743 [](PyBlock &self) {
3744 return PyBlockSuccessors(self, self.getParentOperation());
3745 },
3746 "Returns the list of Block successors.")
3747 .def_prop_ro(
3748 "predecessors",
3749 [](PyBlock &self) {
3750 return PyBlockPredecessors(self, self.getParentOperation());
3751 },
3752 "Returns the list of Block predecessors.");
3753
3754 //----------------------------------------------------------------------------
3755 // Mapping of PyInsertionPoint.
3756 //----------------------------------------------------------------------------
3757
3758 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
3759 .def(nb::init<PyBlock &>(), nb::arg("block"),
3760 "Inserts after the last operation but still inside the block.")
3761 .def("__enter__", &PyInsertionPoint::contextEnter)
3762 .def("__exit__", &PyInsertionPoint::contextExit,
3763 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
3764 nb::arg("traceback").none())
3765 .def_prop_ro_static(
3766 "current",
3767 [](nb::object & /*class*/) {
3768 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3769 if (!ip)
3770 throw nb::value_error("No current InsertionPoint");
3771 return ip;
3772 },
3773 "Gets the InsertionPoint bound to the current thread or raises "
3774 "ValueError if none has been set")
3775 .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
3776 "Inserts before a referenced operation.")
3777 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3778 nb::arg("block"), "Inserts at the beginning of the block.")
3779 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3780 nb::arg("block"), "Inserts before the block terminator.")
3781 .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
3782 "Inserts an operation.")
3783 .def_prop_ro(
3784 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3785 "Returns the block that this InsertionPoint points to.")
3786 .def_prop_ro(
3787 "ref_operation",
3788 [](PyInsertionPoint &self) -> nb::object {
3789 auto refOperation = self.getRefOperation();
3790 if (refOperation)
3791 return refOperation->getObject();
3792 return nb::none();
3793 },
3794 "The reference operation before which new operations are "
3795 "inserted, or None if the insertion point is at the end of "
3796 "the block");
3797
3798 //----------------------------------------------------------------------------
3799 // Mapping of PyAttribute.
3800 //----------------------------------------------------------------------------
3801 nb::class_<PyAttribute>(m, "Attribute")
3802 // Delegate to the PyAttribute copy constructor, which will also lifetime
3803 // extend the backing context which owns the MlirAttribute.
3804 .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
3805 "Casts the passed attribute to the generic Attribute")
3806 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
3807 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3808 .def_static(
3809 "parse",
3810 [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3811 PyMlirContext::ErrorCapture errors(context->getRef());
3812 MlirAttribute attr = mlirAttributeParseGet(
3813 context->get(), toMlirStringRef(attrSpec));
3814 if (mlirAttributeIsNull(attr))
3815 throw MLIRError("Unable to parse attribute", errors.take());
3816 return attr;
3817 },
3818 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3819 "Parses an attribute from an assembly form. Raises an MLIRError on "
3820 "failure.")
3821 .def_prop_ro(
3822 "context",
3823 [](PyAttribute &self) { return self.getContext().getObject(); },
3824 "Context that owns the Attribute")
3825 .def_prop_ro("type",
3826 [](PyAttribute &self) { return mlirAttributeGetType(self); })
3827 .def(
3828 "get_named",
3829 [](PyAttribute &self, std::string name) {
3830 return PyNamedAttribute(self, std::move(name));
3831 },
3832 nb::keep_alive<0, 1>(), "Binds a name to the attribute")
3833 .def("__eq__",
3834 [](PyAttribute &self, PyAttribute &other) { return self == other; })
3835 .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
3836 .def("__hash__",
3837 [](PyAttribute &self) {
3838 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3839 })
3840 .def(
3841 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3842 kDumpDocstring)
3843 .def(
3844 "__str__",
3845 [](PyAttribute &self) {
3846 PyPrintAccumulator printAccum;
3847 mlirAttributePrint(self, printAccum.getCallback(),
3848 printAccum.getUserData());
3849 return printAccum.join();
3850 },
3851 "Returns the assembly form of the Attribute.")
3852 .def("__repr__",
3853 [](PyAttribute &self) {
3854 // Generally, assembly formats are not printed for __repr__ because
3855 // this can cause exceptionally long debug output and exceptions.
3856 // However, attribute values are generally considered useful and
3857 // are printed. This may need to be re-evaluated if debug dumps end
3858 // up being excessive.
3859 PyPrintAccumulator printAccum;
3860 printAccum.parts.append("Attribute(");
3861 mlirAttributePrint(self, printAccum.getCallback(),
3862 printAccum.getUserData());
3863 printAccum.parts.append(")");
3864 return printAccum.join();
3865 })
3866 .def_prop_ro("typeid",
3867 [](PyAttribute &self) -> MlirTypeID {
3868 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3869 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3870 "mlirTypeID was expected to be non-null.");
3871 return mlirTypeID;
3872 })
3873 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
3874 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3875 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3876 "mlirTypeID was expected to be non-null.");
3877 std::optional<nb::callable> typeCaster =
3878 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3879 mlirAttributeGetDialect(self));
3880 if (!typeCaster)
3881 return nb::cast(self);
3882 return typeCaster.value()(self);
3883 });
3884
3885 //----------------------------------------------------------------------------
3886 // Mapping of PyNamedAttribute
3887 //----------------------------------------------------------------------------
3888 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
3889 .def("__repr__",
3890 [](PyNamedAttribute &self) {
3891 PyPrintAccumulator printAccum;
3892 printAccum.parts.append("NamedAttribute(");
3893 printAccum.parts.append(
3894 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3895 mlirIdentifierStr(self.namedAttr.name).length));
3896 printAccum.parts.append("=");
3897 mlirAttributePrint(self.namedAttr.attribute,
3898 printAccum.getCallback(),
3899 printAccum.getUserData());
3900 printAccum.parts.append(")");
3901 return printAccum.join();
3902 })
3903 .def_prop_ro(
3904 "name",
3905 [](PyNamedAttribute &self) {
3906 return mlirIdentifierStr(self.namedAttr.name);
3907 },
3908 "The name of the NamedAttribute binding")
3909 .def_prop_ro(
3910 "attr",
3911 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3912 nb::keep_alive<0, 1>(),
3913 "The underlying generic attribute of the NamedAttribute binding");
3914
3915 //----------------------------------------------------------------------------
3916 // Mapping of PyType.
3917 //----------------------------------------------------------------------------
3918 nb::class_<PyType>(m, "Type")
3919 // Delegate to the PyType copy constructor, which will also lifetime
3920 // extend the backing context which owns the MlirType.
3921 .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
3922 "Casts the passed type to the generic Type")
3923 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3924 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3925 .def_static(
3926 "parse",
3927 [](std::string typeSpec, DefaultingPyMlirContext context) {
3928 PyMlirContext::ErrorCapture errors(context->getRef());
3929 MlirType type =
3930 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3931 if (mlirTypeIsNull(type))
3932 throw MLIRError("Unable to parse type", errors.take());
3933 return type;
3934 },
3935 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3936 kContextParseTypeDocstring)
3937 .def_prop_ro(
3938 "context", [](PyType &self) { return self.getContext().getObject(); },
3939 "Context that owns the Type")
3940 .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3941 .def(
3942 "__eq__", [](PyType &self, nb::object &other) { return false; },
3943 nb::arg("other").none())
3944 .def("__hash__",
3945 [](PyType &self) {
3946 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3947 })
3948 .def(
3949 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3950 .def(
3951 "__str__",
3952 [](PyType &self) {
3953 PyPrintAccumulator printAccum;
3954 mlirTypePrint(self, printAccum.getCallback(),
3955 printAccum.getUserData());
3956 return printAccum.join();
3957 },
3958 "Returns the assembly form of the type.")
3959 .def("__repr__",
3960 [](PyType &self) {
3961 // Generally, assembly formats are not printed for __repr__ because
3962 // this can cause exceptionally long debug output and exceptions.
3963 // However, types are an exception as they typically have compact
3964 // assembly forms and printing them is useful.
3965 PyPrintAccumulator printAccum;
3966 printAccum.parts.append("Type(");
3967 mlirTypePrint(self, printAccum.getCallback(),
3968 printAccum.getUserData());
3969 printAccum.parts.append(")");
3970 return printAccum.join();
3971 })
3972 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3973 [](PyType &self) {
3974 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3975 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3976 "mlirTypeID was expected to be non-null.");
3977 std::optional<nb::callable> typeCaster =
3978 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3979 mlirTypeGetDialect(self));
3980 if (!typeCaster)
3981 return nb::cast(self);
3982 return typeCaster.value()(self);
3983 })
3984 .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID {
3985 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3986 if (!mlirTypeIDIsNull(mlirTypeID))
3987 return mlirTypeID;
3988 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
3989 throw nb::value_error(
3990 (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
3991 });
3992
3993 //----------------------------------------------------------------------------
3994 // Mapping of PyTypeID.
3995 //----------------------------------------------------------------------------
3996 nb::class_<PyTypeID>(m, "TypeID")
3997 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3998 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3999 // Note, this tests whether the underlying TypeIDs are the same,
4000 // not whether the wrapper MlirTypeIDs are the same, nor whether
4001 // the Python objects are the same (i.e., PyTypeID is a value type).
4002 .def("__eq__",
4003 [](PyTypeID &self, PyTypeID &other) { return self == other; })
4004 .def("__eq__",
4005 [](PyTypeID &self, const nb::object &other) { return false; })
4006 // Note, this gives the hash value of the underlying TypeID, not the
4007 // hash value of the Python object, nor the hash value of the
4008 // MlirTypeID wrapper.
4009 .def("__hash__", [](PyTypeID &self) {
4010 return static_cast<size_t>(mlirTypeIDHashValue(self));
4011 });
4012
4013 //----------------------------------------------------------------------------
4014 // Mapping of Value.
4015 //----------------------------------------------------------------------------
4016 nb::class_<PyValue>(m, "Value")
4017 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
4018 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
4019 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
4020 .def_prop_ro(
4021 "context",
4022 [](PyValue &self) { return self.getParentOperation()->getContext(); },
4023 "Context in which the value lives.")
4024 .def(
4025 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4026 kDumpDocstring)
4027 .def_prop_ro(
4028 "owner",
4029 [](PyValue &self) -> nb::object {
4030 MlirValue v = self.get();
4031 if (mlirValueIsAOpResult(v)) {
4032 assert(
4033 mlirOperationEqual(self.getParentOperation()->get(),
4034 mlirOpResultGetOwner(self.get())) &&
4035 "expected the owner of the value in Python to match that in "
4036 "the IR");
4037 return self.getParentOperation().getObject();
4038 }
4039
4040 if (mlirValueIsABlockArgument(v)) {
4041 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
4042 return nb::cast(PyBlock(self.getParentOperation(), block));
4043 }
4044
4045 assert(false && "Value must be a block argument or an op result");
4046 return nb::none();
4047 })
4048 .def_prop_ro("uses",
4049 [](PyValue &self) {
4050 return PyOpOperandIterator(
4051 mlirValueGetFirstUse(self.get()));
4052 })
4053 .def("__eq__",
4054 [](PyValue &self, PyValue &other) {
4055 return self.get().ptr == other.get().ptr;
4056 })
4057 .def("__eq__", [](PyValue &self, nb::object other) { return false; })
4058 .def("__hash__",
4059 [](PyValue &self) {
4060 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
4061 })
4062 .def(
4063 "__str__",
4064 [](PyValue &self) {
4065 PyPrintAccumulator printAccum;
4066 printAccum.parts.append("Value(");
4067 mlirValuePrint(self.get(), printAccum.getCallback(),
4068 printAccum.getUserData());
4069 printAccum.parts.append(")");
4070 return printAccum.join();
4071 },
4072 kValueDunderStrDocstring)
4073 .def(
4074 "get_name",
4075 [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
4076 PyPrintAccumulator printAccum;
4077 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
4078 if (useLocalScope)
4079 mlirOpPrintingFlagsUseLocalScope(flags);
4080 if (useNameLocAsPrefix)
4081 mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
4082 MlirAsmState valueState =
4083 mlirAsmStateCreateForValue(self.get(), flags);
4084 mlirValuePrintAsOperand(self.get(), valueState,
4085 printAccum.getCallback(),
4086 printAccum.getUserData());
4087 mlirOpPrintingFlagsDestroy(flags);
4088 mlirAsmStateDestroy(valueState);
4089 return printAccum.join();
4090 },
4091 nb::arg("use_local_scope") = false,
4092 nb::arg("use_name_loc_as_prefix") = false)
4093 .def(
4094 "get_name",
4095 [](PyValue &self, PyAsmState &state) {
4096 PyPrintAccumulator printAccum;
4097 MlirAsmState valueState = state.get();
4098 mlirValuePrintAsOperand(self.get(), valueState,
4099 printAccum.getCallback(),
4100 printAccum.getUserData());
4101 return printAccum.join();
4102 },
4103 nb::arg("state"), kGetNameAsOperand)
4104 .def_prop_ro("type",
4105 [](PyValue &self) { return mlirValueGetType(self.get()); })
4106 .def(
4107 "set_type",
4108 [](PyValue &self, const PyType &type) {
4109 return mlirValueSetType(self.get(), type);
4110 },
4111 nb::arg("type"))
4112 .def(
4113 "replace_all_uses_with",
4114 [](PyValue &self, PyValue &with) {
4115 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
4116 },
4117 kValueReplaceAllUsesWithDocstring)
4118 .def(
4119 "replace_all_uses_except",
4120 [](MlirValue self, MlirValue with, PyOperation &exception) {
4121 MlirOperation exceptedUser = exception.get();
4122 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4123 },
4124 nb::arg("with"), nb::arg("exceptions"),
4125 kValueReplaceAllUsesExceptDocstring)
4126 .def(
4127 "replace_all_uses_except",
4128 [](MlirValue self, MlirValue with, nb::list exceptions) {
4129 // Convert Python list to a SmallVector of MlirOperations
4130 llvm::SmallVector<MlirOperation> exceptionOps;
4131 for (nb::handle exception : exceptions) {
4132 exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4133 }
4134
4135 mlirValueReplaceAllUsesExcept(
4136 self, with, static_cast<intptr_t>(exceptionOps.size()),
4137 exceptionOps.data());
4138 },
4139 nb::arg("with"), nb::arg("exceptions"),
4140 kValueReplaceAllUsesExceptDocstring)
4141 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
4142 [](PyValue &self) { return self.maybeDownCast(); })
4143 .def_prop_ro(
4144 "location",
4145 [](MlirValue self) {
4146 return PyLocation(
4147 PyMlirContext::forContext(mlirValueGetContext(self)),
4148 mlirValueGetLocation(self));
4149 },
4150 "Returns the source location the value");
4151
4152 PyBlockArgument::bind(m);
4153 PyOpResult::bind(m);
4154 PyOpOperand::bind(m);
4155
4156 nb::class_<PyAsmState>(m, "AsmState")
4157 .def(nb::init<PyValue &, bool>(), nb::arg("value"),
4158 nb::arg("use_local_scope") = false)
4159 .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
4160 nb::arg("use_local_scope") = false);
4161
4162 //----------------------------------------------------------------------------
4163 // Mapping of SymbolTable.
4164 //----------------------------------------------------------------------------
4165 nb::class_<PySymbolTable>(m, "SymbolTable")
4166 .def(nb::init<PyOperationBase &>())
4167 .def("__getitem__", &PySymbolTable::dunderGetItem)
4168 .def("insert", &PySymbolTable::insert, nb::arg("operation"))
4169 .def("erase", &PySymbolTable::erase, nb::arg("operation"))
4170 .def("__delitem__", &PySymbolTable::dunderDel)
4171 .def("__contains__",
4172 [](PySymbolTable &table, const std::string &name) {
4173 return !mlirOperationIsNull(mlirSymbolTableLookup(
4174 table, mlirStringRefCreate(name.data(), name.length())));
4175 })
4176 // Static helpers.
4177 .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
4178 nb::arg("symbol"), nb::arg("name"))
4179 .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
4180 nb::arg("symbol"))
4181 .def_static("get_visibility", &PySymbolTable::getVisibility,
4182 nb::arg("symbol"))
4183 .def_static("set_visibility", &PySymbolTable::setVisibility,
4184 nb::arg("symbol"), nb::arg("visibility"))
4185 .def_static("replace_all_symbol_uses",
4186 &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
4187 nb::arg("new_symbol"), nb::arg("from_op"))
4188 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4189 nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
4190 nb::arg("callback"));
4191
4192 // Container bindings.
4193 PyBlockArgumentList::bind(m);
4194 PyBlockIterator::bind(m);
4195 PyBlockList::bind(m);
4196 PyBlockSuccessors::bind(m);
4197 PyBlockPredecessors::bind(m);
4198 PyOperationIterator::bind(m);
4199 PyOperationList::bind(m);
4200 PyOpAttributeMap::bind(m);
4201 PyOpOperandIterator::bind(m);
4202 PyOpOperandList::bind(m);
4203 PyOpResultList::bind(m);
4204 PyOpSuccessors::bind(m);
4205 PyRegionIterator::bind(m);
4206 PyRegionList::bind(m);
4207
4208 // Debug bindings.
4209 PyGlobalDebugFlag::bind(m);
4210
4211 // Attribute builder getter.
4212 PyAttrBuilderMap::bind(m);
4213
4214 nb::register_exception_translator([](const std::exception_ptr &p,
4215 void *payload) {
4216 // We can't define exceptions with custom fields through pybind, so instead
4217 // the exception class is defined in python and imported here.
4218 try {
4219 if (p)
4220 std::rethrow_exception(p);
4221 } catch (const MLIRError &e) {
4222 nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
4223 .attr("MLIRError")(e.message, e.errorDiagnostics);
4224 PyErr_SetObject(PyExc_Exception, obj.ptr());
4225 }
4226 });
4227}
4228

source code of mlir/lib/Bindings/Python/IRCore.cpp