1 | //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Globals.h" |
10 | #include "IRModule.h" |
11 | #include "NanobindUtils.h" |
12 | #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. |
13 | #include "mlir-c/BuiltinAttributes.h" |
14 | #include "mlir-c/Debug.h" |
15 | #include "mlir-c/Diagnostics.h" |
16 | #include "mlir-c/IR.h" |
17 | #include "mlir-c/Support.h" |
18 | #include "mlir/Bindings/Python/Nanobind.h" |
19 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
20 | #include "nanobind/nanobind.h" |
21 | #include "llvm/ADT/ArrayRef.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | #include "llvm/Support/raw_ostream.h" |
24 | |
25 | #include <optional> |
26 | #include <system_error> |
27 | #include <utility> |
28 | |
29 | namespace nb = nanobind; |
30 | using namespace nb::literals; |
31 | using namespace mlir; |
32 | using namespace mlir::python; |
33 | |
34 | using llvm::SmallVector; |
35 | using llvm::StringRef; |
36 | using llvm::Twine; |
37 | |
38 | //------------------------------------------------------------------------------ |
39 | // Docstrings (trivial, non-duplicated docstrings are included inline). |
40 | //------------------------------------------------------------------------------ |
41 | |
42 | static const char kContextParseTypeDocstring[] = |
43 | R"(Parses the assembly form of a type. |
44 | |
45 | Returns a Type object or raises an MLIRError if the type cannot be parsed. |
46 | |
47 | See also: https://mlir.llvm.org/docs/LangRef/#type-system |
48 | )"; |
49 | |
50 | static const char kContextGetCallSiteLocationDocstring[] = |
51 | R"(Gets a Location representing a caller and callsite)"; |
52 | |
53 | static const char kContextGetFileLocationDocstring[] = |
54 | R"(Gets a Location representing a file, line and column)"; |
55 | |
56 | static const char kContextGetFileRangeDocstring[] = |
57 | R"(Gets a Location representing a file, line and column range)"; |
58 | |
59 | static const char kContextGetFusedLocationDocstring[] = |
60 | R"(Gets a Location representing a fused location with optional metadata)"; |
61 | |
62 | static const char kContextGetNameLocationDocString[] = |
63 | R"(Gets a Location representing a named location with optional child location)"; |
64 | |
65 | static const char kModuleParseDocstring[] = |
66 | R"(Parses a module's assembly format from a string. |
67 | |
68 | Returns a new MlirModule or raises an MLIRError if the parsing fails. |
69 | |
70 | See also: https://mlir.llvm.org/docs/LangRef/ |
71 | )"; |
72 | |
73 | static const char kOperationCreateDocstring[] = |
74 | R"(Creates a new operation. |
75 | |
76 | Args: |
77 | name: Operation name (e.g. "dialect.operation"). |
78 | results: Sequence of Type representing op result types. |
79 | attributes: Dict of str:Attribute. |
80 | successors: List of Block for the operation's successors. |
81 | regions: Number of regions to create. |
82 | location: A Location object (defaults to resolve from context manager). |
83 | ip: An InsertionPoint (defaults to resolve from context manager or set to |
84 | False to disable insertion, even with an insertion point set in the |
85 | context manager). |
86 | infer_type: Whether to infer result types. |
87 | Returns: |
88 | A new "detached" Operation object. Detached operations can be added |
89 | to blocks, which causes them to become "attached." |
90 | )"; |
91 | |
92 | static const char kOperationPrintDocstring[] = |
93 | R"(Prints the assembly form of the operation to a file like object. |
94 | |
95 | Args: |
96 | file: The file like object to write to. Defaults to sys.stdout. |
97 | binary: Whether to write bytes (True) or str (False). Defaults to False. |
98 | large_elements_limit: Whether to elide elements attributes above this |
99 | number of elements. Defaults to None (no limit). |
100 | enable_debug_info: Whether to print debug/location information. Defaults |
101 | to False. |
102 | pretty_debug_info: Whether to format debug information for easier reading |
103 | by a human (warning: the result is unparseable). |
104 | print_generic_op_form: Whether to print the generic assembly forms of all |
105 | ops. Defaults to False. |
106 | use_local_Scope: Whether to print in a way that is more optimized for |
107 | multi-threaded access but may not be consistent with how the overall |
108 | module prints. |
109 | assume_verified: By default, if not printing generic form, the verifier |
110 | will be run and if it fails, generic form will be printed with a comment |
111 | about failed verification. While a reasonable default for interactive use, |
112 | for systematic use, it is often better for the caller to verify explicitly |
113 | and report failures in a more robust fashion. Set this to True if doing this |
114 | in order to avoid running a redundant verification. If the IR is actually |
115 | invalid, behavior is undefined. |
116 | skip_regions: Whether to skip printing regions. Defaults to False. |
117 | )"; |
118 | |
119 | static const char kOperationPrintStateDocstring[] = |
120 | R"(Prints the assembly form of the operation to a file like object. |
121 | |
122 | Args: |
123 | file: The file like object to write to. Defaults to sys.stdout. |
124 | binary: Whether to write bytes (True) or str (False). Defaults to False. |
125 | state: AsmState capturing the operation numbering and flags. |
126 | )"; |
127 | |
128 | static const char kOperationGetAsmDocstring[] = |
129 | R"(Gets the assembly form of the operation with all options available. |
130 | |
131 | Args: |
132 | binary: Whether to return a bytes (True) or str (False) object. Defaults to |
133 | False. |
134 | ... others ...: See the print() method for common keyword arguments for |
135 | configuring the printout. |
136 | Returns: |
137 | Either a bytes or str object, depending on the setting of the 'binary' |
138 | argument. |
139 | )"; |
140 | |
141 | static const char kOperationPrintBytecodeDocstring[] = |
142 | R"(Write the bytecode form of the operation to a file like object. |
143 | |
144 | Args: |
145 | file: The file like object to write to. |
146 | desired_version: The version of bytecode to emit. |
147 | Returns: |
148 | The bytecode writer status. |
149 | )"; |
150 | |
151 | static const char kOperationStrDunderDocstring[] = |
152 | R"(Gets the assembly form of the operation with default options. |
153 | |
154 | If more advanced control over the assembly formatting or I/O options is needed, |
155 | use the dedicated print or get_asm method, which supports keyword arguments to |
156 | customize behavior. |
157 | )"; |
158 | |
159 | static const char kDumpDocstring[] = |
160 | R"(Dumps a debug representation of the object to stderr.)"; |
161 | |
162 | static const char kAppendBlockDocstring[] = |
163 | R"(Appends a new block, with argument types as positional args. |
164 | |
165 | Returns: |
166 | The created block. |
167 | )"; |
168 | |
169 | static const char kValueDunderStrDocstring[] = |
170 | R"(Returns the string form of the value. |
171 | |
172 | If the value is a block argument, this is the assembly form of its type and the |
173 | position in the argument list. If the value is an operation result, this is |
174 | equivalent to printing the operation that produced it. |
175 | )"; |
176 | |
177 | static const char kGetNameAsOperand[] = |
178 | R"(Returns the string form of value as an operand (i.e., the ValueID). |
179 | )"; |
180 | |
181 | static const char kValueReplaceAllUsesWithDocstring[] = |
182 | R"(Replace all uses of value with the new value, updating anything in |
183 | the IR that uses 'self' to use the other value instead. |
184 | )"; |
185 | |
186 | static const char kValueReplaceAllUsesExceptDocstring[] = |
187 | R"("Replace all uses of this value with the 'with' value, except for those |
188 | in 'exceptions'. 'exceptions' can be either a single operation or a list of |
189 | operations. |
190 | )"; |
191 | |
192 | //------------------------------------------------------------------------------ |
193 | // Utilities. |
194 | //------------------------------------------------------------------------------ |
195 | |
196 | /// Helper for creating an @classmethod. |
197 | template <class Func, typename... Args> |
198 | nb::object classmethod(Func f, Args... args) { |
199 | nb::object cf = nb::cpp_function(f, args...); |
200 | return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr()))); |
201 | } |
202 | |
203 | static nb::object |
204 | createCustomDialectWrapper(const std::string &dialectNamespace, |
205 | nb::object dialectDescriptor) { |
206 | auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); |
207 | if (!dialectClass) { |
208 | // Use the base class. |
209 | return nb::cast(PyDialect(std::move(dialectDescriptor))); |
210 | } |
211 | |
212 | // Create the custom implementation. |
213 | return (*dialectClass)(std::move(dialectDescriptor)); |
214 | } |
215 | |
216 | static MlirStringRef toMlirStringRef(const std::string &s) { |
217 | return mlirStringRefCreate(s.data(), s.size()); |
218 | } |
219 | |
220 | static MlirStringRef toMlirStringRef(std::string_view s) { |
221 | return mlirStringRefCreate(s.data(), s.size()); |
222 | } |
223 | |
224 | static MlirStringRef toMlirStringRef(const nb::bytes &s) { |
225 | return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size()); |
226 | } |
227 | |
228 | /// Create a block, using the current location context if no locations are |
229 | /// specified. |
230 | static MlirBlock createBlock(const nb::sequence &pyArgTypes, |
231 | const std::optional<nb::sequence> &pyArgLocs) { |
232 | SmallVector<MlirType> argTypes; |
233 | argTypes.reserve(nb::len(pyArgTypes)); |
234 | for (const auto &pyType : pyArgTypes) |
235 | argTypes.push_back(nb::cast<PyType &>(pyType)); |
236 | |
237 | SmallVector<MlirLocation> argLocs; |
238 | if (pyArgLocs) { |
239 | argLocs.reserve(nb::len(*pyArgLocs)); |
240 | for (const auto &pyLoc : *pyArgLocs) |
241 | argLocs.push_back(nb::cast<PyLocation &>(pyLoc)); |
242 | } else if (!argTypes.empty()) { |
243 | argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); |
244 | } |
245 | |
246 | if (argTypes.size() != argLocs.size()) |
247 | throw nb::value_error(("Expected "+ Twine(argTypes.size()) + |
248 | " locations, got: "+ Twine(argLocs.size())) |
249 | .str() |
250 | .c_str()); |
251 | return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); |
252 | } |
253 | |
254 | /// Wrapper for the global LLVM debugging flag. |
255 | struct PyGlobalDebugFlag { |
256 | static void set(nb::object &o, bool enable) { |
257 | nb::ft_lock_guard lock(mutex); |
258 | mlirEnableGlobalDebug(enable); |
259 | } |
260 | |
261 | static bool get(const nb::object &) { |
262 | nb::ft_lock_guard lock(mutex); |
263 | return mlirIsGlobalDebugEnabled(); |
264 | } |
265 | |
266 | static void bind(nb::module_ &m) { |
267 | // Debug flags. |
268 | nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") |
269 | .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, |
270 | &PyGlobalDebugFlag::set, "LLVM-wide debug flag") |
271 | .def_static( |
272 | "set_types", |
273 | [](const std::string &type) { |
274 | nb::ft_lock_guard lock(mutex); |
275 | mlirSetGlobalDebugType(type.c_str()); |
276 | }, |
277 | "types"_a, "Sets specific debug types to be produced by LLVM") |
278 | .def_static("set_types", [](const std::vector<std::string> &types) { |
279 | std::vector<const char *> pointers; |
280 | pointers.reserve(types.size()); |
281 | for (const std::string &str : types) |
282 | pointers.push_back(str.c_str()); |
283 | nb::ft_lock_guard lock(mutex); |
284 | mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); |
285 | }); |
286 | } |
287 | |
288 | private: |
289 | static nb::ft_mutex mutex; |
290 | }; |
291 | |
292 | nb::ft_mutex PyGlobalDebugFlag::mutex; |
293 | |
294 | struct PyAttrBuilderMap { |
295 | static bool dunderContains(const std::string &attributeKind) { |
296 | return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); |
297 | } |
298 | static nb::callable dunderGetItemNamed(const std::string &attributeKind) { |
299 | auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); |
300 | if (!builder) |
301 | throw nb::key_error(attributeKind.c_str()); |
302 | return *builder; |
303 | } |
304 | static void dunderSetItemNamed(const std::string &attributeKind, |
305 | nb::callable func, bool replace) { |
306 | PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), |
307 | replace); |
308 | } |
309 | |
310 | static void bind(nb::module_ &m) { |
311 | nb::class_<PyAttrBuilderMap>(m, "AttrBuilder") |
312 | .def_static("contains", &PyAttrBuilderMap::dunderContains) |
313 | .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed) |
314 | .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, |
315 | "attribute_kind"_a, "attr_builder"_a, "replace"_a= false, |
316 | "Register an attribute builder for building MLIR " |
317 | "attributes from python values."); |
318 | } |
319 | }; |
320 | |
321 | //------------------------------------------------------------------------------ |
322 | // PyBlock |
323 | //------------------------------------------------------------------------------ |
324 | |
325 | nb::object PyBlock::getCapsule() { |
326 | return nb::steal<nb::object>(mlirPythonBlockToCapsule(get())); |
327 | } |
328 | |
329 | //------------------------------------------------------------------------------ |
330 | // Collections. |
331 | //------------------------------------------------------------------------------ |
332 | |
333 | namespace { |
334 | |
335 | class PyRegionIterator { |
336 | public: |
337 | PyRegionIterator(PyOperationRef operation) |
338 | : operation(std::move(operation)) {} |
339 | |
340 | PyRegionIterator &dunderIter() { return *this; } |
341 | |
342 | PyRegion dunderNext() { |
343 | operation->checkValid(); |
344 | if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { |
345 | throw nb::stop_iteration(); |
346 | } |
347 | MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); |
348 | return PyRegion(operation, region); |
349 | } |
350 | |
351 | static void bind(nb::module_ &m) { |
352 | nb::class_<PyRegionIterator>(m, "RegionIterator") |
353 | .def("__iter__", &PyRegionIterator::dunderIter) |
354 | .def("__next__", &PyRegionIterator::dunderNext); |
355 | } |
356 | |
357 | private: |
358 | PyOperationRef operation; |
359 | int nextIndex = 0; |
360 | }; |
361 | |
362 | /// Regions of an op are fixed length and indexed numerically so are represented |
363 | /// with a sequence-like container. |
364 | class PyRegionList : public Sliceable<PyRegionList, PyRegion> { |
365 | public: |
366 | static constexpr const char *pyClassName = "RegionSequence"; |
367 | |
368 | PyRegionList(PyOperationRef operation, intptr_t startIndex = 0, |
369 | intptr_t length = -1, intptr_t step = 1) |
370 | : Sliceable(startIndex, |
371 | length == -1 ? mlirOperationGetNumRegions(operation->get()) |
372 | : length, |
373 | step), |
374 | operation(std::move(operation)) {} |
375 | |
376 | PyRegionIterator dunderIter() { |
377 | operation->checkValid(); |
378 | return PyRegionIterator(operation); |
379 | } |
380 | |
381 | static void bindDerived(ClassTy &c) { |
382 | c.def("__iter__", &PyRegionList::dunderIter); |
383 | } |
384 | |
385 | private: |
386 | /// Give the parent CRTP class access to hook implementations below. |
387 | friend class Sliceable<PyRegionList, PyRegion>; |
388 | |
389 | intptr_t getRawNumElements() { |
390 | operation->checkValid(); |
391 | return mlirOperationGetNumRegions(operation->get()); |
392 | } |
393 | |
394 | PyRegion getRawElement(intptr_t pos) { |
395 | operation->checkValid(); |
396 | return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos)); |
397 | } |
398 | |
399 | PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
400 | return PyRegionList(operation, startIndex, length, step); |
401 | } |
402 | |
403 | PyOperationRef operation; |
404 | }; |
405 | |
406 | class PyBlockIterator { |
407 | public: |
408 | PyBlockIterator(PyOperationRef operation, MlirBlock next) |
409 | : operation(std::move(operation)), next(next) {} |
410 | |
411 | PyBlockIterator &dunderIter() { return *this; } |
412 | |
413 | PyBlock dunderNext() { |
414 | operation->checkValid(); |
415 | if (mlirBlockIsNull(next)) { |
416 | throw nb::stop_iteration(); |
417 | } |
418 | |
419 | PyBlock returnBlock(operation, next); |
420 | next = mlirBlockGetNextInRegion(next); |
421 | return returnBlock; |
422 | } |
423 | |
424 | static void bind(nb::module_ &m) { |
425 | nb::class_<PyBlockIterator>(m, "BlockIterator") |
426 | .def("__iter__", &PyBlockIterator::dunderIter) |
427 | .def("__next__", &PyBlockIterator::dunderNext); |
428 | } |
429 | |
430 | private: |
431 | PyOperationRef operation; |
432 | MlirBlock next; |
433 | }; |
434 | |
435 | /// Blocks are exposed by the C-API as a forward-only linked list. In Python, |
436 | /// we present them as a more full-featured list-like container but optimize |
437 | /// it for forward iteration. Blocks are always owned by a region. |
438 | class PyBlockList { |
439 | public: |
440 | PyBlockList(PyOperationRef operation, MlirRegion region) |
441 | : operation(std::move(operation)), region(region) {} |
442 | |
443 | PyBlockIterator dunderIter() { |
444 | operation->checkValid(); |
445 | return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); |
446 | } |
447 | |
448 | intptr_t dunderLen() { |
449 | operation->checkValid(); |
450 | intptr_t count = 0; |
451 | MlirBlock block = mlirRegionGetFirstBlock(region); |
452 | while (!mlirBlockIsNull(block)) { |
453 | count += 1; |
454 | block = mlirBlockGetNextInRegion(block); |
455 | } |
456 | return count; |
457 | } |
458 | |
459 | PyBlock dunderGetItem(intptr_t index) { |
460 | operation->checkValid(); |
461 | if (index < 0) { |
462 | index += dunderLen(); |
463 | } |
464 | if (index < 0) { |
465 | throw nb::index_error("attempt to access out of bounds block"); |
466 | } |
467 | MlirBlock block = mlirRegionGetFirstBlock(region); |
468 | while (!mlirBlockIsNull(block)) { |
469 | if (index == 0) { |
470 | return PyBlock(operation, block); |
471 | } |
472 | block = mlirBlockGetNextInRegion(block); |
473 | index -= 1; |
474 | } |
475 | throw nb::index_error("attempt to access out of bounds block"); |
476 | } |
477 | |
478 | PyBlock appendBlock(const nb::args &pyArgTypes, |
479 | const std::optional<nb::sequence> &pyArgLocs) { |
480 | operation->checkValid(); |
481 | MlirBlock block = |
482 | createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); |
483 | mlirRegionAppendOwnedBlock(region, block); |
484 | return PyBlock(operation, block); |
485 | } |
486 | |
487 | static void bind(nb::module_ &m) { |
488 | nb::class_<PyBlockList>(m, "BlockList") |
489 | .def("__getitem__", &PyBlockList::dunderGetItem) |
490 | .def("__iter__", &PyBlockList::dunderIter) |
491 | .def("__len__", &PyBlockList::dunderLen) |
492 | .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, |
493 | nb::arg("args"), nb::kw_only(), |
494 | nb::arg("arg_locs") = std::nullopt); |
495 | } |
496 | |
497 | private: |
498 | PyOperationRef operation; |
499 | MlirRegion region; |
500 | }; |
501 | |
502 | class PyOperationIterator { |
503 | public: |
504 | PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) |
505 | : parentOperation(std::move(parentOperation)), next(next) {} |
506 | |
507 | PyOperationIterator &dunderIter() { return *this; } |
508 | |
509 | nb::object dunderNext() { |
510 | parentOperation->checkValid(); |
511 | if (mlirOperationIsNull(next)) { |
512 | throw nb::stop_iteration(); |
513 | } |
514 | |
515 | PyOperationRef returnOperation = |
516 | PyOperation::forOperation(parentOperation->getContext(), next); |
517 | next = mlirOperationGetNextInBlock(next); |
518 | return returnOperation->createOpView(); |
519 | } |
520 | |
521 | static void bind(nb::module_ &m) { |
522 | nb::class_<PyOperationIterator>(m, "OperationIterator") |
523 | .def("__iter__", &PyOperationIterator::dunderIter) |
524 | .def("__next__", &PyOperationIterator::dunderNext); |
525 | } |
526 | |
527 | private: |
528 | PyOperationRef parentOperation; |
529 | MlirOperation next; |
530 | }; |
531 | |
532 | /// Operations are exposed by the C-API as a forward-only linked list. In |
533 | /// Python, we present them as a more full-featured list-like container but |
534 | /// optimize it for forward iteration. Iterable operations are always owned |
535 | /// by a block. |
536 | class PyOperationList { |
537 | public: |
538 | PyOperationList(PyOperationRef parentOperation, MlirBlock block) |
539 | : parentOperation(std::move(parentOperation)), block(block) {} |
540 | |
541 | PyOperationIterator dunderIter() { |
542 | parentOperation->checkValid(); |
543 | return PyOperationIterator(parentOperation, |
544 | mlirBlockGetFirstOperation(block)); |
545 | } |
546 | |
547 | intptr_t dunderLen() { |
548 | parentOperation->checkValid(); |
549 | intptr_t count = 0; |
550 | MlirOperation childOp = mlirBlockGetFirstOperation(block); |
551 | while (!mlirOperationIsNull(childOp)) { |
552 | count += 1; |
553 | childOp = mlirOperationGetNextInBlock(childOp); |
554 | } |
555 | return count; |
556 | } |
557 | |
558 | nb::object dunderGetItem(intptr_t index) { |
559 | parentOperation->checkValid(); |
560 | if (index < 0) { |
561 | index += dunderLen(); |
562 | } |
563 | if (index < 0) { |
564 | throw nb::index_error("attempt to access out of bounds operation"); |
565 | } |
566 | MlirOperation childOp = mlirBlockGetFirstOperation(block); |
567 | while (!mlirOperationIsNull(childOp)) { |
568 | if (index == 0) { |
569 | return PyOperation::forOperation(parentOperation->getContext(), childOp) |
570 | ->createOpView(); |
571 | } |
572 | childOp = mlirOperationGetNextInBlock(childOp); |
573 | index -= 1; |
574 | } |
575 | throw nb::index_error("attempt to access out of bounds operation"); |
576 | } |
577 | |
578 | static void bind(nb::module_ &m) { |
579 | nb::class_<PyOperationList>(m, "OperationList") |
580 | .def("__getitem__", &PyOperationList::dunderGetItem) |
581 | .def("__iter__", &PyOperationList::dunderIter) |
582 | .def("__len__", &PyOperationList::dunderLen); |
583 | } |
584 | |
585 | private: |
586 | PyOperationRef parentOperation; |
587 | MlirBlock block; |
588 | }; |
589 | |
590 | class PyOpOperand { |
591 | public: |
592 | PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} |
593 | |
594 | nb::object getOwner() { |
595 | MlirOperation owner = mlirOpOperandGetOwner(opOperand); |
596 | PyMlirContextRef context = |
597 | PyMlirContext::forContext(mlirOperationGetContext(owner)); |
598 | return PyOperation::forOperation(context, owner)->createOpView(); |
599 | } |
600 | |
601 | size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } |
602 | |
603 | static void bind(nb::module_ &m) { |
604 | nb::class_<PyOpOperand>(m, "OpOperand") |
605 | .def_prop_ro("owner", &PyOpOperand::getOwner) |
606 | .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); |
607 | } |
608 | |
609 | private: |
610 | MlirOpOperand opOperand; |
611 | }; |
612 | |
613 | class PyOpOperandIterator { |
614 | public: |
615 | PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} |
616 | |
617 | PyOpOperandIterator &dunderIter() { return *this; } |
618 | |
619 | PyOpOperand dunderNext() { |
620 | if (mlirOpOperandIsNull(opOperand)) |
621 | throw nb::stop_iteration(); |
622 | |
623 | PyOpOperand returnOpOperand(opOperand); |
624 | opOperand = mlirOpOperandGetNextUse(opOperand); |
625 | return returnOpOperand; |
626 | } |
627 | |
628 | static void bind(nb::module_ &m) { |
629 | nb::class_<PyOpOperandIterator>(m, "OpOperandIterator") |
630 | .def("__iter__", &PyOpOperandIterator::dunderIter) |
631 | .def("__next__", &PyOpOperandIterator::dunderNext); |
632 | } |
633 | |
634 | private: |
635 | MlirOpOperand opOperand; |
636 | }; |
637 | |
638 | } // namespace |
639 | |
640 | //------------------------------------------------------------------------------ |
641 | // PyMlirContext |
642 | //------------------------------------------------------------------------------ |
643 | |
644 | PyMlirContext::PyMlirContext(MlirContext context) : context(context) { |
645 | nb::gil_scoped_acquire acquire; |
646 | nb::ft_lock_guard lock(live_contexts_mutex); |
647 | auto &liveContexts = getLiveContexts(); |
648 | liveContexts[context.ptr] = this; |
649 | } |
650 | |
651 | PyMlirContext::~PyMlirContext() { |
652 | // Note that the only public way to construct an instance is via the |
653 | // forContext method, which always puts the associated handle into |
654 | // liveContexts. |
655 | nb::gil_scoped_acquire acquire; |
656 | { |
657 | nb::ft_lock_guard lock(live_contexts_mutex); |
658 | getLiveContexts().erase(context.ptr); |
659 | } |
660 | mlirContextDestroy(context); |
661 | } |
662 | |
663 | nb::object PyMlirContext::getCapsule() { |
664 | return nb::steal<nb::object>(mlirPythonContextToCapsule(get())); |
665 | } |
666 | |
667 | nb::object PyMlirContext::createFromCapsule(nb::object capsule) { |
668 | MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); |
669 | if (mlirContextIsNull(rawContext)) |
670 | throw nb::python_error(); |
671 | return forContext(rawContext).releaseObject(); |
672 | } |
673 | |
674 | PyMlirContextRef PyMlirContext::forContext(MlirContext context) { |
675 | nb::gil_scoped_acquire acquire; |
676 | nb::ft_lock_guard lock(live_contexts_mutex); |
677 | auto &liveContexts = getLiveContexts(); |
678 | auto it = liveContexts.find(context.ptr); |
679 | if (it == liveContexts.end()) { |
680 | // Create. |
681 | PyMlirContext *unownedContextWrapper = new PyMlirContext(context); |
682 | nb::object pyRef = nb::cast(unownedContextWrapper); |
683 | assert(pyRef && "cast to nb::object failed"); |
684 | liveContexts[context.ptr] = unownedContextWrapper; |
685 | return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); |
686 | } |
687 | // Use existing. |
688 | nb::object pyRef = nb::cast(it->second); |
689 | return PyMlirContextRef(it->second, std::move(pyRef)); |
690 | } |
691 | |
692 | nb::ft_mutex PyMlirContext::live_contexts_mutex; |
693 | |
694 | PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { |
695 | static LiveContextMap liveContexts; |
696 | return liveContexts; |
697 | } |
698 | |
699 | size_t PyMlirContext::getLiveCount() { |
700 | nb::ft_lock_guard lock(live_contexts_mutex); |
701 | return getLiveContexts().size(); |
702 | } |
703 | |
704 | size_t PyMlirContext::getLiveOperationCount() { |
705 | nb::ft_lock_guard lock(liveOperationsMutex); |
706 | return liveOperations.size(); |
707 | } |
708 | |
709 | std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() { |
710 | std::vector<PyOperation *> liveObjects; |
711 | nb::ft_lock_guard lock(liveOperationsMutex); |
712 | for (auto &entry : liveOperations) |
713 | liveObjects.push_back(entry.second.second); |
714 | return liveObjects; |
715 | } |
716 | |
717 | size_t PyMlirContext::clearLiveOperations() { |
718 | |
719 | LiveOperationMap operations; |
720 | { |
721 | nb::ft_lock_guard lock(liveOperationsMutex); |
722 | std::swap(operations, liveOperations); |
723 | } |
724 | for (auto &op : operations) |
725 | op.second.second->setInvalid(); |
726 | size_t numInvalidated = operations.size(); |
727 | return numInvalidated; |
728 | } |
729 | |
730 | void PyMlirContext::clearOperation(MlirOperation op) { |
731 | PyOperation *py_op; |
732 | { |
733 | nb::ft_lock_guard lock(liveOperationsMutex); |
734 | auto it = liveOperations.find(op.ptr); |
735 | if (it == liveOperations.end()) { |
736 | return; |
737 | } |
738 | py_op = it->second.second; |
739 | liveOperations.erase(it); |
740 | } |
741 | py_op->setInvalid(); |
742 | } |
743 | |
744 | void PyMlirContext::clearOperationsInside(PyOperationBase &op) { |
745 | typedef struct { |
746 | PyOperation &rootOp; |
747 | bool rootSeen; |
748 | } callBackData; |
749 | callBackData data{.rootOp: op.getOperation(), .rootSeen: false}; |
750 | // Mark all ops below the op that the passmanager will be rooted |
751 | // at (but not op itself - note the preorder) as invalid. |
752 | MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, |
753 | void *userData) { |
754 | callBackData *data = static_cast<callBackData *>(userData); |
755 | if (LLVM_LIKELY(data->rootSeen)) |
756 | data->rootOp.getOperation().getContext()->clearOperation(op); |
757 | else |
758 | data->rootSeen = true; |
759 | return MlirWalkResult::MlirWalkResultAdvance; |
760 | }; |
761 | mlirOperationWalk(op.getOperation(), invalidatingCallback, |
762 | static_cast<void *>(&data), MlirWalkPreOrder); |
763 | } |
764 | void PyMlirContext::clearOperationsInside(MlirOperation op) { |
765 | PyOperationRef opRef = PyOperation::forOperation(getRef(), op); |
766 | clearOperationsInside(opRef->getOperation()); |
767 | } |
768 | |
769 | void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { |
770 | MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, |
771 | void *userData) { |
772 | PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData); |
773 | contextRef->clearOperation(op); |
774 | return MlirWalkResult::MlirWalkResultAdvance; |
775 | }; |
776 | mlirOperationWalk(op.getOperation(), invalidatingCallback, |
777 | &op.getOperation().getContext(), MlirWalkPreOrder); |
778 | } |
779 | |
780 | size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } |
781 | |
782 | nb::object PyMlirContext::contextEnter(nb::object context) { |
783 | return PyThreadContextEntry::pushContext(context); |
784 | } |
785 | |
786 | void PyMlirContext::contextExit(const nb::object &excType, |
787 | const nb::object &excVal, |
788 | const nb::object &excTb) { |
789 | PyThreadContextEntry::popContext(context&: *this); |
790 | } |
791 | |
792 | nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { |
793 | // Note that ownership is transferred to the delete callback below by way of |
794 | // an explicit inc_ref (borrow). |
795 | PyDiagnosticHandler *pyHandler = |
796 | new PyDiagnosticHandler(get(), std::move(callback)); |
797 | nb::object pyHandlerObject = |
798 | nb::cast(pyHandler, nb::rv_policy::take_ownership); |
799 | pyHandlerObject.inc_ref(); |
800 | |
801 | // In these C callbacks, the userData is a PyDiagnosticHandler* that is |
802 | // guaranteed to be known to pybind. |
803 | auto handlerCallback = |
804 | +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { |
805 | PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); |
806 | nb::object pyDiagnosticObject = |
807 | nb::cast(pyDiagnostic, nb::rv_policy::take_ownership); |
808 | |
809 | auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); |
810 | bool result = false; |
811 | { |
812 | // Since this can be called from arbitrary C++ contexts, always get the |
813 | // gil. |
814 | nb::gil_scoped_acquire gil; |
815 | try { |
816 | result = nb::cast<bool>(pyHandler->callback(pyDiagnostic)); |
817 | } catch (std::exception &e) { |
818 | fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", |
819 | e.what()); |
820 | pyHandler->hadError = true; |
821 | } |
822 | } |
823 | |
824 | pyDiagnostic->invalidate(); |
825 | return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); |
826 | }; |
827 | auto deleteCallback = +[](void *userData) { |
828 | auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); |
829 | assert(pyHandler->registeredID && "handler is not registered"); |
830 | pyHandler->registeredID.reset(); |
831 | |
832 | // Decrement reference, balancing the inc_ref() above. |
833 | nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference); |
834 | pyHandlerObject.dec_ref(); |
835 | }; |
836 | |
837 | pyHandler->registeredID = mlirContextAttachDiagnosticHandler( |
838 | get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); |
839 | return pyHandlerObject; |
840 | } |
841 | |
842 | MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, |
843 | void *userData) { |
844 | auto *self = static_cast<ErrorCapture *>(userData); |
845 | // Check if the context requested we emit errors instead of capturing them. |
846 | if (self->ctx->emitErrorDiagnostics) |
847 | return mlirLogicalResultFailure(); |
848 | |
849 | if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) |
850 | return mlirLogicalResultFailure(); |
851 | |
852 | self->errors.emplace_back(args: PyDiagnostic(diag).getInfo()); |
853 | return mlirLogicalResultSuccess(); |
854 | } |
855 | |
856 | PyMlirContext &DefaultingPyMlirContext::resolve() { |
857 | PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); |
858 | if (!context) { |
859 | throw std::runtime_error( |
860 | "An MLIR function requires a Context but none was provided in the call " |
861 | "or from the surrounding environment. Either pass to the function with " |
862 | "a 'context=' argument or establish a default using 'with Context():'"); |
863 | } |
864 | return *context; |
865 | } |
866 | |
867 | //------------------------------------------------------------------------------ |
868 | // PyThreadContextEntry management |
869 | //------------------------------------------------------------------------------ |
870 | |
871 | std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { |
872 | static thread_local std::vector<PyThreadContextEntry> stack; |
873 | return stack; |
874 | } |
875 | |
876 | PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { |
877 | auto &stack = getStack(); |
878 | if (stack.empty()) |
879 | return nullptr; |
880 | return &stack.back(); |
881 | } |
882 | |
883 | void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, |
884 | nb::object insertionPoint, |
885 | nb::object location) { |
886 | auto &stack = getStack(); |
887 | stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), |
888 | std::move(location)); |
889 | // If the new stack has more than one entry and the context of the new top |
890 | // entry matches the previous, copy the insertionPoint and location from the |
891 | // previous entry if missing from the new top entry. |
892 | if (stack.size() > 1) { |
893 | auto &prev = *(stack.rbegin() + 1); |
894 | auto ¤t = stack.back(); |
895 | if (current.context.is(prev.context)) { |
896 | // Default non-context objects from the previous entry. |
897 | if (!current.insertionPoint) |
898 | current.insertionPoint = prev.insertionPoint; |
899 | if (!current.location) |
900 | current.location = prev.location; |
901 | } |
902 | } |
903 | } |
904 | |
905 | PyMlirContext *PyThreadContextEntry::getContext() { |
906 | if (!context) |
907 | return nullptr; |
908 | return nb::cast<PyMlirContext *>(context); |
909 | } |
910 | |
911 | PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { |
912 | if (!insertionPoint) |
913 | return nullptr; |
914 | return nb::cast<PyInsertionPoint *>(insertionPoint); |
915 | } |
916 | |
917 | PyLocation *PyThreadContextEntry::getLocation() { |
918 | if (!location) |
919 | return nullptr; |
920 | return nb::cast<PyLocation *>(location); |
921 | } |
922 | |
923 | PyMlirContext *PyThreadContextEntry::getDefaultContext() { |
924 | auto *tos = getTopOfStack(); |
925 | return tos ? tos->getContext() : nullptr; |
926 | } |
927 | |
928 | PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { |
929 | auto *tos = getTopOfStack(); |
930 | return tos ? tos->getInsertionPoint() : nullptr; |
931 | } |
932 | |
933 | PyLocation *PyThreadContextEntry::getDefaultLocation() { |
934 | auto *tos = getTopOfStack(); |
935 | return tos ? tos->getLocation() : nullptr; |
936 | } |
937 | |
938 | nb::object PyThreadContextEntry::pushContext(nb::object context) { |
939 | push(FrameKind::Context, /*context=*/context, |
940 | /*insertionPoint=*/nb::object(), |
941 | /*location=*/nb::object()); |
942 | return context; |
943 | } |
944 | |
945 | void PyThreadContextEntry::popContext(PyMlirContext &context) { |
946 | auto &stack = getStack(); |
947 | if (stack.empty()) |
948 | throw std::runtime_error("Unbalanced Context enter/exit"); |
949 | auto &tos = stack.back(); |
950 | if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) |
951 | throw std::runtime_error("Unbalanced Context enter/exit"); |
952 | stack.pop_back(); |
953 | } |
954 | |
955 | nb::object |
956 | PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { |
957 | PyInsertionPoint &insertionPoint = |
958 | nb::cast<PyInsertionPoint &>(insertionPointObj); |
959 | nb::object contextObj = |
960 | insertionPoint.getBlock().getParentOperation()->getContext().getObject(); |
961 | push(FrameKind::InsertionPoint, |
962 | /*context=*/contextObj, |
963 | /*insertionPoint=*/insertionPointObj, |
964 | /*location=*/nb::object()); |
965 | return insertionPointObj; |
966 | } |
967 | |
968 | void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { |
969 | auto &stack = getStack(); |
970 | if (stack.empty()) |
971 | throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); |
972 | auto &tos = stack.back(); |
973 | if (tos.frameKind != FrameKind::InsertionPoint && |
974 | tos.getInsertionPoint() != &insertionPoint) |
975 | throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); |
976 | stack.pop_back(); |
977 | } |
978 | |
979 | nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { |
980 | PyLocation &location = nb::cast<PyLocation &>(locationObj); |
981 | nb::object contextObj = location.getContext().getObject(); |
982 | push(FrameKind::Location, /*context=*/contextObj, |
983 | /*insertionPoint=*/nb::object(), |
984 | /*location=*/locationObj); |
985 | return locationObj; |
986 | } |
987 | |
988 | void PyThreadContextEntry::popLocation(PyLocation &location) { |
989 | auto &stack = getStack(); |
990 | if (stack.empty()) |
991 | throw std::runtime_error("Unbalanced Location enter/exit"); |
992 | auto &tos = stack.back(); |
993 | if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) |
994 | throw std::runtime_error("Unbalanced Location enter/exit"); |
995 | stack.pop_back(); |
996 | } |
997 | |
998 | //------------------------------------------------------------------------------ |
999 | // PyDiagnostic* |
1000 | //------------------------------------------------------------------------------ |
1001 | |
1002 | void PyDiagnostic::invalidate() { |
1003 | valid = false; |
1004 | if (materializedNotes) { |
1005 | for (nb::handle noteObject : *materializedNotes) { |
1006 | PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject); |
1007 | note->invalidate(); |
1008 | } |
1009 | } |
1010 | } |
1011 | |
1012 | PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, |
1013 | nb::object callback) |
1014 | : context(context), callback(std::move(callback)) {} |
1015 | |
1016 | PyDiagnosticHandler::~PyDiagnosticHandler() = default; |
1017 | |
1018 | void PyDiagnosticHandler::detach() { |
1019 | if (!registeredID) |
1020 | return; |
1021 | MlirDiagnosticHandlerID localID = *registeredID; |
1022 | mlirContextDetachDiagnosticHandler(context, localID); |
1023 | assert(!registeredID && "should have unregistered"); |
1024 | // Not strictly necessary but keeps stale pointers from being around to cause |
1025 | // issues. |
1026 | context = {nullptr}; |
1027 | } |
1028 | |
1029 | void PyDiagnostic::checkValid() { |
1030 | if (!valid) { |
1031 | throw std::invalid_argument( |
1032 | "Diagnostic is invalid (used outside of callback)"); |
1033 | } |
1034 | } |
1035 | |
1036 | MlirDiagnosticSeverity PyDiagnostic::getSeverity() { |
1037 | checkValid(); |
1038 | return mlirDiagnosticGetSeverity(diagnostic); |
1039 | } |
1040 | |
1041 | PyLocation PyDiagnostic::getLocation() { |
1042 | checkValid(); |
1043 | MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); |
1044 | MlirContext context = mlirLocationGetContext(loc); |
1045 | return PyLocation(PyMlirContext::forContext(context), loc); |
1046 | } |
1047 | |
1048 | nb::str PyDiagnostic::getMessage() { |
1049 | checkValid(); |
1050 | nb::object fileObject = nb::module_::import_("io").attr( "StringIO")(); |
1051 | PyFileAccumulator accum(fileObject, /*binary=*/false); |
1052 | mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); |
1053 | return nb::cast<nb::str>(fileObject.attr("getvalue")()); |
1054 | } |
1055 | |
1056 | nb::tuple PyDiagnostic::getNotes() { |
1057 | checkValid(); |
1058 | if (materializedNotes) |
1059 | return *materializedNotes; |
1060 | intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); |
1061 | nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes)); |
1062 | for (intptr_t i = 0; i < numNotes; ++i) { |
1063 | MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); |
1064 | nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag)); |
1065 | PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr()); |
1066 | } |
1067 | materializedNotes = std::move(notes); |
1068 | |
1069 | return *materializedNotes; |
1070 | } |
1071 | |
1072 | PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { |
1073 | std::vector<DiagnosticInfo> notes; |
1074 | for (nb::handle n : getNotes()) |
1075 | notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo()); |
1076 | return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()), |
1077 | std::move(notes)}; |
1078 | } |
1079 | |
1080 | //------------------------------------------------------------------------------ |
1081 | // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry |
1082 | //------------------------------------------------------------------------------ |
1083 | |
1084 | MlirDialect PyDialects::getDialectForKey(const std::string &key, |
1085 | bool attrError) { |
1086 | MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), |
1087 | {key.data(), key.size()}); |
1088 | if (mlirDialectIsNull(dialect)) { |
1089 | std::string msg = (Twine("Dialect '") + key + "' not found").str(); |
1090 | if (attrError) |
1091 | throw nb::attribute_error(msg.c_str()); |
1092 | throw nb::index_error(msg.c_str()); |
1093 | } |
1094 | return dialect; |
1095 | } |
1096 | |
1097 | nb::object PyDialectRegistry::getCapsule() { |
1098 | return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this)); |
1099 | } |
1100 | |
1101 | PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { |
1102 | MlirDialectRegistry rawRegistry = |
1103 | mlirPythonCapsuleToDialectRegistry(capsule.ptr()); |
1104 | if (mlirDialectRegistryIsNull(rawRegistry)) |
1105 | throw nb::python_error(); |
1106 | return PyDialectRegistry(rawRegistry); |
1107 | } |
1108 | |
1109 | //------------------------------------------------------------------------------ |
1110 | // PyLocation |
1111 | //------------------------------------------------------------------------------ |
1112 | |
1113 | nb::object PyLocation::getCapsule() { |
1114 | return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this)); |
1115 | } |
1116 | |
1117 | PyLocation PyLocation::createFromCapsule(nb::object capsule) { |
1118 | MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); |
1119 | if (mlirLocationIsNull(rawLoc)) |
1120 | throw nb::python_error(); |
1121 | return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), |
1122 | rawLoc); |
1123 | } |
1124 | |
1125 | nb::object PyLocation::contextEnter(nb::object locationObj) { |
1126 | return PyThreadContextEntry::pushLocation(locationObj); |
1127 | } |
1128 | |
1129 | void PyLocation::contextExit(const nb::object &excType, |
1130 | const nb::object &excVal, |
1131 | const nb::object &excTb) { |
1132 | PyThreadContextEntry::popLocation(location&: *this); |
1133 | } |
1134 | |
1135 | PyLocation &DefaultingPyLocation::resolve() { |
1136 | auto *location = PyThreadContextEntry::getDefaultLocation(); |
1137 | if (!location) { |
1138 | throw std::runtime_error( |
1139 | "An MLIR function requires a Location but none was provided in the " |
1140 | "call or from the surrounding environment. Either pass to the function " |
1141 | "with a 'loc=' argument or establish a default using 'with loc:'"); |
1142 | } |
1143 | return *location; |
1144 | } |
1145 | |
1146 | //------------------------------------------------------------------------------ |
1147 | // PyModule |
1148 | //------------------------------------------------------------------------------ |
1149 | |
1150 | PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) |
1151 | : BaseContextObject(std::move(contextRef)), module(module) {} |
1152 | |
1153 | PyModule::~PyModule() { |
1154 | nb::gil_scoped_acquire acquire; |
1155 | auto &liveModules = getContext()->liveModules; |
1156 | assert(liveModules.count(module.ptr) == 1 && |
1157 | "destroying module not in live map"); |
1158 | liveModules.erase(module.ptr); |
1159 | mlirModuleDestroy(module); |
1160 | } |
1161 | |
1162 | PyModuleRef PyModule::forModule(MlirModule module) { |
1163 | MlirContext context = mlirModuleGetContext(module); |
1164 | PyMlirContextRef contextRef = PyMlirContext::forContext(context); |
1165 | |
1166 | nb::gil_scoped_acquire acquire; |
1167 | auto &liveModules = contextRef->liveModules; |
1168 | auto it = liveModules.find(module.ptr); |
1169 | if (it == liveModules.end()) { |
1170 | // Create. |
1171 | PyModule *unownedModule = new PyModule(std::move(contextRef), module); |
1172 | // Note that the default return value policy on cast is automatic_reference, |
1173 | // which does not take ownership (delete will not be called). |
1174 | // Just be explicit. |
1175 | nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); |
1176 | unownedModule->handle = pyRef; |
1177 | liveModules[module.ptr] = |
1178 | std::make_pair(unownedModule->handle, unownedModule); |
1179 | return PyModuleRef(unownedModule, std::move(pyRef)); |
1180 | } |
1181 | // Use existing. |
1182 | PyModule *existing = it->second.second; |
1183 | nb::object pyRef = nb::borrow<nb::object>(it->second.first); |
1184 | return PyModuleRef(existing, std::move(pyRef)); |
1185 | } |
1186 | |
1187 | nb::object PyModule::createFromCapsule(nb::object capsule) { |
1188 | MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); |
1189 | if (mlirModuleIsNull(rawModule)) |
1190 | throw nb::python_error(); |
1191 | return forModule(rawModule).releaseObject(); |
1192 | } |
1193 | |
1194 | nb::object PyModule::getCapsule() { |
1195 | return nb::steal<nb::object>(mlirPythonModuleToCapsule(get())); |
1196 | } |
1197 | |
1198 | //------------------------------------------------------------------------------ |
1199 | // PyOperation |
1200 | //------------------------------------------------------------------------------ |
1201 | |
1202 | PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) |
1203 | : BaseContextObject(std::move(contextRef)), operation(operation) {} |
1204 | |
1205 | PyOperation::~PyOperation() { |
1206 | // If the operation has already been invalidated there is nothing to do. |
1207 | if (!valid) |
1208 | return; |
1209 | |
1210 | // Otherwise, invalidate the operation and remove it from live map when it is |
1211 | // attached. |
1212 | if (isAttached()) { |
1213 | getContext()->clearOperation(*this); |
1214 | } else { |
1215 | // And destroy it when it is detached, i.e. owned by Python, in which case |
1216 | // all nested operations must be invalidated at removed from the live map as |
1217 | // well. |
1218 | erase(); |
1219 | } |
1220 | } |
1221 | |
1222 | namespace { |
1223 | |
1224 | // Constructs a new object of type T in-place on the Python heap, returning a |
1225 | // PyObjectRef to it, loosely analogous to std::make_shared<T>(). |
1226 | template <typename T, class... Args> |
1227 | PyObjectRef<T> makeObjectRef(Args &&...args) { |
1228 | nb::handle type = nb::type<T>(); |
1229 | nb::object instance = nb::inst_alloc(type); |
1230 | T *ptr = nb::inst_ptr<T>(instance); |
1231 | new (ptr) T(std::forward<Args>(args)...); |
1232 | nb::inst_mark_ready(instance); |
1233 | return PyObjectRef<T>(ptr, std::move(instance)); |
1234 | } |
1235 | |
1236 | } // namespace |
1237 | |
1238 | PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, |
1239 | MlirOperation operation, |
1240 | nb::object parentKeepAlive) { |
1241 | // Create. |
1242 | PyOperationRef unownedOperation = |
1243 | makeObjectRef<PyOperation>(std::move(contextRef), operation); |
1244 | unownedOperation->handle = unownedOperation.getObject(); |
1245 | if (parentKeepAlive) { |
1246 | unownedOperation->parentKeepAlive = std::move(parentKeepAlive); |
1247 | } |
1248 | return unownedOperation; |
1249 | } |
1250 | |
1251 | PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, |
1252 | MlirOperation operation, |
1253 | nb::object parentKeepAlive) { |
1254 | nb::ft_lock_guard lock(contextRef->liveOperationsMutex); |
1255 | auto &liveOperations = contextRef->liveOperations; |
1256 | auto it = liveOperations.find(operation.ptr); |
1257 | if (it == liveOperations.end()) { |
1258 | // Create. |
1259 | PyOperationRef result = createInstance(std::move(contextRef), operation, |
1260 | std::move(parentKeepAlive)); |
1261 | liveOperations[operation.ptr] = |
1262 | std::make_pair(result.getObject(), result.get()); |
1263 | return result; |
1264 | } |
1265 | // Use existing. |
1266 | PyOperation *existing = it->second.second; |
1267 | nb::object pyRef = nb::borrow<nb::object>(it->second.first); |
1268 | return PyOperationRef(existing, std::move(pyRef)); |
1269 | } |
1270 | |
1271 | PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, |
1272 | MlirOperation operation, |
1273 | nb::object parentKeepAlive) { |
1274 | nb::ft_lock_guard lock(contextRef->liveOperationsMutex); |
1275 | auto &liveOperations = contextRef->liveOperations; |
1276 | assert(liveOperations.count(operation.ptr) == 0 && |
1277 | "cannot create detached operation that already exists"); |
1278 | (void)liveOperations; |
1279 | PyOperationRef created = createInstance(std::move(contextRef), operation, |
1280 | std::move(parentKeepAlive)); |
1281 | liveOperations[operation.ptr] = |
1282 | std::make_pair(created.getObject(), created.get()); |
1283 | created->attached = false; |
1284 | return created; |
1285 | } |
1286 | |
1287 | PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, |
1288 | const std::string &sourceStr, |
1289 | const std::string &sourceName) { |
1290 | PyMlirContext::ErrorCapture errors(contextRef); |
1291 | MlirOperation op = |
1292 | mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), |
1293 | toMlirStringRef(sourceName)); |
1294 | if (mlirOperationIsNull(op)) |
1295 | throw MLIRError("Unable to parse operation assembly", errors.take()); |
1296 | return PyOperation::createDetached(std::move(contextRef), op); |
1297 | } |
1298 | |
1299 | void PyOperation::checkValid() const { |
1300 | if (!valid) { |
1301 | throw std::runtime_error("the operation has been invalidated"); |
1302 | } |
1303 | } |
1304 | |
1305 | void PyOperationBase::print(std::optional<int64_t> largeElementsLimit, |
1306 | bool enableDebugInfo, bool prettyDebugInfo, |
1307 | bool printGenericOpForm, bool useLocalScope, |
1308 | bool useNameLocAsPrefix, bool assumeVerified, |
1309 | nb::object fileObject, bool binary, |
1310 | bool skipRegions) { |
1311 | PyOperation &operation = getOperation(); |
1312 | operation.checkValid(); |
1313 | if (fileObject.is_none()) |
1314 | fileObject = nb::module_::import_("sys").attr( "stdout"); |
1315 | |
1316 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); |
1317 | if (largeElementsLimit) { |
1318 | mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); |
1319 | mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit); |
1320 | } |
1321 | if (enableDebugInfo) |
1322 | mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, |
1323 | /*prettyForm=*/prettyDebugInfo); |
1324 | if (printGenericOpForm) |
1325 | mlirOpPrintingFlagsPrintGenericOpForm(flags); |
1326 | if (useLocalScope) |
1327 | mlirOpPrintingFlagsUseLocalScope(flags); |
1328 | if (assumeVerified) |
1329 | mlirOpPrintingFlagsAssumeVerified(flags); |
1330 | if (skipRegions) |
1331 | mlirOpPrintingFlagsSkipRegions(flags); |
1332 | if (useNameLocAsPrefix) |
1333 | mlirOpPrintingFlagsPrintNameLocAsPrefix(flags); |
1334 | |
1335 | PyFileAccumulator accum(fileObject, binary); |
1336 | mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), |
1337 | accum.getUserData()); |
1338 | mlirOpPrintingFlagsDestroy(flags); |
1339 | } |
1340 | |
1341 | void PyOperationBase::print(PyAsmState &state, nb::object fileObject, |
1342 | bool binary) { |
1343 | PyOperation &operation = getOperation(); |
1344 | operation.checkValid(); |
1345 | if (fileObject.is_none()) |
1346 | fileObject = nb::module_::import_("sys").attr( "stdout"); |
1347 | PyFileAccumulator accum(fileObject, binary); |
1348 | mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), |
1349 | accum.getUserData()); |
1350 | } |
1351 | |
1352 | void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject, |
1353 | std::optional<int64_t> bytecodeVersion) { |
1354 | PyOperation &operation = getOperation(); |
1355 | operation.checkValid(); |
1356 | PyFileAccumulator accum(fileOrStringObject, /*binary=*/true); |
1357 | |
1358 | if (!bytecodeVersion.has_value()) |
1359 | return mlirOperationWriteBytecode(operation, accum.getCallback(), |
1360 | accum.getUserData()); |
1361 | |
1362 | MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate(); |
1363 | mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); |
1364 | MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig( |
1365 | operation, config, accum.getCallback(), accum.getUserData()); |
1366 | mlirBytecodeWriterConfigDestroy(config); |
1367 | if (mlirLogicalResultIsFailure(res)) |
1368 | throw nb::value_error((Twine("Unable to honor desired bytecode version ") + |
1369 | Twine(*bytecodeVersion)) |
1370 | .str() |
1371 | .c_str()); |
1372 | } |
1373 | |
1374 | void PyOperationBase::walk( |
1375 | std::function<MlirWalkResult(MlirOperation)> callback, |
1376 | MlirWalkOrder walkOrder) { |
1377 | PyOperation &operation = getOperation(); |
1378 | operation.checkValid(); |
1379 | struct UserData { |
1380 | std::function<MlirWalkResult(MlirOperation)> callback; |
1381 | bool gotException; |
1382 | std::string exceptionWhat; |
1383 | nb::object exceptionType; |
1384 | }; |
1385 | UserData userData{callback, false, {}, {}}; |
1386 | MlirOperationWalkCallback walkCallback = [](MlirOperation op, |
1387 | void *userData) { |
1388 | UserData *calleeUserData = static_cast<UserData *>(userData); |
1389 | try { |
1390 | return (calleeUserData->callback)(op); |
1391 | } catch (nb::python_error &e) { |
1392 | calleeUserData->gotException = true; |
1393 | calleeUserData->exceptionWhat = std::string(e.what()); |
1394 | calleeUserData->exceptionType = nb::borrow(e.type()); |
1395 | return MlirWalkResult::MlirWalkResultInterrupt; |
1396 | } |
1397 | }; |
1398 | mlirOperationWalk(operation, walkCallback, &userData, walkOrder); |
1399 | if (userData.gotException) { |
1400 | std::string message("Exception raised in callback: "); |
1401 | message.append(str: userData.exceptionWhat); |
1402 | throw std::runtime_error(message); |
1403 | } |
1404 | } |
1405 | |
1406 | nb::object PyOperationBase::getAsm(bool binary, |
1407 | std::optional<int64_t> largeElementsLimit, |
1408 | bool enableDebugInfo, bool prettyDebugInfo, |
1409 | bool printGenericOpForm, bool useLocalScope, |
1410 | bool useNameLocAsPrefix, bool assumeVerified, |
1411 | bool skipRegions) { |
1412 | nb::object fileObject; |
1413 | if (binary) { |
1414 | fileObject = nb::module_::import_("io").attr( "BytesIO")(); |
1415 | } else { |
1416 | fileObject = nb::module_::import_("io").attr( "StringIO")(); |
1417 | } |
1418 | print(/*largeElementsLimit=*/largeElementsLimit, |
1419 | /*enableDebugInfo=*/enableDebugInfo, |
1420 | /*prettyDebugInfo=*/prettyDebugInfo, |
1421 | /*printGenericOpForm=*/printGenericOpForm, |
1422 | /*useLocalScope=*/useLocalScope, |
1423 | /*useNameLocAsPrefix=*/useNameLocAsPrefix, |
1424 | /*assumeVerified=*/assumeVerified, |
1425 | /*fileObject=*/fileObject, |
1426 | /*binary=*/binary, |
1427 | /*skipRegions=*/skipRegions); |
1428 | |
1429 | return fileObject.attr("getvalue")(); |
1430 | } |
1431 | |
1432 | void PyOperationBase::moveAfter(PyOperationBase &other) { |
1433 | PyOperation &operation = getOperation(); |
1434 | PyOperation &otherOp = other.getOperation(); |
1435 | operation.checkValid(); |
1436 | otherOp.checkValid(); |
1437 | mlirOperationMoveAfter(operation, otherOp); |
1438 | operation.parentKeepAlive = otherOp.parentKeepAlive; |
1439 | } |
1440 | |
1441 | void PyOperationBase::moveBefore(PyOperationBase &other) { |
1442 | PyOperation &operation = getOperation(); |
1443 | PyOperation &otherOp = other.getOperation(); |
1444 | operation.checkValid(); |
1445 | otherOp.checkValid(); |
1446 | mlirOperationMoveBefore(operation, otherOp); |
1447 | operation.parentKeepAlive = otherOp.parentKeepAlive; |
1448 | } |
1449 | |
1450 | bool PyOperationBase::verify() { |
1451 | PyOperation &op = getOperation(); |
1452 | PyMlirContext::ErrorCapture errors(op.getContext()); |
1453 | if (!mlirOperationVerify(op.get())) |
1454 | throw MLIRError("Verification failed", errors.take()); |
1455 | return true; |
1456 | } |
1457 | |
1458 | std::optional<PyOperationRef> PyOperation::getParentOperation() { |
1459 | checkValid(); |
1460 | if (!isAttached()) |
1461 | throw nb::value_error("Detached operations have no parent"); |
1462 | MlirOperation operation = mlirOperationGetParentOperation(get()); |
1463 | if (mlirOperationIsNull(operation)) |
1464 | return {}; |
1465 | return PyOperation::forOperation(getContext(), operation); |
1466 | } |
1467 | |
1468 | PyBlock PyOperation::getBlock() { |
1469 | checkValid(); |
1470 | std::optional<PyOperationRef> parentOperation = getParentOperation(); |
1471 | MlirBlock block = mlirOperationGetBlock(get()); |
1472 | assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); |
1473 | assert(parentOperation && "Operation has no parent"); |
1474 | return PyBlock{std::move(*parentOperation), block}; |
1475 | } |
1476 | |
1477 | nb::object PyOperation::getCapsule() { |
1478 | checkValid(); |
1479 | return nb::steal<nb::object>(mlirPythonOperationToCapsule(get())); |
1480 | } |
1481 | |
1482 | nb::object PyOperation::createFromCapsule(nb::object capsule) { |
1483 | MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); |
1484 | if (mlirOperationIsNull(rawOperation)) |
1485 | throw nb::python_error(); |
1486 | MlirContext rawCtxt = mlirOperationGetContext(rawOperation); |
1487 | return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) |
1488 | .releaseObject(); |
1489 | } |
1490 | |
1491 | static void maybeInsertOperation(PyOperationRef &op, |
1492 | const nb::object &maybeIp) { |
1493 | // InsertPoint active? |
1494 | if (!maybeIp.is(nb::cast(false))) { |
1495 | PyInsertionPoint *ip; |
1496 | if (maybeIp.is_none()) { |
1497 | ip = PyThreadContextEntry::getDefaultInsertionPoint(); |
1498 | } else { |
1499 | ip = nb::cast<PyInsertionPoint *>(maybeIp); |
1500 | } |
1501 | if (ip) |
1502 | ip->insert(*op.get()); |
1503 | } |
1504 | } |
1505 | |
1506 | nb::object PyOperation::create(std::string_view name, |
1507 | std::optional<std::vector<PyType *>> results, |
1508 | llvm::ArrayRef<MlirValue> operands, |
1509 | std::optional<nb::dict> attributes, |
1510 | std::optional<std::vector<PyBlock *>> successors, |
1511 | int regions, DefaultingPyLocation location, |
1512 | const nb::object &maybeIp, bool inferType) { |
1513 | llvm::SmallVector<MlirType, 4> mlirResults; |
1514 | llvm::SmallVector<MlirBlock, 4> mlirSuccessors; |
1515 | llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; |
1516 | |
1517 | // General parameter validation. |
1518 | if (regions < 0) |
1519 | throw nb::value_error("number of regions must be >= 0"); |
1520 | |
1521 | // Unpack/validate results. |
1522 | if (results) { |
1523 | mlirResults.reserve(results->size()); |
1524 | for (PyType *result : *results) { |
1525 | // TODO: Verify result type originate from the same context. |
1526 | if (!result) |
1527 | throw nb::value_error("result type cannot be None"); |
1528 | mlirResults.push_back(*result); |
1529 | } |
1530 | } |
1531 | // Unpack/validate attributes. |
1532 | if (attributes) { |
1533 | mlirAttributes.reserve(attributes->size()); |
1534 | for (std::pair<nb::handle, nb::handle> it : *attributes) { |
1535 | std::string key; |
1536 | try { |
1537 | key = nb::cast<std::string>(it.first); |
1538 | } catch (nb::cast_error &err) { |
1539 | std::string msg = "Invalid attribute key (not a string) when " |
1540 | "attempting to create the operation \""+ |
1541 | std::string(name) + "\" ("+ err.what() + ")"; |
1542 | throw nb::type_error(msg.c_str()); |
1543 | } |
1544 | try { |
1545 | auto &attribute = nb::cast<PyAttribute &>(it.second); |
1546 | // TODO: Verify attribute originates from the same context. |
1547 | mlirAttributes.emplace_back(std::move(key), attribute); |
1548 | } catch (nb::cast_error &err) { |
1549 | std::string msg = "Invalid attribute value for the key \""+ key + |
1550 | "\" when attempting to create the operation \""+ |
1551 | std::string(name) + "\" ("+ err.what() + ")"; |
1552 | throw nb::type_error(msg.c_str()); |
1553 | } catch (std::runtime_error &) { |
1554 | // This exception seems thrown when the value is "None". |
1555 | std::string msg = |
1556 | "Found an invalid (`None`?) attribute value for the key \""+ key + |
1557 | "\" when attempting to create the operation \""+ |
1558 | std::string(name) + "\""; |
1559 | throw std::runtime_error(msg); |
1560 | } |
1561 | } |
1562 | } |
1563 | // Unpack/validate successors. |
1564 | if (successors) { |
1565 | mlirSuccessors.reserve(successors->size()); |
1566 | for (auto *successor : *successors) { |
1567 | // TODO: Verify successor originate from the same context. |
1568 | if (!successor) |
1569 | throw nb::value_error("successor block cannot be None"); |
1570 | mlirSuccessors.push_back(successor->get()); |
1571 | } |
1572 | } |
1573 | |
1574 | // Apply unpacked/validated to the operation state. Beyond this |
1575 | // point, exceptions cannot be thrown or else the state will leak. |
1576 | MlirOperationState state = |
1577 | mlirOperationStateGet(toMlirStringRef(name), location); |
1578 | if (!operands.empty()) |
1579 | mlirOperationStateAddOperands(&state, operands.size(), operands.data()); |
1580 | state.enableResultTypeInference = inferType; |
1581 | if (!mlirResults.empty()) |
1582 | mlirOperationStateAddResults(&state, mlirResults.size(), |
1583 | mlirResults.data()); |
1584 | if (!mlirAttributes.empty()) { |
1585 | // Note that the attribute names directly reference bytes in |
1586 | // mlirAttributes, so that vector must not be changed from here |
1587 | // on. |
1588 | llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; |
1589 | mlirNamedAttributes.reserve(mlirAttributes.size()); |
1590 | for (auto &it : mlirAttributes) |
1591 | mlirNamedAttributes.push_back(mlirNamedAttributeGet( |
1592 | mlirIdentifierGet(mlirAttributeGetContext(it.second), |
1593 | toMlirStringRef(it.first)), |
1594 | it.second)); |
1595 | mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), |
1596 | mlirNamedAttributes.data()); |
1597 | } |
1598 | if (!mlirSuccessors.empty()) |
1599 | mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), |
1600 | mlirSuccessors.data()); |
1601 | if (regions) { |
1602 | llvm::SmallVector<MlirRegion, 4> mlirRegions; |
1603 | mlirRegions.resize(regions); |
1604 | for (int i = 0; i < regions; ++i) |
1605 | mlirRegions[i] = mlirRegionCreate(); |
1606 | mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), |
1607 | mlirRegions.data()); |
1608 | } |
1609 | |
1610 | // Construct the operation. |
1611 | MlirOperation operation = mlirOperationCreate(&state); |
1612 | if (!operation.ptr) |
1613 | throw nb::value_error("Operation creation failed"); |
1614 | PyOperationRef created = |
1615 | PyOperation::createDetached(location->getContext(), operation); |
1616 | maybeInsertOperation(created, maybeIp); |
1617 | |
1618 | return created.getObject(); |
1619 | } |
1620 | |
1621 | nb::object PyOperation::clone(const nb::object &maybeIp) { |
1622 | MlirOperation clonedOperation = mlirOperationClone(operation); |
1623 | PyOperationRef cloned = |
1624 | PyOperation::createDetached(getContext(), clonedOperation); |
1625 | maybeInsertOperation(cloned, maybeIp); |
1626 | |
1627 | return cloned->createOpView(); |
1628 | } |
1629 | |
1630 | nb::object PyOperation::createOpView() { |
1631 | checkValid(); |
1632 | MlirIdentifier ident = mlirOperationGetName(get()); |
1633 | MlirStringRef identStr = mlirIdentifierStr(ident); |
1634 | auto operationCls = PyGlobals::get().lookupOperationClass( |
1635 | StringRef(identStr.data, identStr.length)); |
1636 | if (operationCls) |
1637 | return PyOpView::constructDerived(*operationCls, getRef().getObject()); |
1638 | return nb::cast(PyOpView(getRef().getObject())); |
1639 | } |
1640 | |
1641 | void PyOperation::erase() { |
1642 | checkValid(); |
1643 | getContext()->clearOperationAndInside(*this); |
1644 | mlirOperationDestroy(operation); |
1645 | } |
1646 | |
1647 | namespace { |
1648 | /// CRTP base class for Python MLIR values that subclass Value and should be |
1649 | /// castable from it. The value hierarchy is one level deep and is not supposed |
1650 | /// to accommodate other levels unless core MLIR changes. |
1651 | template <typename DerivedTy> |
1652 | class PyConcreteValue : public PyValue { |
1653 | public: |
1654 | // Derived classes must define statics for: |
1655 | // IsAFunctionTy isaFunction |
1656 | // const char *pyClassName |
1657 | // and redefine bindDerived. |
1658 | using ClassTy = nb::class_<DerivedTy, PyValue>; |
1659 | using IsAFunctionTy = bool (*)(MlirValue); |
1660 | |
1661 | PyConcreteValue() = default; |
1662 | PyConcreteValue(PyOperationRef operationRef, MlirValue value) |
1663 | : PyValue(operationRef, value) {} |
1664 | PyConcreteValue(PyValue &orig) |
1665 | : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} |
1666 | |
1667 | /// Attempts to cast the original value to the derived type and throws on |
1668 | /// type mismatches. |
1669 | static MlirValue castFrom(PyValue &orig) { |
1670 | if (!DerivedTy::isaFunction(orig.get())) { |
1671 | auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig))); |
1672 | throw nb::value_error((Twine("Cannot cast value to ") + |
1673 | DerivedTy::pyClassName + " (from "+ origRepr + |
1674 | ")") |
1675 | .str() |
1676 | .c_str()); |
1677 | } |
1678 | return orig.get(); |
1679 | } |
1680 | |
1681 | /// Binds the Python module objects to functions of this class. |
1682 | static void bind(nb::module_ &m) { |
1683 | auto cls = ClassTy(m, DerivedTy::pyClassName); |
1684 | cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")); |
1685 | cls.def_static( |
1686 | "isinstance", |
1687 | [](PyValue &otherValue) -> bool { |
1688 | return DerivedTy::isaFunction(otherValue); |
1689 | }, |
1690 | nb::arg("other_value")); |
1691 | cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, |
1692 | [](DerivedTy &self) { return self.maybeDownCast(); }); |
1693 | DerivedTy::bindDerived(cls); |
1694 | } |
1695 | |
1696 | /// Implemented by derived classes to add methods to the Python subclass. |
1697 | static void bindDerived(ClassTy &m) {} |
1698 | }; |
1699 | |
1700 | } // namespace |
1701 | |
1702 | /// Python wrapper for MlirOpResult. |
1703 | class PyOpResult : public PyConcreteValue<PyOpResult> { |
1704 | public: |
1705 | static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; |
1706 | static constexpr const char *pyClassName = "OpResult"; |
1707 | using PyConcreteValue::PyConcreteValue; |
1708 | |
1709 | static void bindDerived(ClassTy &c) { |
1710 | c.def_prop_ro("owner", [](PyOpResult &self) { |
1711 | assert( |
1712 | mlirOperationEqual(self.getParentOperation()->get(), |
1713 | mlirOpResultGetOwner(self.get())) && |
1714 | "expected the owner of the value in Python to match that in the IR"); |
1715 | return self.getParentOperation().getObject(); |
1716 | }); |
1717 | c.def_prop_ro("result_number", [](PyOpResult &self) { |
1718 | return mlirOpResultGetResultNumber(self.get()); |
1719 | }); |
1720 | } |
1721 | }; |
1722 | |
1723 | /// Returns the list of types of the values held by container. |
1724 | template <typename Container> |
1725 | static std::vector<MlirType> getValueTypes(Container &container, |
1726 | PyMlirContextRef &context) { |
1727 | std::vector<MlirType> result; |
1728 | result.reserve(container.size()); |
1729 | for (int i = 0, e = container.size(); i < e; ++i) { |
1730 | result.push_back(mlirValueGetType(container.getElement(i).get())); |
1731 | } |
1732 | return result; |
1733 | } |
1734 | |
1735 | /// A list of operation results. Internally, these are stored as consecutive |
1736 | /// elements, random access is cheap. The (returned) result list is associated |
1737 | /// with the operation whose results these are, and thus extends the lifetime of |
1738 | /// this operation. |
1739 | class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { |
1740 | public: |
1741 | static constexpr const char *pyClassName = "OpResultList"; |
1742 | using SliceableT = Sliceable<PyOpResultList, PyOpResult>; |
1743 | |
1744 | PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, |
1745 | intptr_t length = -1, intptr_t step = 1) |
1746 | : Sliceable(startIndex, |
1747 | length == -1 ? mlirOperationGetNumResults(operation->get()) |
1748 | : length, |
1749 | step), |
1750 | operation(std::move(operation)) {} |
1751 | |
1752 | static void bindDerived(ClassTy &c) { |
1753 | c.def_prop_ro("types", [](PyOpResultList &self) { |
1754 | return getValueTypes(self, self.operation->getContext()); |
1755 | }); |
1756 | c.def_prop_ro("owner", [](PyOpResultList &self) { |
1757 | return self.operation->createOpView(); |
1758 | }); |
1759 | } |
1760 | |
1761 | PyOperationRef &getOperation() { return operation; } |
1762 | |
1763 | private: |
1764 | /// Give the parent CRTP class access to hook implementations below. |
1765 | friend class Sliceable<PyOpResultList, PyOpResult>; |
1766 | |
1767 | intptr_t getRawNumElements() { |
1768 | operation->checkValid(); |
1769 | return mlirOperationGetNumResults(operation->get()); |
1770 | } |
1771 | |
1772 | PyOpResult getRawElement(intptr_t index) { |
1773 | PyValue value(operation, mlirOperationGetResult(operation->get(), index)); |
1774 | return PyOpResult(value); |
1775 | } |
1776 | |
1777 | PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
1778 | return PyOpResultList(operation, startIndex, length, step); |
1779 | } |
1780 | |
1781 | PyOperationRef operation; |
1782 | }; |
1783 | |
1784 | //------------------------------------------------------------------------------ |
1785 | // PyOpView |
1786 | //------------------------------------------------------------------------------ |
1787 | |
1788 | static void populateResultTypes(StringRef name, nb::list resultTypeList, |
1789 | const nb::object &resultSegmentSpecObj, |
1790 | std::vector<int32_t> &resultSegmentLengths, |
1791 | std::vector<PyType *> &resultTypes) { |
1792 | resultTypes.reserve(n: resultTypeList.size()); |
1793 | if (resultSegmentSpecObj.is_none()) { |
1794 | // Non-variadic result unpacking. |
1795 | for (const auto &it : llvm::enumerate(resultTypeList)) { |
1796 | try { |
1797 | resultTypes.push_back(nb::cast<PyType *>(it.value())); |
1798 | if (!resultTypes.back()) |
1799 | throw nb::cast_error(); |
1800 | } catch (nb::cast_error &err) { |
1801 | throw nb::value_error((llvm::Twine("Result ") + |
1802 | llvm::Twine(it.index()) + " of operation \""+ |
1803 | name + "\" must be a Type ("+ err.what() + ")") |
1804 | .str() |
1805 | .c_str()); |
1806 | } |
1807 | } |
1808 | } else { |
1809 | // Sized result unpacking. |
1810 | auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj); |
1811 | if (resultSegmentSpec.size() != resultTypeList.size()) { |
1812 | throw nb::value_error((llvm::Twine("Operation \"") + name + |
1813 | "\" requires "+ |
1814 | llvm::Twine(resultSegmentSpec.size()) + |
1815 | " result segments but was provided "+ |
1816 | llvm::Twine(resultTypeList.size())) |
1817 | .str() |
1818 | .c_str()); |
1819 | } |
1820 | resultSegmentLengths.reserve(n: resultTypeList.size()); |
1821 | for (const auto &it : |
1822 | llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { |
1823 | int segmentSpec = std::get<1>(it.value()); |
1824 | if (segmentSpec == 1 || segmentSpec == 0) { |
1825 | // Unpack unary element. |
1826 | try { |
1827 | auto *resultType = nb::cast<PyType *>(std::get<0>(it.value())); |
1828 | if (resultType) { |
1829 | resultTypes.push_back(resultType); |
1830 | resultSegmentLengths.push_back(1); |
1831 | } else if (segmentSpec == 0) { |
1832 | // Allowed to be optional. |
1833 | resultSegmentLengths.push_back(0); |
1834 | } else { |
1835 | throw nb::value_error( |
1836 | (llvm::Twine("Result ") + llvm::Twine(it.index()) + |
1837 | " of operation \""+ name + |
1838 | "\" must be a Type (was None and result is not optional)") |
1839 | .str() |
1840 | .c_str()); |
1841 | } |
1842 | } catch (nb::cast_error &err) { |
1843 | throw nb::value_error((llvm::Twine("Result ") + |
1844 | llvm::Twine(it.index()) + " of operation \""+ |
1845 | name + "\" must be a Type ("+ err.what() + |
1846 | ")") |
1847 | .str() |
1848 | .c_str()); |
1849 | } |
1850 | } else if (segmentSpec == -1) { |
1851 | // Unpack sequence by appending. |
1852 | try { |
1853 | if (std::get<0>(it.value()).is_none()) { |
1854 | // Treat it as an empty list. |
1855 | resultSegmentLengths.push_back(0); |
1856 | } else { |
1857 | // Unpack the list. |
1858 | auto segment = nb::cast<nb::sequence>(std::get<0>(it.value())); |
1859 | for (nb::handle segmentItem : segment) { |
1860 | resultTypes.push_back(nb::cast<PyType *>(segmentItem)); |
1861 | if (!resultTypes.back()) { |
1862 | throw nb::type_error("contained a None item"); |
1863 | } |
1864 | } |
1865 | resultSegmentLengths.push_back(nb::len(segment)); |
1866 | } |
1867 | } catch (std::exception &err) { |
1868 | // NOTE: Sloppy to be using a catch-all here, but there are at least |
1869 | // three different unrelated exceptions that can be thrown in the |
1870 | // above "casts". Just keep the scope above small and catch them all. |
1871 | throw nb::value_error((llvm::Twine("Result ") + |
1872 | llvm::Twine(it.index()) + " of operation \""+ |
1873 | name + "\" must be a Sequence of Types ("+ |
1874 | err.what() + ")") |
1875 | .str() |
1876 | .c_str()); |
1877 | } |
1878 | } else { |
1879 | throw nb::value_error("Unexpected segment spec"); |
1880 | } |
1881 | } |
1882 | } |
1883 | } |
1884 | |
1885 | static MlirValue getUniqueResult(MlirOperation operation) { |
1886 | auto numResults = mlirOperationGetNumResults(operation); |
1887 | if (numResults != 1) { |
1888 | auto name = mlirIdentifierStr(mlirOperationGetName(operation)); |
1889 | throw nb::value_error((Twine("Cannot call .result on operation ") + |
1890 | StringRef(name.data, name.length) + " which has "+ |
1891 | Twine(numResults) + |
1892 | " results (it is only valid for operations with a " |
1893 | "single result)") |
1894 | .str() |
1895 | .c_str()); |
1896 | } |
1897 | return mlirOperationGetResult(operation, 0); |
1898 | } |
1899 | |
1900 | static MlirValue getOpResultOrValue(nb::handle operand) { |
1901 | if (operand.is_none()) { |
1902 | throw nb::value_error("contained a None item"); |
1903 | } |
1904 | PyOperationBase *op; |
1905 | if (nb::try_cast<PyOperationBase *>(operand, op)) { |
1906 | return getUniqueResult(op->getOperation()); |
1907 | } |
1908 | PyOpResultList *opResultList; |
1909 | if (nb::try_cast<PyOpResultList *>(operand, opResultList)) { |
1910 | return getUniqueResult(opResultList->getOperation()->get()); |
1911 | } |
1912 | PyValue *value; |
1913 | if (nb::try_cast<PyValue *>(operand, value)) { |
1914 | return value->get(); |
1915 | } |
1916 | throw nb::value_error("is not a Value"); |
1917 | } |
1918 | |
1919 | nb::object PyOpView::buildGeneric( |
1920 | std::string_view name, std::tuple<int, bool> opRegionSpec, |
1921 | nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj, |
1922 | std::optional<nb::list> resultTypeList, nb::list operandList, |
1923 | std::optional<nb::dict> attributes, |
1924 | std::optional<std::vector<PyBlock *>> successors, |
1925 | std::optional<int> regions, DefaultingPyLocation location, |
1926 | const nb::object &maybeIp) { |
1927 | PyMlirContextRef context = location->getContext(); |
1928 | |
1929 | // Class level operation construction metadata. |
1930 | // Operand and result segment specs are either none, which does no |
1931 | // variadic unpacking, or a list of ints with segment sizes, where each |
1932 | // element is either a positive number (typically 1 for a scalar) or -1 to |
1933 | // indicate that it is derived from the length of the same-indexed operand |
1934 | // or result (implying that it is a list at that position). |
1935 | std::vector<int32_t> operandSegmentLengths; |
1936 | std::vector<int32_t> resultSegmentLengths; |
1937 | |
1938 | // Validate/determine region count. |
1939 | int opMinRegionCount = std::get<0>(t&: opRegionSpec); |
1940 | bool opHasNoVariadicRegions = std::get<1>(t&: opRegionSpec); |
1941 | if (!regions) { |
1942 | regions = opMinRegionCount; |
1943 | } |
1944 | if (*regions < opMinRegionCount) { |
1945 | throw nb::value_error( |
1946 | (llvm::Twine("Operation \"") + name + "\" requires a minimum of "+ |
1947 | llvm::Twine(opMinRegionCount) + |
1948 | " regions but was built with regions="+ llvm::Twine(*regions)) |
1949 | .str() |
1950 | .c_str()); |
1951 | } |
1952 | if (opHasNoVariadicRegions && *regions > opMinRegionCount) { |
1953 | throw nb::value_error( |
1954 | (llvm::Twine("Operation \"") + name + "\" requires a maximum of "+ |
1955 | llvm::Twine(opMinRegionCount) + |
1956 | " regions but was built with regions="+ llvm::Twine(*regions)) |
1957 | .str() |
1958 | .c_str()); |
1959 | } |
1960 | |
1961 | // Unpack results. |
1962 | std::vector<PyType *> resultTypes; |
1963 | if (resultTypeList.has_value()) { |
1964 | populateResultTypes(name, *resultTypeList, resultSegmentSpecObj, |
1965 | resultSegmentLengths, resultTypes); |
1966 | } |
1967 | |
1968 | // Unpack operands. |
1969 | llvm::SmallVector<MlirValue, 4> operands; |
1970 | operands.reserve(operands.size()); |
1971 | if (operandSegmentSpecObj.is_none()) { |
1972 | // Non-sized operand unpacking. |
1973 | for (const auto &it : llvm::enumerate(operandList)) { |
1974 | try { |
1975 | operands.push_back(getOpResultOrValue(it.value())); |
1976 | } catch (nb::builtin_exception &err) { |
1977 | throw nb::value_error((llvm::Twine("Operand ") + |
1978 | llvm::Twine(it.index()) + " of operation \""+ |
1979 | name + "\" must be a Value ("+ err.what() + ")") |
1980 | .str() |
1981 | .c_str()); |
1982 | } |
1983 | } |
1984 | } else { |
1985 | // Sized operand unpacking. |
1986 | auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj); |
1987 | if (operandSegmentSpec.size() != operandList.size()) { |
1988 | throw nb::value_error((llvm::Twine("Operation \"") + name + |
1989 | "\" requires "+ |
1990 | llvm::Twine(operandSegmentSpec.size()) + |
1991 | "operand segments but was provided "+ |
1992 | llvm::Twine(operandList.size())) |
1993 | .str() |
1994 | .c_str()); |
1995 | } |
1996 | operandSegmentLengths.reserve(n: operandList.size()); |
1997 | for (const auto &it : |
1998 | llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { |
1999 | int segmentSpec = std::get<1>(it.value()); |
2000 | if (segmentSpec == 1 || segmentSpec == 0) { |
2001 | // Unpack unary element. |
2002 | auto &operand = std::get<0>(it.value()); |
2003 | if (!operand.is_none()) { |
2004 | try { |
2005 | |
2006 | operands.push_back(getOpResultOrValue(operand)); |
2007 | } catch (nb::builtin_exception &err) { |
2008 | throw nb::value_error((llvm::Twine("Operand ") + |
2009 | llvm::Twine(it.index()) + |
2010 | " of operation \""+ name + |
2011 | "\" must be a Value ("+ err.what() + ")") |
2012 | .str() |
2013 | .c_str()); |
2014 | } |
2015 | |
2016 | operandSegmentLengths.push_back(1); |
2017 | } else if (segmentSpec == 0) { |
2018 | // Allowed to be optional. |
2019 | operandSegmentLengths.push_back(0); |
2020 | } else { |
2021 | throw nb::value_error( |
2022 | (llvm::Twine("Operand ") + llvm::Twine(it.index()) + |
2023 | " of operation \""+ name + |
2024 | "\" must be a Value (was None and operand is not optional)") |
2025 | .str() |
2026 | .c_str()); |
2027 | } |
2028 | } else if (segmentSpec == -1) { |
2029 | // Unpack sequence by appending. |
2030 | try { |
2031 | if (std::get<0>(it.value()).is_none()) { |
2032 | // Treat it as an empty list. |
2033 | operandSegmentLengths.push_back(0); |
2034 | } else { |
2035 | // Unpack the list. |
2036 | auto segment = nb::cast<nb::sequence>(std::get<0>(it.value())); |
2037 | for (nb::handle segmentItem : segment) { |
2038 | operands.push_back(getOpResultOrValue(segmentItem)); |
2039 | } |
2040 | operandSegmentLengths.push_back(nb::len(segment)); |
2041 | } |
2042 | } catch (std::exception &err) { |
2043 | // NOTE: Sloppy to be using a catch-all here, but there are at least |
2044 | // three different unrelated exceptions that can be thrown in the |
2045 | // above "casts". Just keep the scope above small and catch them all. |
2046 | throw nb::value_error((llvm::Twine("Operand ") + |
2047 | llvm::Twine(it.index()) + " of operation \""+ |
2048 | name + "\" must be a Sequence of Values ("+ |
2049 | err.what() + ")") |
2050 | .str() |
2051 | .c_str()); |
2052 | } |
2053 | } else { |
2054 | throw nb::value_error("Unexpected segment spec"); |
2055 | } |
2056 | } |
2057 | } |
2058 | |
2059 | // Merge operand/result segment lengths into attributes if needed. |
2060 | if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { |
2061 | // Dup. |
2062 | if (attributes) { |
2063 | attributes = nb::dict(*attributes); |
2064 | } else { |
2065 | attributes = nb::dict(); |
2066 | } |
2067 | if (attributes->contains("resultSegmentSizes") || |
2068 | attributes->contains("operandSegmentSizes")) { |
2069 | throw nb::value_error("Manually setting a 'resultSegmentSizes' or " |
2070 | "'operandSegmentSizes' attribute is unsupported. " |
2071 | "Use Operation.create for such low-level access."); |
2072 | } |
2073 | |
2074 | // Add resultSegmentSizes attribute. |
2075 | if (!resultSegmentLengths.empty()) { |
2076 | MlirAttribute segmentLengthAttr = |
2077 | mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(), |
2078 | resultSegmentLengths.data()); |
2079 | (*attributes)["resultSegmentSizes"] = |
2080 | PyAttribute(context, segmentLengthAttr); |
2081 | } |
2082 | |
2083 | // Add operandSegmentSizes attribute. |
2084 | if (!operandSegmentLengths.empty()) { |
2085 | MlirAttribute segmentLengthAttr = |
2086 | mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(), |
2087 | operandSegmentLengths.data()); |
2088 | (*attributes)["operandSegmentSizes"] = |
2089 | PyAttribute(context, segmentLengthAttr); |
2090 | } |
2091 | } |
2092 | |
2093 | // Delegate to create. |
2094 | return PyOperation::create(name, |
2095 | /*results=*/std::move(resultTypes), |
2096 | /*operands=*/std::move(operands), |
2097 | /*attributes=*/std::move(attributes), |
2098 | /*successors=*/std::move(successors), |
2099 | /*regions=*/*regions, location, maybeIp, |
2100 | !resultTypeList); |
2101 | } |
2102 | |
2103 | nb::object PyOpView::constructDerived(const nb::object &cls, |
2104 | const nb::object &operation) { |
2105 | nb::handle opViewType = nb::type<PyOpView>(); |
2106 | nb::object instance = cls.attr("__new__")(cls); |
2107 | opViewType.attr("__init__")(instance, operation); |
2108 | return instance; |
2109 | } |
2110 | |
2111 | PyOpView::PyOpView(const nb::object &operationObject) |
2112 | // Casting through the PyOperationBase base-class and then back to the |
2113 | // Operation lets us accept any PyOperationBase subclass. |
2114 | : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()), |
2115 | operationObject(operation.getRef().getObject()) {} |
2116 | |
2117 | //------------------------------------------------------------------------------ |
2118 | // PyInsertionPoint. |
2119 | //------------------------------------------------------------------------------ |
2120 | |
2121 | PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} |
2122 | |
2123 | PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) |
2124 | : refOperation(beforeOperationBase.getOperation().getRef()), |
2125 | block((*refOperation)->getBlock()) {} |
2126 | |
2127 | void PyInsertionPoint::insert(PyOperationBase &operationBase) { |
2128 | PyOperation &operation = operationBase.getOperation(); |
2129 | if (operation.isAttached()) |
2130 | throw nb::value_error( |
2131 | "Attempt to insert operation that is already attached"); |
2132 | block.getParentOperation()->checkValid(); |
2133 | MlirOperation beforeOp = {nullptr}; |
2134 | if (refOperation) { |
2135 | // Insert before operation. |
2136 | (*refOperation)->checkValid(); |
2137 | beforeOp = (*refOperation)->get(); |
2138 | } else { |
2139 | // Insert at end (before null) is only valid if the block does not |
2140 | // already end in a known terminator (violating this will cause assertion |
2141 | // failures later). |
2142 | if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { |
2143 | throw nb::index_error("Cannot insert operation at the end of a block " |
2144 | "that already has a terminator. Did you mean to " |
2145 | "use 'InsertionPoint.at_block_terminator(block)' " |
2146 | "versus 'InsertionPoint(block)'?"); |
2147 | } |
2148 | } |
2149 | mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); |
2150 | operation.setAttached(); |
2151 | } |
2152 | |
2153 | PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { |
2154 | MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); |
2155 | if (mlirOperationIsNull(firstOp)) { |
2156 | // Just insert at end. |
2157 | return PyInsertionPoint(block); |
2158 | } |
2159 | |
2160 | // Insert before first op. |
2161 | PyOperationRef firstOpRef = PyOperation::forOperation( |
2162 | block.getParentOperation()->getContext(), firstOp); |
2163 | return PyInsertionPoint{block, std::move(firstOpRef)}; |
2164 | } |
2165 | |
2166 | PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { |
2167 | MlirOperation terminator = mlirBlockGetTerminator(block.get()); |
2168 | if (mlirOperationIsNull(terminator)) |
2169 | throw nb::value_error("Block has no terminator"); |
2170 | PyOperationRef terminatorOpRef = PyOperation::forOperation( |
2171 | block.getParentOperation()->getContext(), terminator); |
2172 | return PyInsertionPoint{block, std::move(terminatorOpRef)}; |
2173 | } |
2174 | |
2175 | nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { |
2176 | return PyThreadContextEntry::pushInsertionPoint(insertPoint); |
2177 | } |
2178 | |
2179 | void PyInsertionPoint::contextExit(const nb::object &excType, |
2180 | const nb::object &excVal, |
2181 | const nb::object &excTb) { |
2182 | PyThreadContextEntry::popInsertionPoint(insertionPoint&: *this); |
2183 | } |
2184 | |
2185 | //------------------------------------------------------------------------------ |
2186 | // PyAttribute. |
2187 | //------------------------------------------------------------------------------ |
2188 | |
2189 | bool PyAttribute::operator==(const PyAttribute &other) const { |
2190 | return mlirAttributeEqual(attr, other.attr); |
2191 | } |
2192 | |
2193 | nb::object PyAttribute::getCapsule() { |
2194 | return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this)); |
2195 | } |
2196 | |
2197 | PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { |
2198 | MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); |
2199 | if (mlirAttributeIsNull(rawAttr)) |
2200 | throw nb::python_error(); |
2201 | return PyAttribute( |
2202 | PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); |
2203 | } |
2204 | |
2205 | //------------------------------------------------------------------------------ |
2206 | // PyNamedAttribute. |
2207 | //------------------------------------------------------------------------------ |
2208 | |
2209 | PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) |
2210 | : ownedName(new std::string(std::move(ownedName))) { |
2211 | namedAttr = mlirNamedAttributeGet( |
2212 | mlirIdentifierGet(mlirAttributeGetContext(attr), |
2213 | toMlirStringRef(*this->ownedName)), |
2214 | attr); |
2215 | } |
2216 | |
2217 | //------------------------------------------------------------------------------ |
2218 | // PyType. |
2219 | //------------------------------------------------------------------------------ |
2220 | |
2221 | bool PyType::operator==(const PyType &other) const { |
2222 | return mlirTypeEqual(type, other.type); |
2223 | } |
2224 | |
2225 | nb::object PyType::getCapsule() { |
2226 | return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this)); |
2227 | } |
2228 | |
2229 | PyType PyType::createFromCapsule(nb::object capsule) { |
2230 | MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); |
2231 | if (mlirTypeIsNull(rawType)) |
2232 | throw nb::python_error(); |
2233 | return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), |
2234 | rawType); |
2235 | } |
2236 | |
2237 | //------------------------------------------------------------------------------ |
2238 | // PyTypeID. |
2239 | //------------------------------------------------------------------------------ |
2240 | |
2241 | nb::object PyTypeID::getCapsule() { |
2242 | return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this)); |
2243 | } |
2244 | |
2245 | PyTypeID PyTypeID::createFromCapsule(nb::object capsule) { |
2246 | MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); |
2247 | if (mlirTypeIDIsNull(mlirTypeID)) |
2248 | throw nb::python_error(); |
2249 | return PyTypeID(mlirTypeID); |
2250 | } |
2251 | bool PyTypeID::operator==(const PyTypeID &other) const { |
2252 | return mlirTypeIDEqual(typeID, other.typeID); |
2253 | } |
2254 | |
2255 | //------------------------------------------------------------------------------ |
2256 | // PyValue and subclasses. |
2257 | //------------------------------------------------------------------------------ |
2258 | |
2259 | nb::object PyValue::getCapsule() { |
2260 | return nb::steal<nb::object>(mlirPythonValueToCapsule(get())); |
2261 | } |
2262 | |
2263 | nb::object PyValue::maybeDownCast() { |
2264 | MlirType type = mlirValueGetType(get()); |
2265 | MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); |
2266 | assert(!mlirTypeIDIsNull(mlirTypeID) && |
2267 | "mlirTypeID was expected to be non-null."); |
2268 | std::optional<nb::callable> valueCaster = |
2269 | PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); |
2270 | // nb::rv_policy::move means use std::move to move the return value |
2271 | // contents into a new instance that will be owned by Python. |
2272 | nb::object thisObj = nb::cast(this, nb::rv_policy::move); |
2273 | if (!valueCaster) |
2274 | return thisObj; |
2275 | return valueCaster.value()(thisObj); |
2276 | } |
2277 | |
2278 | PyValue PyValue::createFromCapsule(nb::object capsule) { |
2279 | MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); |
2280 | if (mlirValueIsNull(value)) |
2281 | throw nb::python_error(); |
2282 | MlirOperation owner; |
2283 | if (mlirValueIsAOpResult(value)) |
2284 | owner = mlirOpResultGetOwner(value); |
2285 | if (mlirValueIsABlockArgument(value)) |
2286 | owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); |
2287 | if (mlirOperationIsNull(owner)) |
2288 | throw nb::python_error(); |
2289 | MlirContext ctx = mlirOperationGetContext(owner); |
2290 | PyOperationRef ownerRef = |
2291 | PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); |
2292 | return PyValue(ownerRef, value); |
2293 | } |
2294 | |
2295 | //------------------------------------------------------------------------------ |
2296 | // PySymbolTable. |
2297 | //------------------------------------------------------------------------------ |
2298 | |
2299 | PySymbolTable::PySymbolTable(PyOperationBase &operation) |
2300 | : operation(operation.getOperation().getRef()) { |
2301 | symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); |
2302 | if (mlirSymbolTableIsNull(symbolTable)) { |
2303 | throw nb::type_error("Operation is not a Symbol Table."); |
2304 | } |
2305 | } |
2306 | |
2307 | nb::object PySymbolTable::dunderGetItem(const std::string &name) { |
2308 | operation->checkValid(); |
2309 | MlirOperation symbol = mlirSymbolTableLookup( |
2310 | symbolTable, mlirStringRefCreate(name.data(), name.length())); |
2311 | if (mlirOperationIsNull(symbol)) |
2312 | throw nb::key_error( |
2313 | ("Symbol '"+ name + "' not in the symbol table.").c_str()); |
2314 | |
2315 | return PyOperation::forOperation(operation->getContext(), symbol, |
2316 | operation.getObject()) |
2317 | ->createOpView(); |
2318 | } |
2319 | |
2320 | void PySymbolTable::erase(PyOperationBase &symbol) { |
2321 | operation->checkValid(); |
2322 | symbol.getOperation().checkValid(); |
2323 | mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); |
2324 | // The operation is also erased, so we must invalidate it. There may be Python |
2325 | // references to this operation so we don't want to delete it from the list of |
2326 | // live operations here. |
2327 | symbol.getOperation().valid = false; |
2328 | } |
2329 | |
2330 | void PySymbolTable::dunderDel(const std::string &name) { |
2331 | nb::object operation = dunderGetItem(name); |
2332 | erase(nb::cast<PyOperationBase &>(operation)); |
2333 | } |
2334 | |
2335 | MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { |
2336 | operation->checkValid(); |
2337 | symbol.getOperation().checkValid(); |
2338 | MlirAttribute symbolAttr = mlirOperationGetAttributeByName( |
2339 | symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); |
2340 | if (mlirAttributeIsNull(symbolAttr)) |
2341 | throw nb::value_error("Expected operation to have a symbol name."); |
2342 | return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); |
2343 | } |
2344 | |
2345 | MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { |
2346 | // Op must already be a symbol. |
2347 | PyOperation &operation = symbol.getOperation(); |
2348 | operation.checkValid(); |
2349 | MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); |
2350 | MlirAttribute existingNameAttr = |
2351 | mlirOperationGetAttributeByName(operation.get(), attrName); |
2352 | if (mlirAttributeIsNull(existingNameAttr)) |
2353 | throw nb::value_error("Expected operation to have a symbol name."); |
2354 | return existingNameAttr; |
2355 | } |
2356 | |
2357 | void PySymbolTable::setSymbolName(PyOperationBase &symbol, |
2358 | const std::string &name) { |
2359 | // Op must already be a symbol. |
2360 | PyOperation &operation = symbol.getOperation(); |
2361 | operation.checkValid(); |
2362 | MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); |
2363 | MlirAttribute existingNameAttr = |
2364 | mlirOperationGetAttributeByName(operation.get(), attrName); |
2365 | if (mlirAttributeIsNull(existingNameAttr)) |
2366 | throw nb::value_error("Expected operation to have a symbol name."); |
2367 | MlirAttribute newNameAttr = |
2368 | mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); |
2369 | mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); |
2370 | } |
2371 | |
2372 | MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { |
2373 | PyOperation &operation = symbol.getOperation(); |
2374 | operation.checkValid(); |
2375 | MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); |
2376 | MlirAttribute existingVisAttr = |
2377 | mlirOperationGetAttributeByName(operation.get(), attrName); |
2378 | if (mlirAttributeIsNull(existingVisAttr)) |
2379 | throw nb::value_error("Expected operation to have a symbol visibility."); |
2380 | return existingVisAttr; |
2381 | } |
2382 | |
2383 | void PySymbolTable::setVisibility(PyOperationBase &symbol, |
2384 | const std::string &visibility) { |
2385 | if (visibility != "public"&& visibility != "private"&& |
2386 | visibility != "nested") |
2387 | throw nb::value_error( |
2388 | "Expected visibility to be 'public', 'private' or 'nested'"); |
2389 | PyOperation &operation = symbol.getOperation(); |
2390 | operation.checkValid(); |
2391 | MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); |
2392 | MlirAttribute existingVisAttr = |
2393 | mlirOperationGetAttributeByName(operation.get(), attrName); |
2394 | if (mlirAttributeIsNull(existingVisAttr)) |
2395 | throw nb::value_error("Expected operation to have a symbol visibility."); |
2396 | MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), |
2397 | toMlirStringRef(visibility)); |
2398 | mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); |
2399 | } |
2400 | |
2401 | void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, |
2402 | const std::string &newSymbol, |
2403 | PyOperationBase &from) { |
2404 | PyOperation &fromOperation = from.getOperation(); |
2405 | fromOperation.checkValid(); |
2406 | if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( |
2407 | toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), |
2408 | from.getOperation()))) |
2409 | |
2410 | throw nb::value_error("Symbol rename failed"); |
2411 | } |
2412 | |
2413 | void PySymbolTable::walkSymbolTables(PyOperationBase &from, |
2414 | bool allSymUsesVisible, |
2415 | nb::object callback) { |
2416 | PyOperation &fromOperation = from.getOperation(); |
2417 | fromOperation.checkValid(); |
2418 | struct UserData { |
2419 | PyMlirContextRef context; |
2420 | nb::object callback; |
2421 | bool gotException; |
2422 | std::string exceptionWhat; |
2423 | nb::object exceptionType; |
2424 | }; |
2425 | UserData userData{ |
2426 | fromOperation.getContext(), std::move(callback), false, {}, {}}; |
2427 | mlirSymbolTableWalkSymbolTables( |
2428 | fromOperation.get(), allSymUsesVisible, |
2429 | [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { |
2430 | UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); |
2431 | auto pyFoundOp = |
2432 | PyOperation::forOperation(calleeUserData->context, foundOp); |
2433 | if (calleeUserData->gotException) |
2434 | return; |
2435 | try { |
2436 | calleeUserData->callback(pyFoundOp.getObject(), isVisible); |
2437 | } catch (nb::python_error &e) { |
2438 | calleeUserData->gotException = true; |
2439 | calleeUserData->exceptionWhat = e.what(); |
2440 | calleeUserData->exceptionType = nb::borrow(e.type()); |
2441 | } |
2442 | }, |
2443 | static_cast<void *>(&userData)); |
2444 | if (userData.gotException) { |
2445 | std::string message("Exception raised in callback: "); |
2446 | message.append(str: userData.exceptionWhat); |
2447 | throw std::runtime_error(message); |
2448 | } |
2449 | } |
2450 | |
2451 | namespace { |
2452 | |
2453 | /// Python wrapper for MlirBlockArgument. |
2454 | class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { |
2455 | public: |
2456 | static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; |
2457 | static constexpr const char *pyClassName = "BlockArgument"; |
2458 | using PyConcreteValue::PyConcreteValue; |
2459 | |
2460 | static void bindDerived(ClassTy &c) { |
2461 | c.def_prop_ro("owner", [](PyBlockArgument &self) { |
2462 | return PyBlock(self.getParentOperation(), |
2463 | mlirBlockArgumentGetOwner(self.get())); |
2464 | }); |
2465 | c.def_prop_ro("arg_number", [](PyBlockArgument &self) { |
2466 | return mlirBlockArgumentGetArgNumber(self.get()); |
2467 | }); |
2468 | c.def( |
2469 | "set_type", |
2470 | [](PyBlockArgument &self, PyType type) { |
2471 | return mlirBlockArgumentSetType(self.get(), type); |
2472 | }, |
2473 | nb::arg("type")); |
2474 | } |
2475 | }; |
2476 | |
2477 | /// A list of block arguments. Internally, these are stored as consecutive |
2478 | /// elements, random access is cheap. The argument list is associated with the |
2479 | /// operation that contains the block (detached blocks are not allowed in |
2480 | /// Python bindings) and extends its lifetime. |
2481 | class PyBlockArgumentList |
2482 | : public Sliceable<PyBlockArgumentList, PyBlockArgument> { |
2483 | public: |
2484 | static constexpr const char *pyClassName = "BlockArgumentList"; |
2485 | using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>; |
2486 | |
2487 | PyBlockArgumentList(PyOperationRef operation, MlirBlock block, |
2488 | intptr_t startIndex = 0, intptr_t length = -1, |
2489 | intptr_t step = 1) |
2490 | : Sliceable(startIndex, |
2491 | length == -1 ? mlirBlockGetNumArguments(block) : length, |
2492 | step), |
2493 | operation(std::move(operation)), block(block) {} |
2494 | |
2495 | static void bindDerived(ClassTy &c) { |
2496 | c.def_prop_ro("types", [](PyBlockArgumentList &self) { |
2497 | return getValueTypes(self, self.operation->getContext()); |
2498 | }); |
2499 | } |
2500 | |
2501 | private: |
2502 | /// Give the parent CRTP class access to hook implementations below. |
2503 | friend class Sliceable<PyBlockArgumentList, PyBlockArgument>; |
2504 | |
2505 | /// Returns the number of arguments in the list. |
2506 | intptr_t getRawNumElements() { |
2507 | operation->checkValid(); |
2508 | return mlirBlockGetNumArguments(block); |
2509 | } |
2510 | |
2511 | /// Returns `pos`-the element in the list. |
2512 | PyBlockArgument getRawElement(intptr_t pos) { |
2513 | MlirValue argument = mlirBlockGetArgument(block, pos); |
2514 | return PyBlockArgument(operation, argument); |
2515 | } |
2516 | |
2517 | /// Returns a sublist of this list. |
2518 | PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, |
2519 | intptr_t step) { |
2520 | return PyBlockArgumentList(operation, block, startIndex, length, step); |
2521 | } |
2522 | |
2523 | PyOperationRef operation; |
2524 | MlirBlock block; |
2525 | }; |
2526 | |
2527 | /// A list of operation operands. Internally, these are stored as consecutive |
2528 | /// elements, random access is cheap. The (returned) operand list is associated |
2529 | /// with the operation whose operands these are, and thus extends the lifetime |
2530 | /// of this operation. |
2531 | class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { |
2532 | public: |
2533 | static constexpr const char *pyClassName = "OpOperandList"; |
2534 | using SliceableT = Sliceable<PyOpOperandList, PyValue>; |
2535 | |
2536 | PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, |
2537 | intptr_t length = -1, intptr_t step = 1) |
2538 | : Sliceable(startIndex, |
2539 | length == -1 ? mlirOperationGetNumOperands(operation->get()) |
2540 | : length, |
2541 | step), |
2542 | operation(operation) {} |
2543 | |
2544 | void dunderSetItem(intptr_t index, PyValue value) { |
2545 | index = wrapIndex(index); |
2546 | mlirOperationSetOperand(operation->get(), index, value.get()); |
2547 | } |
2548 | |
2549 | static void bindDerived(ClassTy &c) { |
2550 | c.def("__setitem__", &PyOpOperandList::dunderSetItem); |
2551 | } |
2552 | |
2553 | private: |
2554 | /// Give the parent CRTP class access to hook implementations below. |
2555 | friend class Sliceable<PyOpOperandList, PyValue>; |
2556 | |
2557 | intptr_t getRawNumElements() { |
2558 | operation->checkValid(); |
2559 | return mlirOperationGetNumOperands(operation->get()); |
2560 | } |
2561 | |
2562 | PyValue getRawElement(intptr_t pos) { |
2563 | MlirValue operand = mlirOperationGetOperand(operation->get(), pos); |
2564 | MlirOperation owner; |
2565 | if (mlirValueIsAOpResult(operand)) |
2566 | owner = mlirOpResultGetOwner(operand); |
2567 | else if (mlirValueIsABlockArgument(operand)) |
2568 | owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); |
2569 | else |
2570 | assert(false && "Value must be an block arg or op result."); |
2571 | PyOperationRef pyOwner = |
2572 | PyOperation::forOperation(operation->getContext(), owner); |
2573 | return PyValue(pyOwner, operand); |
2574 | } |
2575 | |
2576 | PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
2577 | return PyOpOperandList(operation, startIndex, length, step); |
2578 | } |
2579 | |
2580 | PyOperationRef operation; |
2581 | }; |
2582 | |
2583 | /// A list of operation successors. Internally, these are stored as consecutive |
2584 | /// elements, random access is cheap. The (returned) successor list is |
2585 | /// associated with the operation whose successors these are, and thus extends |
2586 | /// the lifetime of this operation. |
2587 | class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> { |
2588 | public: |
2589 | static constexpr const char *pyClassName = "OpSuccessors"; |
2590 | |
2591 | PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0, |
2592 | intptr_t length = -1, intptr_t step = 1) |
2593 | : Sliceable(startIndex, |
2594 | length == -1 ? mlirOperationGetNumSuccessors(operation->get()) |
2595 | : length, |
2596 | step), |
2597 | operation(operation) {} |
2598 | |
2599 | void dunderSetItem(intptr_t index, PyBlock block) { |
2600 | index = wrapIndex(index); |
2601 | mlirOperationSetSuccessor(operation->get(), index, block.get()); |
2602 | } |
2603 | |
2604 | static void bindDerived(ClassTy &c) { |
2605 | c.def("__setitem__", &PyOpSuccessors::dunderSetItem); |
2606 | } |
2607 | |
2608 | private: |
2609 | /// Give the parent CRTP class access to hook implementations below. |
2610 | friend class Sliceable<PyOpSuccessors, PyBlock>; |
2611 | |
2612 | intptr_t getRawNumElements() { |
2613 | operation->checkValid(); |
2614 | return mlirOperationGetNumSuccessors(operation->get()); |
2615 | } |
2616 | |
2617 | PyBlock getRawElement(intptr_t pos) { |
2618 | MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos); |
2619 | return PyBlock(operation, block); |
2620 | } |
2621 | |
2622 | PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
2623 | return PyOpSuccessors(operation, startIndex, length, step); |
2624 | } |
2625 | |
2626 | PyOperationRef operation; |
2627 | }; |
2628 | |
2629 | /// A list of operation attributes. Can be indexed by name, producing |
2630 | /// attributes, or by index, producing named attributes. |
2631 | class PyOpAttributeMap { |
2632 | public: |
2633 | PyOpAttributeMap(PyOperationRef operation) |
2634 | : operation(std::move(operation)) {} |
2635 | |
2636 | MlirAttribute dunderGetItemNamed(const std::string &name) { |
2637 | MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), |
2638 | toMlirStringRef(name)); |
2639 | if (mlirAttributeIsNull(attr)) { |
2640 | throw nb::key_error("attempt to access a non-existent attribute"); |
2641 | } |
2642 | return attr; |
2643 | } |
2644 | |
2645 | PyNamedAttribute dunderGetItemIndexed(intptr_t index) { |
2646 | if (index < 0) { |
2647 | index += dunderLen(); |
2648 | } |
2649 | if (index < 0 || index >= dunderLen()) { |
2650 | throw nb::index_error("attempt to access out of bounds attribute"); |
2651 | } |
2652 | MlirNamedAttribute namedAttr = |
2653 | mlirOperationGetAttribute(operation->get(), index); |
2654 | return PyNamedAttribute( |
2655 | namedAttr.attribute, |
2656 | std::string(mlirIdentifierStr(namedAttr.name).data, |
2657 | mlirIdentifierStr(namedAttr.name).length)); |
2658 | } |
2659 | |
2660 | void dunderSetItem(const std::string &name, const PyAttribute &attr) { |
2661 | mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), |
2662 | attr); |
2663 | } |
2664 | |
2665 | void dunderDelItem(const std::string &name) { |
2666 | int removed = mlirOperationRemoveAttributeByName(operation->get(), |
2667 | toMlirStringRef(name)); |
2668 | if (!removed) |
2669 | throw nb::key_error("attempt to delete a non-existent attribute"); |
2670 | } |
2671 | |
2672 | intptr_t dunderLen() { |
2673 | return mlirOperationGetNumAttributes(operation->get()); |
2674 | } |
2675 | |
2676 | bool dunderContains(const std::string &name) { |
2677 | return !mlirAttributeIsNull(mlirOperationGetAttributeByName( |
2678 | operation->get(), toMlirStringRef(name))); |
2679 | } |
2680 | |
2681 | static void bind(nb::module_ &m) { |
2682 | nb::class_<PyOpAttributeMap>(m, "OpAttributeMap") |
2683 | .def("__contains__", &PyOpAttributeMap::dunderContains) |
2684 | .def("__len__", &PyOpAttributeMap::dunderLen) |
2685 | .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) |
2686 | .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) |
2687 | .def("__setitem__", &PyOpAttributeMap::dunderSetItem) |
2688 | .def("__delitem__", &PyOpAttributeMap::dunderDelItem); |
2689 | } |
2690 | |
2691 | private: |
2692 | PyOperationRef operation; |
2693 | }; |
2694 | |
2695 | } // namespace |
2696 | |
2697 | //------------------------------------------------------------------------------ |
2698 | // Populates the core exports of the 'ir' submodule. |
2699 | //------------------------------------------------------------------------------ |
2700 | |
2701 | void mlir::python::populateIRCore(nb::module_ &m) { |
2702 | // disable leak warnings which tend to be false positives. |
2703 | nb::set_leak_warnings(false); |
2704 | //---------------------------------------------------------------------------- |
2705 | // Enums. |
2706 | //---------------------------------------------------------------------------- |
2707 | nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity") |
2708 | .value("ERROR", MlirDiagnosticError) |
2709 | .value("WARNING", MlirDiagnosticWarning) |
2710 | .value("NOTE", MlirDiagnosticNote) |
2711 | .value("REMARK", MlirDiagnosticRemark); |
2712 | |
2713 | nb::enum_<MlirWalkOrder>(m, "WalkOrder") |
2714 | .value("PRE_ORDER", MlirWalkPreOrder) |
2715 | .value("POST_ORDER", MlirWalkPostOrder); |
2716 | |
2717 | nb::enum_<MlirWalkResult>(m, "WalkResult") |
2718 | .value("ADVANCE", MlirWalkResultAdvance) |
2719 | .value("INTERRUPT", MlirWalkResultInterrupt) |
2720 | .value("SKIP", MlirWalkResultSkip); |
2721 | |
2722 | //---------------------------------------------------------------------------- |
2723 | // Mapping of Diagnostics. |
2724 | //---------------------------------------------------------------------------- |
2725 | nb::class_<PyDiagnostic>(m, "Diagnostic") |
2726 | .def_prop_ro("severity", &PyDiagnostic::getSeverity) |
2727 | .def_prop_ro("location", &PyDiagnostic::getLocation) |
2728 | .def_prop_ro("message", &PyDiagnostic::getMessage) |
2729 | .def_prop_ro("notes", &PyDiagnostic::getNotes) |
2730 | .def("__str__", [](PyDiagnostic &self) -> nb::str { |
2731 | if (!self.isValid()) |
2732 | return nb::str("<Invalid Diagnostic>"); |
2733 | return self.getMessage(); |
2734 | }); |
2735 | |
2736 | nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo") |
2737 | .def("__init__", |
2738 | [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { |
2739 | new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); |
2740 | }) |
2741 | .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) |
2742 | .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) |
2743 | .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) |
2744 | .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) |
2745 | .def("__str__", |
2746 | [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); |
2747 | |
2748 | nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler") |
2749 | .def("detach", &PyDiagnosticHandler::detach) |
2750 | .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) |
2751 | .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) |
2752 | .def("__enter__", &PyDiagnosticHandler::contextEnter) |
2753 | .def("__exit__", &PyDiagnosticHandler::contextExit, |
2754 | nb::arg("exc_type").none(), nb::arg( "exc_value").none(), |
2755 | nb::arg("traceback").none()); |
2756 | |
2757 | //---------------------------------------------------------------------------- |
2758 | // Mapping of MlirContext. |
2759 | // Note that this is exported as _BaseContext. The containing, Python level |
2760 | // __init__.py will subclass it with site-specific functionality and set a |
2761 | // "Context" attribute on this module. |
2762 | //---------------------------------------------------------------------------- |
2763 | |
2764 | // Expose DefaultThreadPool to python |
2765 | nb::class_<PyThreadPool>(m, "ThreadPool") |
2766 | .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) |
2767 | .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency) |
2768 | .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr); |
2769 | |
2770 | nb::class_<PyMlirContext>(m, "_BaseContext") |
2771 | .def("__init__", |
2772 | [](PyMlirContext &self) { |
2773 | MlirContext context = mlirContextCreateWithThreading(false); |
2774 | new (&self) PyMlirContext(context); |
2775 | }) |
2776 | .def_static("_get_live_count", &PyMlirContext::getLiveCount) |
2777 | .def("_get_context_again", |
2778 | [](PyMlirContext &self) { |
2779 | PyMlirContextRef ref = PyMlirContext::forContext(self.get()); |
2780 | return ref.releaseObject(); |
2781 | }) |
2782 | .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) |
2783 | .def("_get_live_operation_objects", |
2784 | &PyMlirContext::getLiveOperationObjects) |
2785 | .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) |
2786 | .def("_clear_live_operations_inside", |
2787 | nb::overload_cast<MlirOperation>( |
2788 | &PyMlirContext::clearOperationsInside)) |
2789 | .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) |
2790 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) |
2791 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) |
2792 | .def("__enter__", &PyMlirContext::contextEnter) |
2793 | .def("__exit__", &PyMlirContext::contextExit, nb::arg( "exc_type").none(), |
2794 | nb::arg("exc_value").none(), nb::arg( "traceback").none()) |
2795 | .def_prop_ro_static( |
2796 | "current", |
2797 | [](nb::object & /*class*/) { |
2798 | auto *context = PyThreadContextEntry::getDefaultContext(); |
2799 | if (!context) |
2800 | return nb::none(); |
2801 | return nb::cast(context); |
2802 | }, |
2803 | "Gets the Context bound to the current thread or raises ValueError") |
2804 | .def_prop_ro( |
2805 | "dialects", |
2806 | [](PyMlirContext &self) { return PyDialects(self.getRef()); }, |
2807 | "Gets a container for accessing dialects by name") |
2808 | .def_prop_ro( |
2809 | "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, |
2810 | "Alias for 'dialect'") |
2811 | .def( |
2812 | "get_dialect_descriptor", |
2813 | [=](PyMlirContext &self, std::string &name) { |
2814 | MlirDialect dialect = mlirContextGetOrLoadDialect( |
2815 | self.get(), {name.data(), name.size()}); |
2816 | if (mlirDialectIsNull(dialect)) { |
2817 | throw nb::value_error( |
2818 | (Twine("Dialect '") + name + "' not found").str().c_str()); |
2819 | } |
2820 | return PyDialectDescriptor(self.getRef(), dialect); |
2821 | }, |
2822 | nb::arg("dialect_name"), |
2823 | "Gets or loads a dialect by name, returning its descriptor object") |
2824 | .def_prop_rw( |
2825 | "allow_unregistered_dialects", |
2826 | [](PyMlirContext &self) -> bool { |
2827 | return mlirContextGetAllowUnregisteredDialects(self.get()); |
2828 | }, |
2829 | [](PyMlirContext &self, bool value) { |
2830 | mlirContextSetAllowUnregisteredDialects(self.get(), value); |
2831 | }) |
2832 | .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, |
2833 | nb::arg("callback"), |
2834 | "Attaches a diagnostic handler that will receive callbacks") |
2835 | .def( |
2836 | "enable_multithreading", |
2837 | [](PyMlirContext &self, bool enable) { |
2838 | mlirContextEnableMultithreading(self.get(), enable); |
2839 | }, |
2840 | nb::arg("enable")) |
2841 | .def("set_thread_pool", |
2842 | [](PyMlirContext &self, PyThreadPool &pool) { |
2843 | // we should disable multi-threading first before setting |
2844 | // new thread pool otherwise the assert in |
2845 | // MLIRContext::setThreadPool will be raised. |
2846 | mlirContextEnableMultithreading(self.get(), false); |
2847 | mlirContextSetThreadPool(self.get(), pool.get()); |
2848 | }) |
2849 | .def("get_num_threads", |
2850 | [](PyMlirContext &self) { |
2851 | return mlirContextGetNumThreads(self.get()); |
2852 | }) |
2853 | .def("_mlir_thread_pool_ptr", |
2854 | [](PyMlirContext &self) { |
2855 | MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); |
2856 | std::stringstream ss; |
2857 | ss << pool.ptr; |
2858 | return ss.str(); |
2859 | }) |
2860 | .def( |
2861 | "is_registered_operation", |
2862 | [](PyMlirContext &self, std::string &name) { |
2863 | return mlirContextIsRegisteredOperation( |
2864 | self.get(), MlirStringRef{name.data(), name.size()}); |
2865 | }, |
2866 | nb::arg("operation_name")) |
2867 | .def( |
2868 | "append_dialect_registry", |
2869 | [](PyMlirContext &self, PyDialectRegistry ®istry) { |
2870 | mlirContextAppendDialectRegistry(self.get(), registry); |
2871 | }, |
2872 | nb::arg("registry")) |
2873 | .def_prop_rw("emit_error_diagnostics", nullptr, |
2874 | &PyMlirContext::setEmitErrorDiagnostics, |
2875 | "Emit error diagnostics to diagnostic handlers. By default " |
2876 | "error diagnostics are captured and reported through " |
2877 | "MLIRError exceptions.") |
2878 | .def("load_all_available_dialects", [](PyMlirContext &self) { |
2879 | mlirContextLoadAllAvailableDialects(self.get()); |
2880 | }); |
2881 | |
2882 | //---------------------------------------------------------------------------- |
2883 | // Mapping of PyDialectDescriptor |
2884 | //---------------------------------------------------------------------------- |
2885 | nb::class_<PyDialectDescriptor>(m, "DialectDescriptor") |
2886 | .def_prop_ro("namespace", |
2887 | [](PyDialectDescriptor &self) { |
2888 | MlirStringRef ns = mlirDialectGetNamespace(self.get()); |
2889 | return nb::str(ns.data, ns.length); |
2890 | }) |
2891 | .def("__repr__", [](PyDialectDescriptor &self) { |
2892 | MlirStringRef ns = mlirDialectGetNamespace(self.get()); |
2893 | std::string repr("<DialectDescriptor "); |
2894 | repr.append(ns.data, ns.length); |
2895 | repr.append(">"); |
2896 | return repr; |
2897 | }); |
2898 | |
2899 | //---------------------------------------------------------------------------- |
2900 | // Mapping of PyDialects |
2901 | //---------------------------------------------------------------------------- |
2902 | nb::class_<PyDialects>(m, "Dialects") |
2903 | .def("__getitem__", |
2904 | [=](PyDialects &self, std::string keyName) { |
2905 | MlirDialect dialect = |
2906 | self.getDialectForKey(keyName, /*attrError=*/false); |
2907 | nb::object descriptor = |
2908 | nb::cast(PyDialectDescriptor{self.getContext(), dialect}); |
2909 | return createCustomDialectWrapper(keyName, std::move(descriptor)); |
2910 | }) |
2911 | .def("__getattr__", [=](PyDialects &self, std::string attrName) { |
2912 | MlirDialect dialect = |
2913 | self.getDialectForKey(attrName, /*attrError=*/true); |
2914 | nb::object descriptor = |
2915 | nb::cast(PyDialectDescriptor{self.getContext(), dialect}); |
2916 | return createCustomDialectWrapper(attrName, std::move(descriptor)); |
2917 | }); |
2918 | |
2919 | //---------------------------------------------------------------------------- |
2920 | // Mapping of PyDialect |
2921 | //---------------------------------------------------------------------------- |
2922 | nb::class_<PyDialect>(m, "Dialect") |
2923 | .def(nb::init<nb::object>(), nb::arg("descriptor")) |
2924 | .def_prop_ro("descriptor", |
2925 | [](PyDialect &self) { return self.getDescriptor(); }) |
2926 | .def("__repr__", [](nb::object self) { |
2927 | auto clazz = self.attr("__class__"); |
2928 | return nb::str("<Dialect ") + |
2929 | self.attr("descriptor").attr( "namespace") + nb::str( " (class ") + |
2930 | clazz.attr("__module__") + nb::str( ".") + |
2931 | clazz.attr("__name__") + nb::str( ")>"); |
2932 | }); |
2933 | |
2934 | //---------------------------------------------------------------------------- |
2935 | // Mapping of PyDialectRegistry |
2936 | //---------------------------------------------------------------------------- |
2937 | nb::class_<PyDialectRegistry>(m, "DialectRegistry") |
2938 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) |
2939 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) |
2940 | .def(nb::init<>()); |
2941 | |
2942 | //---------------------------------------------------------------------------- |
2943 | // Mapping of Location |
2944 | //---------------------------------------------------------------------------- |
2945 | nb::class_<PyLocation>(m, "Location") |
2946 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) |
2947 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) |
2948 | .def("__enter__", &PyLocation::contextEnter) |
2949 | .def("__exit__", &PyLocation::contextExit, nb::arg( "exc_type").none(), |
2950 | nb::arg("exc_value").none(), nb::arg( "traceback").none()) |
2951 | .def("__eq__", |
2952 | [](PyLocation &self, PyLocation &other) -> bool { |
2953 | return mlirLocationEqual(self, other); |
2954 | }) |
2955 | .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) |
2956 | .def_prop_ro_static( |
2957 | "current", |
2958 | [](nb::object & /*class*/) { |
2959 | auto *loc = PyThreadContextEntry::getDefaultLocation(); |
2960 | if (!loc) |
2961 | throw nb::value_error("No current Location"); |
2962 | return loc; |
2963 | }, |
2964 | "Gets the Location bound to the current thread or raises ValueError") |
2965 | .def_static( |
2966 | "unknown", |
2967 | [](DefaultingPyMlirContext context) { |
2968 | return PyLocation(context->getRef(), |
2969 | mlirLocationUnknownGet(context->get())); |
2970 | }, |
2971 | nb::arg("context").none() = nb::none(), |
2972 | "Gets a Location representing an unknown location") |
2973 | .def_static( |
2974 | "callsite", |
2975 | [](PyLocation callee, const std::vector<PyLocation> &frames, |
2976 | DefaultingPyMlirContext context) { |
2977 | if (frames.empty()) |
2978 | throw nb::value_error("No caller frames provided"); |
2979 | MlirLocation caller = frames.back().get(); |
2980 | for (const PyLocation &frame : |
2981 | llvm::reverse(llvm::ArrayRef(frames).drop_back())) |
2982 | caller = mlirLocationCallSiteGet(frame.get(), caller); |
2983 | return PyLocation(context->getRef(), |
2984 | mlirLocationCallSiteGet(callee.get(), caller)); |
2985 | }, |
2986 | nb::arg("callee"), nb::arg( "frames"), |
2987 | nb::arg("context").none() = nb::none(), |
2988 | kContextGetCallSiteLocationDocstring) |
2989 | .def("is_a_callsite", mlirLocationIsACallSite) |
2990 | .def_prop_ro("callee", mlirLocationCallSiteGetCallee) |
2991 | .def_prop_ro("caller", mlirLocationCallSiteGetCaller) |
2992 | .def_static( |
2993 | "file", |
2994 | [](std::string filename, int line, int col, |
2995 | DefaultingPyMlirContext context) { |
2996 | return PyLocation( |
2997 | context->getRef(), |
2998 | mlirLocationFileLineColGet( |
2999 | context->get(), toMlirStringRef(filename), line, col)); |
3000 | }, |
3001 | nb::arg("filename"), nb::arg( "line"), nb::arg( "col"), |
3002 | nb::arg("context").none() = nb::none(), |
3003 | kContextGetFileLocationDocstring) |
3004 | .def_static( |
3005 | "file", |
3006 | [](std::string filename, int startLine, int startCol, int endLine, |
3007 | int endCol, DefaultingPyMlirContext context) { |
3008 | return PyLocation(context->getRef(), |
3009 | mlirLocationFileLineColRangeGet( |
3010 | context->get(), toMlirStringRef(filename), |
3011 | startLine, startCol, endLine, endCol)); |
3012 | }, |
3013 | nb::arg("filename"), nb::arg( "start_line"), nb::arg( "start_col"), |
3014 | nb::arg("end_line"), nb::arg( "end_col"), |
3015 | nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring) |
3016 | .def("is_a_file", mlirLocationIsAFileLineColRange) |
3017 | .def_prop_ro("filename", |
3018 | [](MlirLocation loc) { |
3019 | return mlirIdentifierStr( |
3020 | mlirLocationFileLineColRangeGetFilename(loc)); |
3021 | }) |
3022 | .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine) |
3023 | .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn) |
3024 | .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine) |
3025 | .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn) |
3026 | .def_static( |
3027 | "fused", |
3028 | [](const std::vector<PyLocation> &pyLocations, |
3029 | std::optional<PyAttribute> metadata, |
3030 | DefaultingPyMlirContext context) { |
3031 | llvm::SmallVector<MlirLocation, 4> locations; |
3032 | locations.reserve(pyLocations.size()); |
3033 | for (auto &pyLocation : pyLocations) |
3034 | locations.push_back(pyLocation.get()); |
3035 | MlirLocation location = mlirLocationFusedGet( |
3036 | context->get(), locations.size(), locations.data(), |
3037 | metadata ? metadata->get() : MlirAttribute{0}); |
3038 | return PyLocation(context->getRef(), location); |
3039 | }, |
3040 | nb::arg("locations"), nb::arg( "metadata").none() = nb::none(), |
3041 | nb::arg("context").none() = nb::none(), |
3042 | kContextGetFusedLocationDocstring) |
3043 | .def("is_a_fused", mlirLocationIsAFused) |
3044 | .def_prop_ro("locations", |
3045 | [](MlirLocation loc) { |
3046 | unsigned numLocations = |
3047 | mlirLocationFusedGetNumLocations(loc); |
3048 | std::vector<MlirLocation> locations(numLocations); |
3049 | if (numLocations) |
3050 | mlirLocationFusedGetLocations(loc, locations.data()); |
3051 | return locations; |
3052 | }) |
3053 | .def_static( |
3054 | "name", |
3055 | [](std::string name, std::optional<PyLocation> childLoc, |
3056 | DefaultingPyMlirContext context) { |
3057 | return PyLocation( |
3058 | context->getRef(), |
3059 | mlirLocationNameGet( |
3060 | context->get(), toMlirStringRef(name), |
3061 | childLoc ? childLoc->get() |
3062 | : mlirLocationUnknownGet(context->get()))); |
3063 | }, |
3064 | nb::arg("name"), nb::arg( "childLoc").none() = nb::none(), |
3065 | nb::arg("context").none() = nb::none(), |
3066 | kContextGetNameLocationDocString) |
3067 | .def("is_a_name", mlirLocationIsAName) |
3068 | .def_prop_ro("name_str", |
3069 | [](MlirLocation loc) { |
3070 | return mlirIdentifierStr(mlirLocationNameGetName(loc)); |
3071 | }) |
3072 | .def_prop_ro("child_loc", mlirLocationNameGetChildLoc) |
3073 | .def_static( |
3074 | "from_attr", |
3075 | [](PyAttribute &attribute, DefaultingPyMlirContext context) { |
3076 | return PyLocation(context->getRef(), |
3077 | mlirLocationFromAttribute(attribute)); |
3078 | }, |
3079 | nb::arg("attribute"), nb::arg( "context").none() = nb::none(), |
3080 | "Gets a Location from a LocationAttr") |
3081 | .def_prop_ro( |
3082 | "context", |
3083 | [](PyLocation &self) { return self.getContext().getObject(); }, |
3084 | "Context that owns the Location") |
3085 | .def_prop_ro( |
3086 | "attr", |
3087 | [](PyLocation &self) { return mlirLocationGetAttribute(self); }, |
3088 | "Get the underlying LocationAttr") |
3089 | .def( |
3090 | "emit_error", |
3091 | [](PyLocation &self, std::string message) { |
3092 | mlirEmitError(self, message.c_str()); |
3093 | }, |
3094 | nb::arg("message"), "Emits an error at this location") |
3095 | .def("__repr__", [](PyLocation &self) { |
3096 | PyPrintAccumulator printAccum; |
3097 | mlirLocationPrint(self, printAccum.getCallback(), |
3098 | printAccum.getUserData()); |
3099 | return printAccum.join(); |
3100 | }); |
3101 | |
3102 | //---------------------------------------------------------------------------- |
3103 | // Mapping of Module |
3104 | //---------------------------------------------------------------------------- |
3105 | nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable()) |
3106 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) |
3107 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) |
3108 | .def_static( |
3109 | "parse", |
3110 | [](const std::string &moduleAsm, DefaultingPyMlirContext context) { |
3111 | PyMlirContext::ErrorCapture errors(context->getRef()); |
3112 | MlirModule module = mlirModuleCreateParse( |
3113 | context->get(), toMlirStringRef(moduleAsm)); |
3114 | if (mlirModuleIsNull(module)) |
3115 | throw MLIRError("Unable to parse module assembly", errors.take()); |
3116 | return PyModule::forModule(module).releaseObject(); |
3117 | }, |
3118 | nb::arg("asm"), nb::arg( "context").none() = nb::none(), |
3119 | kModuleParseDocstring) |
3120 | .def_static( |
3121 | "parse", |
3122 | [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { |
3123 | PyMlirContext::ErrorCapture errors(context->getRef()); |
3124 | MlirModule module = mlirModuleCreateParse( |
3125 | context->get(), toMlirStringRef(moduleAsm)); |
3126 | if (mlirModuleIsNull(module)) |
3127 | throw MLIRError("Unable to parse module assembly", errors.take()); |
3128 | return PyModule::forModule(module).releaseObject(); |
3129 | }, |
3130 | nb::arg("asm"), nb::arg( "context").none() = nb::none(), |
3131 | kModuleParseDocstring) |
3132 | .def_static( |
3133 | "parseFile", |
3134 | [](const std::string &path, DefaultingPyMlirContext context) { |
3135 | PyMlirContext::ErrorCapture errors(context->getRef()); |
3136 | MlirModule module = mlirModuleCreateParseFromFile( |
3137 | context->get(), toMlirStringRef(path)); |
3138 | if (mlirModuleIsNull(module)) |
3139 | throw MLIRError("Unable to parse module assembly", errors.take()); |
3140 | return PyModule::forModule(module).releaseObject(); |
3141 | }, |
3142 | nb::arg("path"), nb::arg( "context").none() = nb::none(), |
3143 | kModuleParseDocstring) |
3144 | .def_static( |
3145 | "create", |
3146 | [](DefaultingPyLocation loc) { |
3147 | MlirModule module = mlirModuleCreateEmpty(loc); |
3148 | return PyModule::forModule(module).releaseObject(); |
3149 | }, |
3150 | nb::arg("loc").none() = nb::none(), "Creates an empty module") |
3151 | .def_prop_ro( |
3152 | "context", |
3153 | [](PyModule &self) { return self.getContext().getObject(); }, |
3154 | "Context that created the Module") |
3155 | .def_prop_ro( |
3156 | "operation", |
3157 | [](PyModule &self) { |
3158 | return PyOperation::forOperation(self.getContext(), |
3159 | mlirModuleGetOperation(self.get()), |
3160 | self.getRef().releaseObject()) |
3161 | .releaseObject(); |
3162 | }, |
3163 | "Accesses the module as an operation") |
3164 | .def_prop_ro( |
3165 | "body", |
3166 | [](PyModule &self) { |
3167 | PyOperationRef moduleOp = PyOperation::forOperation( |
3168 | self.getContext(), mlirModuleGetOperation(self.get()), |
3169 | self.getRef().releaseObject()); |
3170 | PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); |
3171 | return returnBlock; |
3172 | }, |
3173 | "Return the block for this module") |
3174 | .def( |
3175 | "dump", |
3176 | [](PyModule &self) { |
3177 | mlirOperationDump(mlirModuleGetOperation(self.get())); |
3178 | }, |
3179 | kDumpDocstring) |
3180 | .def( |
3181 | "__str__", |
3182 | [](nb::object self) { |
3183 | // Defer to the operation's __str__. |
3184 | return self.attr("operation").attr( "__str__")(); |
3185 | }, |
3186 | kOperationStrDunderDocstring); |
3187 | |
3188 | //---------------------------------------------------------------------------- |
3189 | // Mapping of Operation. |
3190 | //---------------------------------------------------------------------------- |
3191 | nb::class_<PyOperationBase>(m, "_OperationBase") |
3192 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, |
3193 | [](PyOperationBase &self) { |
3194 | return self.getOperation().getCapsule(); |
3195 | }) |
3196 | .def("__eq__", |
3197 | [](PyOperationBase &self, PyOperationBase &other) { |
3198 | return &self.getOperation() == &other.getOperation(); |
3199 | }) |
3200 | .def("__eq__", |
3201 | [](PyOperationBase &self, nb::object other) { return false; }) |
3202 | .def("__hash__", |
3203 | [](PyOperationBase &self) { |
3204 | return static_cast<size_t>(llvm::hash_value(&self.getOperation())); |
3205 | }) |
3206 | .def_prop_ro("attributes", |
3207 | [](PyOperationBase &self) { |
3208 | return PyOpAttributeMap(self.getOperation().getRef()); |
3209 | }) |
3210 | .def_prop_ro( |
3211 | "context", |
3212 | [](PyOperationBase &self) { |
3213 | PyOperation &concreteOperation = self.getOperation(); |
3214 | concreteOperation.checkValid(); |
3215 | return concreteOperation.getContext().getObject(); |
3216 | }, |
3217 | "Context that owns the Operation") |
3218 | .def_prop_ro("name", |
3219 | [](PyOperationBase &self) { |
3220 | auto &concreteOperation = self.getOperation(); |
3221 | concreteOperation.checkValid(); |
3222 | MlirOperation operation = concreteOperation.get(); |
3223 | return mlirIdentifierStr(mlirOperationGetName(operation)); |
3224 | }) |
3225 | .def_prop_ro("operands", |
3226 | [](PyOperationBase &self) { |
3227 | return PyOpOperandList(self.getOperation().getRef()); |
3228 | }) |
3229 | .def_prop_ro("regions", |
3230 | [](PyOperationBase &self) { |
3231 | return PyRegionList(self.getOperation().getRef()); |
3232 | }) |
3233 | .def_prop_ro( |
3234 | "results", |
3235 | [](PyOperationBase &self) { |
3236 | return PyOpResultList(self.getOperation().getRef()); |
3237 | }, |
3238 | "Returns the list of Operation results.") |
3239 | .def_prop_ro( |
3240 | "result", |
3241 | [](PyOperationBase &self) { |
3242 | auto &operation = self.getOperation(); |
3243 | return PyOpResult(operation.getRef(), getUniqueResult(operation)) |
3244 | .maybeDownCast(); |
3245 | }, |
3246 | "Shortcut to get an op result if it has only one (throws an error " |
3247 | "otherwise).") |
3248 | .def_prop_ro( |
3249 | "location", |
3250 | [](PyOperationBase &self) { |
3251 | PyOperation &operation = self.getOperation(); |
3252 | return PyLocation(operation.getContext(), |
3253 | mlirOperationGetLocation(operation.get())); |
3254 | }, |
3255 | "Returns the source location the operation was defined or derived " |
3256 | "from.") |
3257 | .def_prop_ro("parent", |
3258 | [](PyOperationBase &self) -> nb::object { |
3259 | auto parent = self.getOperation().getParentOperation(); |
3260 | if (parent) |
3261 | return parent->getObject(); |
3262 | return nb::none(); |
3263 | }) |
3264 | .def( |
3265 | "__str__", |
3266 | [](PyOperationBase &self) { |
3267 | return self.getAsm(/*binary=*/false, |
3268 | /*largeElementsLimit=*/std::nullopt, |
3269 | /*enableDebugInfo=*/false, |
3270 | /*prettyDebugInfo=*/false, |
3271 | /*printGenericOpForm=*/false, |
3272 | /*useLocalScope=*/false, |
3273 | /*useNameLocAsPrefix=*/false, |
3274 | /*assumeVerified=*/false, |
3275 | /*skipRegions=*/false); |
3276 | }, |
3277 | "Returns the assembly form of the operation.") |
3278 | .def("print", |
3279 | nb::overload_cast<PyAsmState &, nb::object, bool>( |
3280 | &PyOperationBase::print), |
3281 | nb::arg("state"), nb::arg( "file").none() = nb::none(), |
3282 | nb::arg("binary") = false, kOperationPrintStateDocstring) |
3283 | .def("print", |
3284 | nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool, |
3285 | bool, bool, nb::object, bool, bool>( |
3286 | &PyOperationBase::print), |
3287 | // Careful: Lots of arguments must match up with print method. |
3288 | nb::arg("large_elements_limit").none() = nb::none(), |
3289 | nb::arg("enable_debug_info") = false, |
3290 | nb::arg("pretty_debug_info") = false, |
3291 | nb::arg("print_generic_op_form") = false, |
3292 | nb::arg("use_local_scope") = false, |
3293 | nb::arg("use_name_loc_as_prefix") = false, |
3294 | nb::arg("assume_verified") = false, |
3295 | nb::arg("file").none() = nb::none(), nb::arg( "binary") = false, |
3296 | nb::arg("skip_regions") = false, kOperationPrintDocstring) |
3297 | .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg( "file"), |
3298 | nb::arg("desired_version").none() = nb::none(), |
3299 | kOperationPrintBytecodeDocstring) |
3300 | .def("get_asm", &PyOperationBase::getAsm, |
3301 | // Careful: Lots of arguments must match up with get_asm method. |
3302 | nb::arg("binary") = false, |
3303 | nb::arg("large_elements_limit").none() = nb::none(), |
3304 | nb::arg("enable_debug_info") = false, |
3305 | nb::arg("pretty_debug_info") = false, |
3306 | nb::arg("print_generic_op_form") = false, |
3307 | nb::arg("use_local_scope") = false, |
3308 | nb::arg("use_name_loc_as_prefix") = false, |
3309 | nb::arg("assume_verified") = false, nb::arg( "skip_regions") = false, |
3310 | kOperationGetAsmDocstring) |
3311 | .def("verify", &PyOperationBase::verify, |
3312 | "Verify the operation. Raises MLIRError if verification fails, and " |
3313 | "returns true otherwise.") |
3314 | .def("move_after", &PyOperationBase::moveAfter, nb::arg( "other"), |
3315 | "Puts self immediately after the other operation in its parent " |
3316 | "block.") |
3317 | .def("move_before", &PyOperationBase::moveBefore, nb::arg( "other"), |
3318 | "Puts self immediately before the other operation in its parent " |
3319 | "block.") |
3320 | .def( |
3321 | "clone", |
3322 | [](PyOperationBase &self, nb::object ip) { |
3323 | return self.getOperation().clone(ip); |
3324 | }, |
3325 | nb::arg("ip").none() = nb::none()) |
3326 | .def( |
3327 | "detach_from_parent", |
3328 | [](PyOperationBase &self) { |
3329 | PyOperation &operation = self.getOperation(); |
3330 | operation.checkValid(); |
3331 | if (!operation.isAttached()) |
3332 | throw nb::value_error("Detached operation has no parent."); |
3333 | |
3334 | operation.detachFromParent(); |
3335 | return operation.createOpView(); |
3336 | }, |
3337 | "Detaches the operation from its parent block.") |
3338 | .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) |
3339 | .def("walk", &PyOperationBase::walk, nb::arg( "callback"), |
3340 | nb::arg("walk_order") = MlirWalkPostOrder); |
3341 | |
3342 | nb::class_<PyOperation, PyOperationBase>(m, "Operation") |
3343 | .def_static( |
3344 | "create", |
3345 | [](std::string_view name, |
3346 | std::optional<std::vector<PyType *>> results, |
3347 | std::optional<std::vector<PyValue *>> operands, |
3348 | std::optional<nb::dict> attributes, |
3349 | std::optional<std::vector<PyBlock *>> successors, int regions, |
3350 | DefaultingPyLocation location, const nb::object &maybeIp, |
3351 | bool inferType) { |
3352 | // Unpack/validate operands. |
3353 | llvm::SmallVector<MlirValue, 4> mlirOperands; |
3354 | if (operands) { |
3355 | mlirOperands.reserve(operands->size()); |
3356 | for (PyValue *operand : *operands) { |
3357 | if (!operand) |
3358 | throw nb::value_error("operand value cannot be None"); |
3359 | mlirOperands.push_back(operand->get()); |
3360 | } |
3361 | } |
3362 | |
3363 | return PyOperation::create(name, results, mlirOperands, attributes, |
3364 | successors, regions, location, maybeIp, |
3365 | inferType); |
3366 | }, |
3367 | nb::arg("name"), nb::arg( "results").none() = nb::none(), |
3368 | nb::arg("operands").none() = nb::none(), |
3369 | nb::arg("attributes").none() = nb::none(), |
3370 | nb::arg("successors").none() = nb::none(), nb::arg( "regions") = 0, |
3371 | nb::arg("loc").none() = nb::none(), nb::arg( "ip").none() = nb::none(), |
3372 | nb::arg("infer_type") = false, kOperationCreateDocstring) |
3373 | .def_static( |
3374 | "parse", |
3375 | [](const std::string &sourceStr, const std::string &sourceName, |
3376 | DefaultingPyMlirContext context) { |
3377 | return PyOperation::parse(context->getRef(), sourceStr, sourceName) |
3378 | ->createOpView(); |
3379 | }, |
3380 | nb::arg("source"), nb::kw_only(), nb::arg( "source_name") = "", |
3381 | nb::arg("context").none() = nb::none(), |
3382 | "Parses an operation. Supports both text assembly format and binary " |
3383 | "bytecode format.") |
3384 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) |
3385 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) |
3386 | .def_prop_ro("operation", [](nb::object self) { return self; }) |
3387 | .def_prop_ro("opview", &PyOperation::createOpView) |
3388 | .def_prop_ro( |
3389 | "successors", |
3390 | [](PyOperationBase &self) { |
3391 | return PyOpSuccessors(self.getOperation().getRef()); |
3392 | }, |
3393 | "Returns the list of Operation successors."); |
3394 | |
3395 | auto opViewClass = |
3396 | nb::class_<PyOpView, PyOperationBase>(m, "OpView") |
3397 | .def(nb::init<nb::object>(), nb::arg("operation")) |
3398 | .def( |
3399 | "__init__", |
3400 | [](PyOpView *self, std::string_view name, |
3401 | std::tuple<int, bool> opRegionSpec, |
3402 | nb::object operandSegmentSpecObj, |
3403 | nb::object resultSegmentSpecObj, |
3404 | std::optional<nb::list> resultTypeList, nb::list operandList, |
3405 | std::optional<nb::dict> attributes, |
3406 | std::optional<std::vector<PyBlock *>> successors, |
3407 | std::optional<int> regions, DefaultingPyLocation location, |
3408 | const nb::object &maybeIp) { |
3409 | new (self) PyOpView(PyOpView::buildGeneric( |
3410 | name, opRegionSpec, operandSegmentSpecObj, |
3411 | resultSegmentSpecObj, resultTypeList, operandList, |
3412 | attributes, successors, regions, location, maybeIp)); |
3413 | }, |
3414 | nb::arg("name"), nb::arg( "opRegionSpec"), |
3415 | nb::arg("operandSegmentSpecObj").none() = nb::none(), |
3416 | nb::arg("resultSegmentSpecObj").none() = nb::none(), |
3417 | nb::arg("results").none() = nb::none(), |
3418 | nb::arg("operands").none() = nb::none(), |
3419 | nb::arg("attributes").none() = nb::none(), |
3420 | nb::arg("successors").none() = nb::none(), |
3421 | nb::arg("regions").none() = nb::none(), |
3422 | nb::arg("loc").none() = nb::none(), |
3423 | nb::arg("ip").none() = nb::none()) |
3424 | |
3425 | .def_prop_ro("operation", &PyOpView::getOperationObject) |
3426 | .def_prop_ro("opview", [](nb::object self) { return self; }) |
3427 | .def( |
3428 | "__str__", |
3429 | [](PyOpView &self) { return nb::str(self.getOperationObject()); }) |
3430 | .def_prop_ro( |
3431 | "successors", |
3432 | [](PyOperationBase &self) { |
3433 | return PyOpSuccessors(self.getOperation().getRef()); |
3434 | }, |
3435 | "Returns the list of Operation successors."); |
3436 | opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); |
3437 | opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); |
3438 | opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); |
3439 | // It is faster to pass the operation_name, ods_regions, and |
3440 | // ods_operand_segments/ods_result_segments as arguments to the constructor, |
3441 | // rather than to access them as attributes. |
3442 | opViewClass.attr("build_generic") = classmethod( |
3443 | [](nb::handle cls, std::optional<nb::list> resultTypeList, |
3444 | nb::list operandList, std::optional<nb::dict> attributes, |
3445 | std::optional<std::vector<PyBlock *>> successors, |
3446 | std::optional<int> regions, DefaultingPyLocation location, |
3447 | const nb::object &maybeIp) { |
3448 | std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME")); |
3449 | std::tuple<int, bool> opRegionSpec = |
3450 | nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); |
3451 | nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS"); |
3452 | nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS"); |
3453 | return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, |
3454 | resultSegmentSpec, resultTypeList, |
3455 | operandList, attributes, successors, |
3456 | regions, location, maybeIp); |
3457 | }, |
3458 | nb::arg("cls"), nb::arg( "results").none() = nb::none(), |
3459 | nb::arg("operands").none() = nb::none(), |
3460 | nb::arg("attributes").none() = nb::none(), |
3461 | nb::arg("successors").none() = nb::none(), |
3462 | nb::arg("regions").none() = nb::none(), |
3463 | nb::arg("loc").none() = nb::none(), nb::arg( "ip").none() = nb::none(), |
3464 | "Builds a specific, generated OpView based on class level attributes."); |
3465 | opViewClass.attr("parse") = classmethod( |
3466 | [](const nb::object &cls, const std::string &sourceStr, |
3467 | const std::string &sourceName, DefaultingPyMlirContext context) { |
3468 | PyOperationRef parsed = |
3469 | PyOperation::parse(context->getRef(), sourceStr, sourceName); |
3470 | |
3471 | // Check if the expected operation was parsed, and cast to to the |
3472 | // appropriate `OpView` subclass if successful. |
3473 | // NOTE: This accesses attributes that have been automatically added to |
3474 | // `OpView` subclasses, and is not intended to be used on `OpView` |
3475 | // directly. |
3476 | std::string clsOpName = |
3477 | nb::cast<std::string>(cls.attr("OPERATION_NAME")); |
3478 | MlirStringRef identifier = |
3479 | mlirIdentifierStr(mlirOperationGetName(*parsed.get())); |
3480 | std::string_view parsedOpName(identifier.data, identifier.length); |
3481 | if (clsOpName != parsedOpName) |
3482 | throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '"+ |
3483 | parsedOpName + "'"); |
3484 | return PyOpView::constructDerived(cls, parsed.getObject()); |
3485 | }, |
3486 | nb::arg("cls"), nb::arg( "source"), nb::kw_only(), |
3487 | nb::arg("source_name") = "", nb::arg( "context").none() = nb::none(), |
3488 | "Parses a specific, generated OpView based on class level attributes"); |
3489 | |
3490 | //---------------------------------------------------------------------------- |
3491 | // Mapping of PyRegion. |
3492 | //---------------------------------------------------------------------------- |
3493 | nb::class_<PyRegion>(m, "Region") |
3494 | .def_prop_ro( |
3495 | "blocks", |
3496 | [](PyRegion &self) { |
3497 | return PyBlockList(self.getParentOperation(), self.get()); |
3498 | }, |
3499 | "Returns a forward-optimized sequence of blocks.") |
3500 | .def_prop_ro( |
3501 | "owner", |
3502 | [](PyRegion &self) { |
3503 | return self.getParentOperation()->createOpView(); |
3504 | }, |
3505 | "Returns the operation owning this region.") |
3506 | .def( |
3507 | "__iter__", |
3508 | [](PyRegion &self) { |
3509 | self.checkValid(); |
3510 | MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); |
3511 | return PyBlockIterator(self.getParentOperation(), firstBlock); |
3512 | }, |
3513 | "Iterates over blocks in the region.") |
3514 | .def("__eq__", |
3515 | [](PyRegion &self, PyRegion &other) { |
3516 | return self.get().ptr == other.get().ptr; |
3517 | }) |
3518 | .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); |
3519 | |
3520 | //---------------------------------------------------------------------------- |
3521 | // Mapping of PyBlock. |
3522 | //---------------------------------------------------------------------------- |
3523 | nb::class_<PyBlock>(m, "Block") |
3524 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) |
3525 | .def_prop_ro( |
3526 | "owner", |
3527 | [](PyBlock &self) { |
3528 | return self.getParentOperation()->createOpView(); |
3529 | }, |
3530 | "Returns the owning operation of this block.") |
3531 | .def_prop_ro( |
3532 | "region", |
3533 | [](PyBlock &self) { |
3534 | MlirRegion region = mlirBlockGetParentRegion(self.get()); |
3535 | return PyRegion(self.getParentOperation(), region); |
3536 | }, |
3537 | "Returns the owning region of this block.") |
3538 | .def_prop_ro( |
3539 | "arguments", |
3540 | [](PyBlock &self) { |
3541 | return PyBlockArgumentList(self.getParentOperation(), self.get()); |
3542 | }, |
3543 | "Returns a list of block arguments.") |
3544 | .def( |
3545 | "add_argument", |
3546 | [](PyBlock &self, const PyType &type, const PyLocation &loc) { |
3547 | return mlirBlockAddArgument(self.get(), type, loc); |
3548 | }, |
3549 | "Append an argument of the specified type to the block and returns " |
3550 | "the newly added argument.") |
3551 | .def( |
3552 | "erase_argument", |
3553 | [](PyBlock &self, unsigned index) { |
3554 | return mlirBlockEraseArgument(self.get(), index); |
3555 | }, |
3556 | "Erase the argument at 'index' and remove it from the argument list.") |
3557 | .def_prop_ro( |
3558 | "operations", |
3559 | [](PyBlock &self) { |
3560 | return PyOperationList(self.getParentOperation(), self.get()); |
3561 | }, |
3562 | "Returns a forward-optimized sequence of operations.") |
3563 | .def_static( |
3564 | "create_at_start", |
3565 | [](PyRegion &parent, const nb::sequence &pyArgTypes, |
3566 | const std::optional<nb::sequence> &pyArgLocs) { |
3567 | parent.checkValid(); |
3568 | MlirBlock block = createBlock(pyArgTypes, pyArgLocs); |
3569 | mlirRegionInsertOwnedBlock(parent, 0, block); |
3570 | return PyBlock(parent.getParentOperation(), block); |
3571 | }, |
3572 | nb::arg("parent"), nb::arg( "arg_types") = nb::list(), |
3573 | nb::arg("arg_locs") = std::nullopt, |
3574 | "Creates and returns a new Block at the beginning of the given " |
3575 | "region (with given argument types and locations).") |
3576 | .def( |
3577 | "append_to", |
3578 | [](PyBlock &self, PyRegion ®ion) { |
3579 | MlirBlock b = self.get(); |
3580 | if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) |
3581 | mlirBlockDetach(b); |
3582 | mlirRegionAppendOwnedBlock(region.get(), b); |
3583 | }, |
3584 | "Append this block to a region, transferring ownership if necessary") |
3585 | .def( |
3586 | "create_before", |
3587 | [](PyBlock &self, const nb::args &pyArgTypes, |
3588 | const std::optional<nb::sequence> &pyArgLocs) { |
3589 | self.checkValid(); |
3590 | MlirBlock block = |
3591 | createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); |
3592 | MlirRegion region = mlirBlockGetParentRegion(self.get()); |
3593 | mlirRegionInsertOwnedBlockBefore(region, self.get(), block); |
3594 | return PyBlock(self.getParentOperation(), block); |
3595 | }, |
3596 | nb::arg("arg_types"), nb::kw_only(), |
3597 | nb::arg("arg_locs") = std::nullopt, |
3598 | "Creates and returns a new Block before this block " |
3599 | "(with given argument types and locations).") |
3600 | .def( |
3601 | "create_after", |
3602 | [](PyBlock &self, const nb::args &pyArgTypes, |
3603 | const std::optional<nb::sequence> &pyArgLocs) { |
3604 | self.checkValid(); |
3605 | MlirBlock block = |
3606 | createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); |
3607 | MlirRegion region = mlirBlockGetParentRegion(self.get()); |
3608 | mlirRegionInsertOwnedBlockAfter(region, self.get(), block); |
3609 | return PyBlock(self.getParentOperation(), block); |
3610 | }, |
3611 | nb::arg("arg_types"), nb::kw_only(), |
3612 | nb::arg("arg_locs") = std::nullopt, |
3613 | "Creates and returns a new Block after this block " |
3614 | "(with given argument types and locations).") |
3615 | .def( |
3616 | "__iter__", |
3617 | [](PyBlock &self) { |
3618 | self.checkValid(); |
3619 | MlirOperation firstOperation = |
3620 | mlirBlockGetFirstOperation(self.get()); |
3621 | return PyOperationIterator(self.getParentOperation(), |
3622 | firstOperation); |
3623 | }, |
3624 | "Iterates over operations in the block.") |
3625 | .def("__eq__", |
3626 | [](PyBlock &self, PyBlock &other) { |
3627 | return self.get().ptr == other.get().ptr; |
3628 | }) |
3629 | .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) |
3630 | .def("__hash__", |
3631 | [](PyBlock &self) { |
3632 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
3633 | }) |
3634 | .def( |
3635 | "__str__", |
3636 | [](PyBlock &self) { |
3637 | self.checkValid(); |
3638 | PyPrintAccumulator printAccum; |
3639 | mlirBlockPrint(self.get(), printAccum.getCallback(), |
3640 | printAccum.getUserData()); |
3641 | return printAccum.join(); |
3642 | }, |
3643 | "Returns the assembly form of the block.") |
3644 | .def( |
3645 | "append", |
3646 | [](PyBlock &self, PyOperationBase &operation) { |
3647 | if (operation.getOperation().isAttached()) |
3648 | operation.getOperation().detachFromParent(); |
3649 | |
3650 | MlirOperation mlirOperation = operation.getOperation().get(); |
3651 | mlirBlockAppendOwnedOperation(self.get(), mlirOperation); |
3652 | operation.getOperation().setAttached( |
3653 | self.getParentOperation().getObject()); |
3654 | }, |
3655 | nb::arg("operation"), |
3656 | "Appends an operation to this block. If the operation is currently " |
3657 | "in another block, it will be moved."); |
3658 | |
3659 | //---------------------------------------------------------------------------- |
3660 | // Mapping of PyInsertionPoint. |
3661 | //---------------------------------------------------------------------------- |
3662 | |
3663 | nb::class_<PyInsertionPoint>(m, "InsertionPoint") |
3664 | .def(nb::init<PyBlock &>(), nb::arg("block"), |
3665 | "Inserts after the last operation but still inside the block.") |
3666 | .def("__enter__", &PyInsertionPoint::contextEnter) |
3667 | .def("__exit__", &PyInsertionPoint::contextExit, |
3668 | nb::arg("exc_type").none(), nb::arg( "exc_value").none(), |
3669 | nb::arg("traceback").none()) |
3670 | .def_prop_ro_static( |
3671 | "current", |
3672 | [](nb::object & /*class*/) { |
3673 | auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); |
3674 | if (!ip) |
3675 | throw nb::value_error("No current InsertionPoint"); |
3676 | return ip; |
3677 | }, |
3678 | "Gets the InsertionPoint bound to the current thread or raises " |
3679 | "ValueError if none has been set") |
3680 | .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"), |
3681 | "Inserts before a referenced operation.") |
3682 | .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, |
3683 | nb::arg("block"), "Inserts at the beginning of the block.") |
3684 | .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, |
3685 | nb::arg("block"), "Inserts before the block terminator.") |
3686 | .def("insert", &PyInsertionPoint::insert, nb::arg( "operation"), |
3687 | "Inserts an operation.") |
3688 | .def_prop_ro( |
3689 | "block", [](PyInsertionPoint &self) { return self.getBlock(); }, |
3690 | "Returns the block that this InsertionPoint points to.") |
3691 | .def_prop_ro( |
3692 | "ref_operation", |
3693 | [](PyInsertionPoint &self) -> nb::object { |
3694 | auto refOperation = self.getRefOperation(); |
3695 | if (refOperation) |
3696 | return refOperation->getObject(); |
3697 | return nb::none(); |
3698 | }, |
3699 | "The reference operation before which new operations are " |
3700 | "inserted, or None if the insertion point is at the end of " |
3701 | "the block"); |
3702 | |
3703 | //---------------------------------------------------------------------------- |
3704 | // Mapping of PyAttribute. |
3705 | //---------------------------------------------------------------------------- |
3706 | nb::class_<PyAttribute>(m, "Attribute") |
3707 | // Delegate to the PyAttribute copy constructor, which will also lifetime |
3708 | // extend the backing context which owns the MlirAttribute. |
3709 | .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"), |
3710 | "Casts the passed attribute to the generic Attribute") |
3711 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) |
3712 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) |
3713 | .def_static( |
3714 | "parse", |
3715 | [](const std::string &attrSpec, DefaultingPyMlirContext context) { |
3716 | PyMlirContext::ErrorCapture errors(context->getRef()); |
3717 | MlirAttribute attr = mlirAttributeParseGet( |
3718 | context->get(), toMlirStringRef(attrSpec)); |
3719 | if (mlirAttributeIsNull(attr)) |
3720 | throw MLIRError("Unable to parse attribute", errors.take()); |
3721 | return attr; |
3722 | }, |
3723 | nb::arg("asm"), nb::arg( "context").none() = nb::none(), |
3724 | "Parses an attribute from an assembly form. Raises an MLIRError on " |
3725 | "failure.") |
3726 | .def_prop_ro( |
3727 | "context", |
3728 | [](PyAttribute &self) { return self.getContext().getObject(); }, |
3729 | "Context that owns the Attribute") |
3730 | .def_prop_ro("type", |
3731 | [](PyAttribute &self) { return mlirAttributeGetType(self); }) |
3732 | .def( |
3733 | "get_named", |
3734 | [](PyAttribute &self, std::string name) { |
3735 | return PyNamedAttribute(self, std::move(name)); |
3736 | }, |
3737 | nb::keep_alive<0, 1>(), "Binds a name to the attribute") |
3738 | .def("__eq__", |
3739 | [](PyAttribute &self, PyAttribute &other) { return self == other; }) |
3740 | .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) |
3741 | .def("__hash__", |
3742 | [](PyAttribute &self) { |
3743 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
3744 | }) |
3745 | .def( |
3746 | "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, |
3747 | kDumpDocstring) |
3748 | .def( |
3749 | "__str__", |
3750 | [](PyAttribute &self) { |
3751 | PyPrintAccumulator printAccum; |
3752 | mlirAttributePrint(self, printAccum.getCallback(), |
3753 | printAccum.getUserData()); |
3754 | return printAccum.join(); |
3755 | }, |
3756 | "Returns the assembly form of the Attribute.") |
3757 | .def("__repr__", |
3758 | [](PyAttribute &self) { |
3759 | // Generally, assembly formats are not printed for __repr__ because |
3760 | // this can cause exceptionally long debug output and exceptions. |
3761 | // However, attribute values are generally considered useful and |
3762 | // are printed. This may need to be re-evaluated if debug dumps end |
3763 | // up being excessive. |
3764 | PyPrintAccumulator printAccum; |
3765 | printAccum.parts.append("Attribute("); |
3766 | mlirAttributePrint(self, printAccum.getCallback(), |
3767 | printAccum.getUserData()); |
3768 | printAccum.parts.append(")"); |
3769 | return printAccum.join(); |
3770 | }) |
3771 | .def_prop_ro("typeid", |
3772 | [](PyAttribute &self) -> MlirTypeID { |
3773 | MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); |
3774 | assert(!mlirTypeIDIsNull(mlirTypeID) && |
3775 | "mlirTypeID was expected to be non-null."); |
3776 | return mlirTypeID; |
3777 | }) |
3778 | .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { |
3779 | MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); |
3780 | assert(!mlirTypeIDIsNull(mlirTypeID) && |
3781 | "mlirTypeID was expected to be non-null."); |
3782 | std::optional<nb::callable> typeCaster = |
3783 | PyGlobals::get().lookupTypeCaster(mlirTypeID, |
3784 | mlirAttributeGetDialect(self)); |
3785 | if (!typeCaster) |
3786 | return nb::cast(self); |
3787 | return typeCaster.value()(self); |
3788 | }); |
3789 | |
3790 | //---------------------------------------------------------------------------- |
3791 | // Mapping of PyNamedAttribute |
3792 | //---------------------------------------------------------------------------- |
3793 | nb::class_<PyNamedAttribute>(m, "NamedAttribute") |
3794 | .def("__repr__", |
3795 | [](PyNamedAttribute &self) { |
3796 | PyPrintAccumulator printAccum; |
3797 | printAccum.parts.append("NamedAttribute("); |
3798 | printAccum.parts.append( |
3799 | nb::str(mlirIdentifierStr(self.namedAttr.name).data, |
3800 | mlirIdentifierStr(self.namedAttr.name).length)); |
3801 | printAccum.parts.append("="); |
3802 | mlirAttributePrint(self.namedAttr.attribute, |
3803 | printAccum.getCallback(), |
3804 | printAccum.getUserData()); |
3805 | printAccum.parts.append(")"); |
3806 | return printAccum.join(); |
3807 | }) |
3808 | .def_prop_ro( |
3809 | "name", |
3810 | [](PyNamedAttribute &self) { |
3811 | return mlirIdentifierStr(self.namedAttr.name); |
3812 | }, |
3813 | "The name of the NamedAttribute binding") |
3814 | .def_prop_ro( |
3815 | "attr", |
3816 | [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, |
3817 | nb::keep_alive<0, 1>(), |
3818 | "The underlying generic attribute of the NamedAttribute binding"); |
3819 | |
3820 | //---------------------------------------------------------------------------- |
3821 | // Mapping of PyType. |
3822 | //---------------------------------------------------------------------------- |
3823 | nb::class_<PyType>(m, "Type") |
3824 | // Delegate to the PyType copy constructor, which will also lifetime |
3825 | // extend the backing context which owns the MlirType. |
3826 | .def(nb::init<PyType &>(), nb::arg("cast_from_type"), |
3827 | "Casts the passed type to the generic Type") |
3828 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) |
3829 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) |
3830 | .def_static( |
3831 | "parse", |
3832 | [](std::string typeSpec, DefaultingPyMlirContext context) { |
3833 | PyMlirContext::ErrorCapture errors(context->getRef()); |
3834 | MlirType type = |
3835 | mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); |
3836 | if (mlirTypeIsNull(type)) |
3837 | throw MLIRError("Unable to parse type", errors.take()); |
3838 | return type; |
3839 | }, |
3840 | nb::arg("asm"), nb::arg( "context").none() = nb::none(), |
3841 | kContextParseTypeDocstring) |
3842 | .def_prop_ro( |
3843 | "context", [](PyType &self) { return self.getContext().getObject(); }, |
3844 | "Context that owns the Type") |
3845 | .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) |
3846 | .def( |
3847 | "__eq__", [](PyType &self, nb::object &other) { return false; }, |
3848 | nb::arg("other").none()) |
3849 | .def("__hash__", |
3850 | [](PyType &self) { |
3851 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
3852 | }) |
3853 | .def( |
3854 | "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) |
3855 | .def( |
3856 | "__str__", |
3857 | [](PyType &self) { |
3858 | PyPrintAccumulator printAccum; |
3859 | mlirTypePrint(self, printAccum.getCallback(), |
3860 | printAccum.getUserData()); |
3861 | return printAccum.join(); |
3862 | }, |
3863 | "Returns the assembly form of the type.") |
3864 | .def("__repr__", |
3865 | [](PyType &self) { |
3866 | // Generally, assembly formats are not printed for __repr__ because |
3867 | // this can cause exceptionally long debug output and exceptions. |
3868 | // However, types are an exception as they typically have compact |
3869 | // assembly forms and printing them is useful. |
3870 | PyPrintAccumulator printAccum; |
3871 | printAccum.parts.append("Type("); |
3872 | mlirTypePrint(self, printAccum.getCallback(), |
3873 | printAccum.getUserData()); |
3874 | printAccum.parts.append(")"); |
3875 | return printAccum.join(); |
3876 | }) |
3877 | .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, |
3878 | [](PyType &self) { |
3879 | MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); |
3880 | assert(!mlirTypeIDIsNull(mlirTypeID) && |
3881 | "mlirTypeID was expected to be non-null."); |
3882 | std::optional<nb::callable> typeCaster = |
3883 | PyGlobals::get().lookupTypeCaster(mlirTypeID, |
3884 | mlirTypeGetDialect(self)); |
3885 | if (!typeCaster) |
3886 | return nb::cast(self); |
3887 | return typeCaster.value()(self); |
3888 | }) |
3889 | .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { |
3890 | MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); |
3891 | if (!mlirTypeIDIsNull(mlirTypeID)) |
3892 | return mlirTypeID; |
3893 | auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self))); |
3894 | throw nb::value_error( |
3895 | (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); |
3896 | }); |
3897 | |
3898 | //---------------------------------------------------------------------------- |
3899 | // Mapping of PyTypeID. |
3900 | //---------------------------------------------------------------------------- |
3901 | nb::class_<PyTypeID>(m, "TypeID") |
3902 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) |
3903 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) |
3904 | // Note, this tests whether the underlying TypeIDs are the same, |
3905 | // not whether the wrapper MlirTypeIDs are the same, nor whether |
3906 | // the Python objects are the same (i.e., PyTypeID is a value type). |
3907 | .def("__eq__", |
3908 | [](PyTypeID &self, PyTypeID &other) { return self == other; }) |
3909 | .def("__eq__", |
3910 | [](PyTypeID &self, const nb::object &other) { return false; }) |
3911 | // Note, this gives the hash value of the underlying TypeID, not the |
3912 | // hash value of the Python object, nor the hash value of the |
3913 | // MlirTypeID wrapper. |
3914 | .def("__hash__", [](PyTypeID &self) { |
3915 | return static_cast<size_t>(mlirTypeIDHashValue(self)); |
3916 | }); |
3917 | |
3918 | //---------------------------------------------------------------------------- |
3919 | // Mapping of Value. |
3920 | //---------------------------------------------------------------------------- |
3921 | nb::class_<PyValue>(m, "Value") |
3922 | .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")) |
3923 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) |
3924 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) |
3925 | .def_prop_ro( |
3926 | "context", |
3927 | [](PyValue &self) { return self.getParentOperation()->getContext(); }, |
3928 | "Context in which the value lives.") |
3929 | .def( |
3930 | "dump", [](PyValue &self) { mlirValueDump(self.get()); }, |
3931 | kDumpDocstring) |
3932 | .def_prop_ro( |
3933 | "owner", |
3934 | [](PyValue &self) -> nb::object { |
3935 | MlirValue v = self.get(); |
3936 | if (mlirValueIsAOpResult(v)) { |
3937 | assert( |
3938 | mlirOperationEqual(self.getParentOperation()->get(), |
3939 | mlirOpResultGetOwner(self.get())) && |
3940 | "expected the owner of the value in Python to match that in " |
3941 | "the IR"); |
3942 | return self.getParentOperation().getObject(); |
3943 | } |
3944 | |
3945 | if (mlirValueIsABlockArgument(v)) { |
3946 | MlirBlock block = mlirBlockArgumentGetOwner(self.get()); |
3947 | return nb::cast(PyBlock(self.getParentOperation(), block)); |
3948 | } |
3949 | |
3950 | assert(false && "Value must be a block argument or an op result"); |
3951 | return nb::none(); |
3952 | }) |
3953 | .def_prop_ro("uses", |
3954 | [](PyValue &self) { |
3955 | return PyOpOperandIterator( |
3956 | mlirValueGetFirstUse(self.get())); |
3957 | }) |
3958 | .def("__eq__", |
3959 | [](PyValue &self, PyValue &other) { |
3960 | return self.get().ptr == other.get().ptr; |
3961 | }) |
3962 | .def("__eq__", [](PyValue &self, nb::object other) { return false; }) |
3963 | .def("__hash__", |
3964 | [](PyValue &self) { |
3965 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
3966 | }) |
3967 | .def( |
3968 | "__str__", |
3969 | [](PyValue &self) { |
3970 | PyPrintAccumulator printAccum; |
3971 | printAccum.parts.append("Value("); |
3972 | mlirValuePrint(self.get(), printAccum.getCallback(), |
3973 | printAccum.getUserData()); |
3974 | printAccum.parts.append(")"); |
3975 | return printAccum.join(); |
3976 | }, |
3977 | kValueDunderStrDocstring) |
3978 | .def( |
3979 | "get_name", |
3980 | [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) { |
3981 | PyPrintAccumulator printAccum; |
3982 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); |
3983 | if (useLocalScope) |
3984 | mlirOpPrintingFlagsUseLocalScope(flags); |
3985 | if (useNameLocAsPrefix) |
3986 | mlirOpPrintingFlagsPrintNameLocAsPrefix(flags); |
3987 | MlirAsmState valueState = |
3988 | mlirAsmStateCreateForValue(self.get(), flags); |
3989 | mlirValuePrintAsOperand(self.get(), valueState, |
3990 | printAccum.getCallback(), |
3991 | printAccum.getUserData()); |
3992 | mlirOpPrintingFlagsDestroy(flags); |
3993 | mlirAsmStateDestroy(valueState); |
3994 | return printAccum.join(); |
3995 | }, |
3996 | nb::arg("use_local_scope") = false, |
3997 | nb::arg("use_name_loc_as_prefix") = false) |
3998 | .def( |
3999 | "get_name", |
4000 | [](PyValue &self, PyAsmState &state) { |
4001 | PyPrintAccumulator printAccum; |
4002 | MlirAsmState valueState = state.get(); |
4003 | mlirValuePrintAsOperand(self.get(), valueState, |
4004 | printAccum.getCallback(), |
4005 | printAccum.getUserData()); |
4006 | return printAccum.join(); |
4007 | }, |
4008 | nb::arg("state"), kGetNameAsOperand) |
4009 | .def_prop_ro("type", |
4010 | [](PyValue &self) { return mlirValueGetType(self.get()); }) |
4011 | .def( |
4012 | "set_type", |
4013 | [](PyValue &self, const PyType &type) { |
4014 | return mlirValueSetType(self.get(), type); |
4015 | }, |
4016 | nb::arg("type")) |
4017 | .def( |
4018 | "replace_all_uses_with", |
4019 | [](PyValue &self, PyValue &with) { |
4020 | mlirValueReplaceAllUsesOfWith(self.get(), with.get()); |
4021 | }, |
4022 | kValueReplaceAllUsesWithDocstring) |
4023 | .def( |
4024 | "replace_all_uses_except", |
4025 | [](MlirValue self, MlirValue with, PyOperation &exception) { |
4026 | MlirOperation exceptedUser = exception.get(); |
4027 | mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); |
4028 | }, |
4029 | nb::arg("with"), nb::arg( "exceptions"), |
4030 | kValueReplaceAllUsesExceptDocstring) |
4031 | .def( |
4032 | "replace_all_uses_except", |
4033 | [](MlirValue self, MlirValue with, nb::list exceptions) { |
4034 | // Convert Python list to a SmallVector of MlirOperations |
4035 | llvm::SmallVector<MlirOperation> exceptionOps; |
4036 | for (nb::handle exception : exceptions) { |
4037 | exceptionOps.push_back(nb::cast<PyOperation &>(exception).get()); |
4038 | } |
4039 | |
4040 | mlirValueReplaceAllUsesExcept( |
4041 | self, with, static_cast<intptr_t>(exceptionOps.size()), |
4042 | exceptionOps.data()); |
4043 | }, |
4044 | nb::arg("with"), nb::arg( "exceptions"), |
4045 | kValueReplaceAllUsesExceptDocstring) |
4046 | .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, |
4047 | [](PyValue &self) { return self.maybeDownCast(); }) |
4048 | .def_prop_ro( |
4049 | "location", |
4050 | [](MlirValue self) { |
4051 | return PyLocation( |
4052 | PyMlirContext::forContext(mlirValueGetContext(self)), |
4053 | mlirValueGetLocation(self)); |
4054 | }, |
4055 | "Returns the source location the value"); |
4056 | |
4057 | PyBlockArgument::bind(m); |
4058 | PyOpResult::bind(m); |
4059 | PyOpOperand::bind(m); |
4060 | |
4061 | nb::class_<PyAsmState>(m, "AsmState") |
4062 | .def(nb::init<PyValue &, bool>(), nb::arg("value"), |
4063 | nb::arg("use_local_scope") = false) |
4064 | .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"), |
4065 | nb::arg("use_local_scope") = false); |
4066 | |
4067 | //---------------------------------------------------------------------------- |
4068 | // Mapping of SymbolTable. |
4069 | //---------------------------------------------------------------------------- |
4070 | nb::class_<PySymbolTable>(m, "SymbolTable") |
4071 | .def(nb::init<PyOperationBase &>()) |
4072 | .def("__getitem__", &PySymbolTable::dunderGetItem) |
4073 | .def("insert", &PySymbolTable::insert, nb::arg( "operation")) |
4074 | .def("erase", &PySymbolTable::erase, nb::arg( "operation")) |
4075 | .def("__delitem__", &PySymbolTable::dunderDel) |
4076 | .def("__contains__", |
4077 | [](PySymbolTable &table, const std::string &name) { |
4078 | return !mlirOperationIsNull(mlirSymbolTableLookup( |
4079 | table, mlirStringRefCreate(name.data(), name.length()))); |
4080 | }) |
4081 | // Static helpers. |
4082 | .def_static("set_symbol_name", &PySymbolTable::setSymbolName, |
4083 | nb::arg("symbol"), nb::arg( "name")) |
4084 | .def_static("get_symbol_name", &PySymbolTable::getSymbolName, |
4085 | nb::arg("symbol")) |
4086 | .def_static("get_visibility", &PySymbolTable::getVisibility, |
4087 | nb::arg("symbol")) |
4088 | .def_static("set_visibility", &PySymbolTable::setVisibility, |
4089 | nb::arg("symbol"), nb::arg( "visibility")) |
4090 | .def_static("replace_all_symbol_uses", |
4091 | &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), |
4092 | nb::arg("new_symbol"), nb::arg( "from_op")) |
4093 | .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, |
4094 | nb::arg("from_op"), nb::arg( "all_sym_uses_visible"), |
4095 | nb::arg("callback")); |
4096 | |
4097 | // Container bindings. |
4098 | PyBlockArgumentList::bind(m); |
4099 | PyBlockIterator::bind(m); |
4100 | PyBlockList::bind(m); |
4101 | PyOperationIterator::bind(m); |
4102 | PyOperationList::bind(m); |
4103 | PyOpAttributeMap::bind(m); |
4104 | PyOpOperandIterator::bind(m); |
4105 | PyOpOperandList::bind(m); |
4106 | PyOpResultList::bind(m); |
4107 | PyOpSuccessors::bind(m); |
4108 | PyRegionIterator::bind(m); |
4109 | PyRegionList::bind(m); |
4110 | |
4111 | // Debug bindings. |
4112 | PyGlobalDebugFlag::bind(m); |
4113 | |
4114 | // Attribute builder getter. |
4115 | PyAttrBuilderMap::bind(m); |
4116 | |
4117 | nb::register_exception_translator([](const std::exception_ptr &p, |
4118 | void *payload) { |
4119 | // We can't define exceptions with custom fields through pybind, so instead |
4120 | // the exception class is defined in python and imported here. |
4121 | try { |
4122 | if (p) |
4123 | std::rethrow_exception(p); |
4124 | } catch (const MLIRError &e) { |
4125 | nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) |
4126 | .attr("MLIRError")(e.message, e.errorDiagnostics); |
4127 | PyErr_SetObject(PyExc_Exception, obj.ptr()); |
4128 | } |
4129 | }); |
4130 | } |
4131 |
Definitions
- kContextParseTypeDocstring
- kContextGetCallSiteLocationDocstring
- kContextGetFileLocationDocstring
- kContextGetFileRangeDocstring
- kContextGetFusedLocationDocstring
- kContextGetNameLocationDocString
- kModuleParseDocstring
- kOperationCreateDocstring
- kOperationPrintDocstring
- kOperationPrintStateDocstring
- kOperationGetAsmDocstring
- kOperationPrintBytecodeDocstring
- kOperationStrDunderDocstring
- kDumpDocstring
- kAppendBlockDocstring
- kValueDunderStrDocstring
- kGetNameAsOperand
- kValueReplaceAllUsesWithDocstring
- kValueReplaceAllUsesExceptDocstring
- classmethod
- createCustomDialectWrapper
- toMlirStringRef
- toMlirStringRef
- toMlirStringRef
- createBlock
- PyGlobalDebugFlag
- set
- get
- bind
- mutex
- PyAttrBuilderMap
- dunderContains
- dunderGetItemNamed
- dunderSetItemNamed
- bind
- getCapsule
- PyRegionIterator
- PyRegionIterator
- dunderIter
- dunderNext
- bind
- PyRegionList
- pyClassName
- PyRegionList
- dunderIter
- bindDerived
- getRawNumElements
- getRawElement
- slice
- PyBlockIterator
- PyBlockIterator
- dunderIter
- dunderNext
- bind
- PyBlockList
- PyBlockList
- dunderIter
- dunderLen
- dunderGetItem
- appendBlock
- bind
- PyOperationIterator
- PyOperationIterator
- dunderIter
- dunderNext
- bind
- PyOperationList
- PyOperationList
- dunderIter
- dunderLen
- dunderGetItem
- bind
- PyOpOperand
- PyOpOperand
- getOwner
- getOperandNumber
- bind
- PyOpOperandIterator
- PyOpOperandIterator
- dunderIter
- dunderNext
- bind
- PyMlirContext
- ~PyMlirContext
- getCapsule
- createFromCapsule
- forContext
- live_contexts_mutex
- getLiveContexts
- getLiveCount
- getLiveOperationCount
- getLiveOperationObjects
- clearLiveOperations
- clearOperation
- clearOperationsInside
- clearOperationsInside
- clearOperationAndInside
- getLiveModuleCount
- contextEnter
- contextExit
- attachDiagnosticHandler
- handler
- resolve
- getStack
- getTopOfStack
- push
- getContext
- getInsertionPoint
- getLocation
- getDefaultContext
- getDefaultInsertionPoint
- getDefaultLocation
- pushContext
- popContext
- pushInsertionPoint
- popInsertionPoint
- pushLocation
- popLocation
- invalidate
- PyDiagnosticHandler
- ~PyDiagnosticHandler
- detach
- checkValid
- getSeverity
- getLocation
- getMessage
- getNotes
- getInfo
- getDialectForKey
- getCapsule
- createFromCapsule
- getCapsule
- createFromCapsule
- contextEnter
- contextExit
- resolve
- PyModule
- ~PyModule
- forModule
- createFromCapsule
- getCapsule
- PyOperation
- ~PyOperation
- makeObjectRef
- createInstance
- forOperation
- createDetached
- parse
- checkValid
- writeBytecode
- walk
- getAsm
- moveAfter
- moveBefore
- verify
- getParentOperation
- getBlock
- getCapsule
- createFromCapsule
- maybeInsertOperation
- create
- clone
- createOpView
- erase
- PyConcreteValue
- PyConcreteValue
- PyConcreteValue
- PyConcreteValue
- castFrom
- bind
- bindDerived
- PyOpResult
- isaFunction
- pyClassName
- bindDerived
- getValueTypes
- PyOpResultList
- pyClassName
- PyOpResultList
- bindDerived
- getOperation
- getRawNumElements
- getRawElement
- slice
- populateResultTypes
- getUniqueResult
- getOpResultOrValue
- buildGeneric
- constructDerived
- PyOpView
- PyInsertionPoint
- PyInsertionPoint
- insert
- atBlockBegin
- atBlockTerminator
- contextEnter
- contextExit
- operator==
- getCapsule
- createFromCapsule
- PyNamedAttribute
- operator==
- getCapsule
- createFromCapsule
- getCapsule
- createFromCapsule
- operator==
- getCapsule
- maybeDownCast
- createFromCapsule
- PySymbolTable
- dunderGetItem
- erase
- dunderDel
- insert
- getSymbolName
- setSymbolName
- getVisibility
- setVisibility
- replaceAllSymbolUses
- walkSymbolTables
- PyBlockArgument
- isaFunction
- pyClassName
- bindDerived
- PyBlockArgumentList
- pyClassName
- PyBlockArgumentList
- bindDerived
- getRawNumElements
- getRawElement
- slice
- PyOpOperandList
- pyClassName
- PyOpOperandList
- dunderSetItem
- bindDerived
- getRawNumElements
- getRawElement
- slice
- PyOpSuccessors
- pyClassName
- PyOpSuccessors
- dunderSetItem
- bindDerived
- getRawNumElements
- getRawElement
- slice
- PyOpAttributeMap
- PyOpAttributeMap
- dunderGetItemNamed
- dunderGetItemIndexed
- dunderSetItem
- dunderDelItem
- dunderLen
- dunderContains
- bind
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more