1 | //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "IRModule.h" |
10 | |
11 | #include "Globals.h" |
12 | #include "PybindUtils.h" |
13 | |
14 | #include "mlir-c/Bindings/Python/Interop.h" |
15 | #include "mlir-c/BuiltinAttributes.h" |
16 | #include "mlir-c/Debug.h" |
17 | #include "mlir-c/Diagnostics.h" |
18 | #include "mlir-c/IR.h" |
19 | #include "mlir-c/Support.h" |
20 | #include "mlir/Bindings/Python/PybindAdaptors.h" |
21 | #include "llvm/ADT/ArrayRef.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | |
24 | #include <optional> |
25 | #include <utility> |
26 | |
27 | namespace py = pybind11; |
28 | using namespace py::literals; |
29 | using namespace mlir; |
30 | using namespace mlir::python; |
31 | |
32 | using llvm::SmallVector; |
33 | using llvm::StringRef; |
34 | using llvm::Twine; |
35 | |
36 | //------------------------------------------------------------------------------ |
37 | // Docstrings (trivial, non-duplicated docstrings are included inline). |
38 | //------------------------------------------------------------------------------ |
39 | |
40 | static const char kContextParseTypeDocstring[] = |
41 | R"(Parses the assembly form of a type. |
42 | |
43 | Returns a Type object or raises an MLIRError if the type cannot be parsed. |
44 | |
45 | See also: https://mlir.llvm.org/docs/LangRef/#type-system |
46 | )" ; |
47 | |
48 | static const char kContextGetCallSiteLocationDocstring[] = |
49 | R"(Gets a Location representing a caller and callsite)" ; |
50 | |
51 | static const char kContextGetFileLocationDocstring[] = |
52 | R"(Gets a Location representing a file, line and column)" ; |
53 | |
54 | static const char kContextGetFusedLocationDocstring[] = |
55 | R"(Gets a Location representing a fused location with optional metadata)" ; |
56 | |
57 | static const char kContextGetNameLocationDocString[] = |
58 | R"(Gets a Location representing a named location with optional child location)" ; |
59 | |
60 | static const char kModuleParseDocstring[] = |
61 | R"(Parses a module's assembly format from a string. |
62 | |
63 | Returns a new MlirModule or raises an MLIRError if the parsing fails. |
64 | |
65 | See also: https://mlir.llvm.org/docs/LangRef/ |
66 | )" ; |
67 | |
68 | static const char kOperationCreateDocstring[] = |
69 | R"(Creates a new operation. |
70 | |
71 | Args: |
72 | name: Operation name (e.g. "dialect.operation"). |
73 | results: Sequence of Type representing op result types. |
74 | attributes: Dict of str:Attribute. |
75 | successors: List of Block for the operation's successors. |
76 | regions: Number of regions to create. |
77 | location: A Location object (defaults to resolve from context manager). |
78 | ip: An InsertionPoint (defaults to resolve from context manager or set to |
79 | False to disable insertion, even with an insertion point set in the |
80 | context manager). |
81 | infer_type: Whether to infer result types. |
82 | Returns: |
83 | A new "detached" Operation object. Detached operations can be added |
84 | to blocks, which causes them to become "attached." |
85 | )" ; |
86 | |
87 | static const char kOperationPrintDocstring[] = |
88 | R"(Prints the assembly form of the operation to a file like object. |
89 | |
90 | Args: |
91 | file: The file like object to write to. Defaults to sys.stdout. |
92 | binary: Whether to write bytes (True) or str (False). Defaults to False. |
93 | large_elements_limit: Whether to elide elements attributes above this |
94 | number of elements. Defaults to None (no limit). |
95 | enable_debug_info: Whether to print debug/location information. Defaults |
96 | to False. |
97 | pretty_debug_info: Whether to format debug information for easier reading |
98 | by a human (warning: the result is unparseable). |
99 | print_generic_op_form: Whether to print the generic assembly forms of all |
100 | ops. Defaults to False. |
101 | use_local_Scope: Whether to print in a way that is more optimized for |
102 | multi-threaded access but may not be consistent with how the overall |
103 | module prints. |
104 | assume_verified: By default, if not printing generic form, the verifier |
105 | will be run and if it fails, generic form will be printed with a comment |
106 | about failed verification. While a reasonable default for interactive use, |
107 | for systematic use, it is often better for the caller to verify explicitly |
108 | and report failures in a more robust fashion. Set this to True if doing this |
109 | in order to avoid running a redundant verification. If the IR is actually |
110 | invalid, behavior is undefined. |
111 | )" ; |
112 | |
113 | static const char kOperationPrintStateDocstring[] = |
114 | R"(Prints the assembly form of the operation to a file like object. |
115 | |
116 | Args: |
117 | file: The file like object to write to. Defaults to sys.stdout. |
118 | binary: Whether to write bytes (True) or str (False). Defaults to False. |
119 | state: AsmState capturing the operation numbering and flags. |
120 | )" ; |
121 | |
122 | static const char kOperationGetAsmDocstring[] = |
123 | R"(Gets the assembly form of the operation with all options available. |
124 | |
125 | Args: |
126 | binary: Whether to return a bytes (True) or str (False) object. Defaults to |
127 | False. |
128 | ... others ...: See the print() method for common keyword arguments for |
129 | configuring the printout. |
130 | Returns: |
131 | Either a bytes or str object, depending on the setting of the 'binary' |
132 | argument. |
133 | )" ; |
134 | |
135 | static const char kOperationPrintBytecodeDocstring[] = |
136 | R"(Write the bytecode form of the operation to a file like object. |
137 | |
138 | Args: |
139 | file: The file like object to write to. |
140 | desired_version: The version of bytecode to emit. |
141 | Returns: |
142 | The bytecode writer status. |
143 | )" ; |
144 | |
145 | static const char kOperationStrDunderDocstring[] = |
146 | R"(Gets the assembly form of the operation with default options. |
147 | |
148 | If more advanced control over the assembly formatting or I/O options is needed, |
149 | use the dedicated print or get_asm method, which supports keyword arguments to |
150 | customize behavior. |
151 | )" ; |
152 | |
153 | static const char kDumpDocstring[] = |
154 | R"(Dumps a debug representation of the object to stderr.)" ; |
155 | |
156 | static const char kAppendBlockDocstring[] = |
157 | R"(Appends a new block, with argument types as positional args. |
158 | |
159 | Returns: |
160 | The created block. |
161 | )" ; |
162 | |
163 | static const char kValueDunderStrDocstring[] = |
164 | R"(Returns the string form of the value. |
165 | |
166 | If the value is a block argument, this is the assembly form of its type and the |
167 | position in the argument list. If the value is an operation result, this is |
168 | equivalent to printing the operation that produced it. |
169 | )" ; |
170 | |
171 | static const char kGetNameAsOperand[] = |
172 | R"(Returns the string form of value as an operand (i.e., the ValueID). |
173 | )" ; |
174 | |
175 | static const char kValueReplaceAllUsesWithDocstring[] = |
176 | R"(Replace all uses of value with the new value, updating anything in |
177 | the IR that uses 'self' to use the other value instead. |
178 | )" ; |
179 | |
180 | //------------------------------------------------------------------------------ |
181 | // Utilities. |
182 | //------------------------------------------------------------------------------ |
183 | |
184 | /// Helper for creating an @classmethod. |
185 | template <class Func, typename... Args> |
186 | py::object classmethod(Func f, Args... args) { |
187 | py::object cf = py::cpp_function(f, args...); |
188 | return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); |
189 | } |
190 | |
191 | static py::object |
192 | createCustomDialectWrapper(const std::string &dialectNamespace, |
193 | py::object dialectDescriptor) { |
194 | auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); |
195 | if (!dialectClass) { |
196 | // Use the base class. |
197 | return py::cast(PyDialect(std::move(dialectDescriptor))); |
198 | } |
199 | |
200 | // Create the custom implementation. |
201 | return (*dialectClass)(std::move(dialectDescriptor)); |
202 | } |
203 | |
204 | static MlirStringRef toMlirStringRef(const std::string &s) { |
205 | return mlirStringRefCreate(s.data(), s.size()); |
206 | } |
207 | |
208 | /// Create a block, using the current location context if no locations are |
209 | /// specified. |
210 | static MlirBlock createBlock(const py::sequence &pyArgTypes, |
211 | const std::optional<py::sequence> &pyArgLocs) { |
212 | SmallVector<MlirType> argTypes; |
213 | argTypes.reserve(pyArgTypes.size()); |
214 | for (const auto &pyType : pyArgTypes) |
215 | argTypes.push_back(pyType.cast<PyType &>()); |
216 | |
217 | SmallVector<MlirLocation> argLocs; |
218 | if (pyArgLocs) { |
219 | argLocs.reserve(pyArgLocs->size()); |
220 | for (const auto &pyLoc : *pyArgLocs) |
221 | argLocs.push_back(pyLoc.cast<PyLocation &>()); |
222 | } else if (!argTypes.empty()) { |
223 | argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); |
224 | } |
225 | |
226 | if (argTypes.size() != argLocs.size()) |
227 | throw py::value_error(("Expected " + Twine(argTypes.size()) + |
228 | " locations, got: " + Twine(argLocs.size())) |
229 | .str()); |
230 | return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); |
231 | } |
232 | |
233 | /// Wrapper for the global LLVM debugging flag. |
234 | struct PyGlobalDebugFlag { |
235 | static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } |
236 | |
237 | static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } |
238 | |
239 | static void bind(py::module &m) { |
240 | // Debug flags. |
241 | py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug" , py::module_local()) |
242 | .def_property_static("flag" , &PyGlobalDebugFlag::get, |
243 | &PyGlobalDebugFlag::set, "LLVM-wide debug flag" ); |
244 | } |
245 | }; |
246 | |
247 | struct PyAttrBuilderMap { |
248 | static bool dunderContains(const std::string &attributeKind) { |
249 | return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); |
250 | } |
251 | static py::function dundeGetItemNamed(const std::string &attributeKind) { |
252 | auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); |
253 | if (!builder) |
254 | throw py::key_error(attributeKind); |
255 | return *builder; |
256 | } |
257 | static void dundeSetItemNamed(const std::string &attributeKind, |
258 | py::function func, bool replace) { |
259 | PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), |
260 | replace); |
261 | } |
262 | |
263 | static void bind(py::module &m) { |
264 | py::class_<PyAttrBuilderMap>(m, "AttrBuilder" , py::module_local()) |
265 | .def_static("contains" , &PyAttrBuilderMap::dunderContains) |
266 | .def_static("get" , &PyAttrBuilderMap::dundeGetItemNamed) |
267 | .def_static("insert" , &PyAttrBuilderMap::dundeSetItemNamed, |
268 | "attribute_kind"_a , "attr_builder"_a , "replace"_a = false, |
269 | "Register an attribute builder for building MLIR " |
270 | "attributes from python values." ); |
271 | } |
272 | }; |
273 | |
274 | //------------------------------------------------------------------------------ |
275 | // PyBlock |
276 | //------------------------------------------------------------------------------ |
277 | |
278 | py::object PyBlock::getCapsule() { |
279 | return py::reinterpret_steal<py::object>(mlirPythonBlockToCapsule(get())); |
280 | } |
281 | |
282 | //------------------------------------------------------------------------------ |
283 | // Collections. |
284 | //------------------------------------------------------------------------------ |
285 | |
286 | namespace { |
287 | |
288 | class PyRegionIterator { |
289 | public: |
290 | PyRegionIterator(PyOperationRef operation) |
291 | : operation(std::move(operation)) {} |
292 | |
293 | PyRegionIterator &dunderIter() { return *this; } |
294 | |
295 | PyRegion dunderNext() { |
296 | operation->checkValid(); |
297 | if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { |
298 | throw py::stop_iteration(); |
299 | } |
300 | MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); |
301 | return PyRegion(operation, region); |
302 | } |
303 | |
304 | static void bind(py::module &m) { |
305 | py::class_<PyRegionIterator>(m, "RegionIterator" , py::module_local()) |
306 | .def("__iter__" , &PyRegionIterator::dunderIter) |
307 | .def("__next__" , &PyRegionIterator::dunderNext); |
308 | } |
309 | |
310 | private: |
311 | PyOperationRef operation; |
312 | int nextIndex = 0; |
313 | }; |
314 | |
315 | /// Regions of an op are fixed length and indexed numerically so are represented |
316 | /// with a sequence-like container. |
317 | class PyRegionList { |
318 | public: |
319 | PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} |
320 | |
321 | PyRegionIterator dunderIter() { |
322 | operation->checkValid(); |
323 | return PyRegionIterator(operation); |
324 | } |
325 | |
326 | intptr_t dunderLen() { |
327 | operation->checkValid(); |
328 | return mlirOperationGetNumRegions(operation->get()); |
329 | } |
330 | |
331 | PyRegion dunderGetItem(intptr_t index) { |
332 | // dunderLen checks validity. |
333 | if (index < 0 || index >= dunderLen()) { |
334 | throw py::index_error("attempt to access out of bounds region" ); |
335 | } |
336 | MlirRegion region = mlirOperationGetRegion(operation->get(), index); |
337 | return PyRegion(operation, region); |
338 | } |
339 | |
340 | static void bind(py::module &m) { |
341 | py::class_<PyRegionList>(m, "RegionSequence" , py::module_local()) |
342 | .def("__len__" , &PyRegionList::dunderLen) |
343 | .def("__iter__" , &PyRegionList::dunderIter) |
344 | .def("__getitem__" , &PyRegionList::dunderGetItem); |
345 | } |
346 | |
347 | private: |
348 | PyOperationRef operation; |
349 | }; |
350 | |
351 | class PyBlockIterator { |
352 | public: |
353 | PyBlockIterator(PyOperationRef operation, MlirBlock next) |
354 | : operation(std::move(operation)), next(next) {} |
355 | |
356 | PyBlockIterator &dunderIter() { return *this; } |
357 | |
358 | PyBlock dunderNext() { |
359 | operation->checkValid(); |
360 | if (mlirBlockIsNull(next)) { |
361 | throw py::stop_iteration(); |
362 | } |
363 | |
364 | PyBlock returnBlock(operation, next); |
365 | next = mlirBlockGetNextInRegion(next); |
366 | return returnBlock; |
367 | } |
368 | |
369 | static void bind(py::module &m) { |
370 | py::class_<PyBlockIterator>(m, "BlockIterator" , py::module_local()) |
371 | .def("__iter__" , &PyBlockIterator::dunderIter) |
372 | .def("__next__" , &PyBlockIterator::dunderNext); |
373 | } |
374 | |
375 | private: |
376 | PyOperationRef operation; |
377 | MlirBlock next; |
378 | }; |
379 | |
380 | /// Blocks are exposed by the C-API as a forward-only linked list. In Python, |
381 | /// we present them as a more full-featured list-like container but optimize |
382 | /// it for forward iteration. Blocks are always owned by a region. |
383 | class PyBlockList { |
384 | public: |
385 | PyBlockList(PyOperationRef operation, MlirRegion region) |
386 | : operation(std::move(operation)), region(region) {} |
387 | |
388 | PyBlockIterator dunderIter() { |
389 | operation->checkValid(); |
390 | return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); |
391 | } |
392 | |
393 | intptr_t dunderLen() { |
394 | operation->checkValid(); |
395 | intptr_t count = 0; |
396 | MlirBlock block = mlirRegionGetFirstBlock(region); |
397 | while (!mlirBlockIsNull(block)) { |
398 | count += 1; |
399 | block = mlirBlockGetNextInRegion(block); |
400 | } |
401 | return count; |
402 | } |
403 | |
404 | PyBlock dunderGetItem(intptr_t index) { |
405 | operation->checkValid(); |
406 | if (index < 0) { |
407 | throw py::index_error("attempt to access out of bounds block" ); |
408 | } |
409 | MlirBlock block = mlirRegionGetFirstBlock(region); |
410 | while (!mlirBlockIsNull(block)) { |
411 | if (index == 0) { |
412 | return PyBlock(operation, block); |
413 | } |
414 | block = mlirBlockGetNextInRegion(block); |
415 | index -= 1; |
416 | } |
417 | throw py::index_error("attempt to access out of bounds block" ); |
418 | } |
419 | |
420 | PyBlock appendBlock(const py::args &pyArgTypes, |
421 | const std::optional<py::sequence> &pyArgLocs) { |
422 | operation->checkValid(); |
423 | MlirBlock block = createBlock(pyArgTypes, pyArgLocs); |
424 | mlirRegionAppendOwnedBlock(region, block); |
425 | return PyBlock(operation, block); |
426 | } |
427 | |
428 | static void bind(py::module &m) { |
429 | py::class_<PyBlockList>(m, "BlockList" , py::module_local()) |
430 | .def("__getitem__" , &PyBlockList::dunderGetItem) |
431 | .def("__iter__" , &PyBlockList::dunderIter) |
432 | .def("__len__" , &PyBlockList::dunderLen) |
433 | .def("append" , &PyBlockList::appendBlock, kAppendBlockDocstring, |
434 | py::arg("arg_locs" ) = std::nullopt); |
435 | } |
436 | |
437 | private: |
438 | PyOperationRef operation; |
439 | MlirRegion region; |
440 | }; |
441 | |
442 | class PyOperationIterator { |
443 | public: |
444 | PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) |
445 | : parentOperation(std::move(parentOperation)), next(next) {} |
446 | |
447 | PyOperationIterator &dunderIter() { return *this; } |
448 | |
449 | py::object dunderNext() { |
450 | parentOperation->checkValid(); |
451 | if (mlirOperationIsNull(next)) { |
452 | throw py::stop_iteration(); |
453 | } |
454 | |
455 | PyOperationRef returnOperation = |
456 | PyOperation::forOperation(parentOperation->getContext(), next); |
457 | next = mlirOperationGetNextInBlock(next); |
458 | return returnOperation->createOpView(); |
459 | } |
460 | |
461 | static void bind(py::module &m) { |
462 | py::class_<PyOperationIterator>(m, "OperationIterator" , py::module_local()) |
463 | .def("__iter__" , &PyOperationIterator::dunderIter) |
464 | .def("__next__" , &PyOperationIterator::dunderNext); |
465 | } |
466 | |
467 | private: |
468 | PyOperationRef parentOperation; |
469 | MlirOperation next; |
470 | }; |
471 | |
472 | /// Operations are exposed by the C-API as a forward-only linked list. In |
473 | /// Python, we present them as a more full-featured list-like container but |
474 | /// optimize it for forward iteration. Iterable operations are always owned |
475 | /// by a block. |
476 | class PyOperationList { |
477 | public: |
478 | PyOperationList(PyOperationRef parentOperation, MlirBlock block) |
479 | : parentOperation(std::move(parentOperation)), block(block) {} |
480 | |
481 | PyOperationIterator dunderIter() { |
482 | parentOperation->checkValid(); |
483 | return PyOperationIterator(parentOperation, |
484 | mlirBlockGetFirstOperation(block)); |
485 | } |
486 | |
487 | intptr_t dunderLen() { |
488 | parentOperation->checkValid(); |
489 | intptr_t count = 0; |
490 | MlirOperation childOp = mlirBlockGetFirstOperation(block); |
491 | while (!mlirOperationIsNull(childOp)) { |
492 | count += 1; |
493 | childOp = mlirOperationGetNextInBlock(childOp); |
494 | } |
495 | return count; |
496 | } |
497 | |
498 | py::object dunderGetItem(intptr_t index) { |
499 | parentOperation->checkValid(); |
500 | if (index < 0) { |
501 | throw py::index_error("attempt to access out of bounds operation" ); |
502 | } |
503 | MlirOperation childOp = mlirBlockGetFirstOperation(block); |
504 | while (!mlirOperationIsNull(childOp)) { |
505 | if (index == 0) { |
506 | return PyOperation::forOperation(parentOperation->getContext(), childOp) |
507 | ->createOpView(); |
508 | } |
509 | childOp = mlirOperationGetNextInBlock(childOp); |
510 | index -= 1; |
511 | } |
512 | throw py::index_error("attempt to access out of bounds operation" ); |
513 | } |
514 | |
515 | static void bind(py::module &m) { |
516 | py::class_<PyOperationList>(m, "OperationList" , py::module_local()) |
517 | .def("__getitem__" , &PyOperationList::dunderGetItem) |
518 | .def("__iter__" , &PyOperationList::dunderIter) |
519 | .def("__len__" , &PyOperationList::dunderLen); |
520 | } |
521 | |
522 | private: |
523 | PyOperationRef parentOperation; |
524 | MlirBlock block; |
525 | }; |
526 | |
527 | class PyOpOperand { |
528 | public: |
529 | PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} |
530 | |
531 | py::object getOwner() { |
532 | MlirOperation owner = mlirOpOperandGetOwner(opOperand); |
533 | PyMlirContextRef context = |
534 | PyMlirContext::forContext(mlirOperationGetContext(owner)); |
535 | return PyOperation::forOperation(context, owner)->createOpView(); |
536 | } |
537 | |
538 | size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } |
539 | |
540 | static void bind(py::module &m) { |
541 | py::class_<PyOpOperand>(m, "OpOperand" , py::module_local()) |
542 | .def_property_readonly("owner" , &PyOpOperand::getOwner) |
543 | .def_property_readonly("operand_number" , |
544 | &PyOpOperand::getOperandNumber); |
545 | } |
546 | |
547 | private: |
548 | MlirOpOperand opOperand; |
549 | }; |
550 | |
551 | class PyOpOperandIterator { |
552 | public: |
553 | PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} |
554 | |
555 | PyOpOperandIterator &dunderIter() { return *this; } |
556 | |
557 | PyOpOperand dunderNext() { |
558 | if (mlirOpOperandIsNull(opOperand)) |
559 | throw py::stop_iteration(); |
560 | |
561 | PyOpOperand returnOpOperand(opOperand); |
562 | opOperand = mlirOpOperandGetNextUse(opOperand); |
563 | return returnOpOperand; |
564 | } |
565 | |
566 | static void bind(py::module &m) { |
567 | py::class_<PyOpOperandIterator>(m, "OpOperandIterator" , py::module_local()) |
568 | .def("__iter__" , &PyOpOperandIterator::dunderIter) |
569 | .def("__next__" , &PyOpOperandIterator::dunderNext); |
570 | } |
571 | |
572 | private: |
573 | MlirOpOperand opOperand; |
574 | }; |
575 | |
576 | } // namespace |
577 | |
578 | //------------------------------------------------------------------------------ |
579 | // PyMlirContext |
580 | //------------------------------------------------------------------------------ |
581 | |
582 | PyMlirContext::PyMlirContext(MlirContext context) : context(context) { |
583 | py::gil_scoped_acquire acquire; |
584 | auto &liveContexts = getLiveContexts(); |
585 | liveContexts[context.ptr] = this; |
586 | } |
587 | |
588 | PyMlirContext::~PyMlirContext() { |
589 | // Note that the only public way to construct an instance is via the |
590 | // forContext method, which always puts the associated handle into |
591 | // liveContexts. |
592 | py::gil_scoped_acquire acquire; |
593 | getLiveContexts().erase(context.ptr); |
594 | mlirContextDestroy(context); |
595 | } |
596 | |
597 | py::object PyMlirContext::getCapsule() { |
598 | return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); |
599 | } |
600 | |
601 | py::object PyMlirContext::createFromCapsule(py::object capsule) { |
602 | MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); |
603 | if (mlirContextIsNull(rawContext)) |
604 | throw py::error_already_set(); |
605 | return forContext(rawContext).releaseObject(); |
606 | } |
607 | |
608 | PyMlirContext *PyMlirContext::createNewContextForInit() { |
609 | MlirContext context = mlirContextCreateWithThreading(false); |
610 | return new PyMlirContext(context); |
611 | } |
612 | |
613 | PyMlirContextRef PyMlirContext::forContext(MlirContext context) { |
614 | py::gil_scoped_acquire acquire; |
615 | auto &liveContexts = getLiveContexts(); |
616 | auto it = liveContexts.find(context.ptr); |
617 | if (it == liveContexts.end()) { |
618 | // Create. |
619 | PyMlirContext *unownedContextWrapper = new PyMlirContext(context); |
620 | py::object pyRef = py::cast(unownedContextWrapper); |
621 | assert(pyRef && "cast to py::object failed" ); |
622 | liveContexts[context.ptr] = unownedContextWrapper; |
623 | return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); |
624 | } |
625 | // Use existing. |
626 | py::object pyRef = py::cast(it->second); |
627 | return PyMlirContextRef(it->second, std::move(pyRef)); |
628 | } |
629 | |
630 | PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { |
631 | static LiveContextMap liveContexts; |
632 | return liveContexts; |
633 | } |
634 | |
635 | size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } |
636 | |
637 | size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } |
638 | |
639 | std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() { |
640 | std::vector<PyOperation *> liveObjects; |
641 | for (auto &entry : liveOperations) |
642 | liveObjects.push_back(entry.second.second); |
643 | return liveObjects; |
644 | } |
645 | |
646 | size_t PyMlirContext::clearLiveOperations() { |
647 | for (auto &op : liveOperations) |
648 | op.second.second->setInvalid(); |
649 | size_t numInvalidated = liveOperations.size(); |
650 | liveOperations.clear(); |
651 | return numInvalidated; |
652 | } |
653 | |
654 | void PyMlirContext::clearOperation(MlirOperation op) { |
655 | auto it = liveOperations.find(op.ptr); |
656 | if (it != liveOperations.end()) { |
657 | it->second.second->setInvalid(); |
658 | liveOperations.erase(it); |
659 | } |
660 | } |
661 | |
662 | void PyMlirContext::clearOperationsInside(PyOperationBase &op) { |
663 | typedef struct { |
664 | PyOperation &rootOp; |
665 | bool rootSeen; |
666 | } callBackData; |
667 | callBackData data{.rootOp: op.getOperation(), .rootSeen: false}; |
668 | // Mark all ops below the op that the passmanager will be rooted |
669 | // at (but not op itself - note the preorder) as invalid. |
670 | MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, |
671 | void *userData) { |
672 | callBackData *data = static_cast<callBackData *>(userData); |
673 | if (LLVM_LIKELY(data->rootSeen)) |
674 | data->rootOp.getOperation().getContext()->clearOperation(op); |
675 | else |
676 | data->rootSeen = true; |
677 | return MlirWalkResult::MlirWalkResultAdvance; |
678 | }; |
679 | mlirOperationWalk(op.getOperation(), invalidatingCallback, |
680 | static_cast<void *>(&data), MlirWalkPreOrder); |
681 | } |
682 | void PyMlirContext::clearOperationsInside(MlirOperation op) { |
683 | PyOperationRef opRef = PyOperation::forOperation(getRef(), op); |
684 | clearOperationsInside(opRef->getOperation()); |
685 | } |
686 | |
687 | size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } |
688 | |
689 | pybind11::object PyMlirContext::contextEnter() { |
690 | return PyThreadContextEntry::pushContext(*this); |
691 | } |
692 | |
693 | void PyMlirContext::contextExit(const pybind11::object &excType, |
694 | const pybind11::object &excVal, |
695 | const pybind11::object &excTb) { |
696 | PyThreadContextEntry::popContext(context&: *this); |
697 | } |
698 | |
699 | py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { |
700 | // Note that ownership is transferred to the delete callback below by way of |
701 | // an explicit inc_ref (borrow). |
702 | PyDiagnosticHandler *pyHandler = |
703 | new PyDiagnosticHandler(get(), std::move(callback)); |
704 | py::object pyHandlerObject = |
705 | py::cast(pyHandler, py::return_value_policy::take_ownership); |
706 | pyHandlerObject.inc_ref(); |
707 | |
708 | // In these C callbacks, the userData is a PyDiagnosticHandler* that is |
709 | // guaranteed to be known to pybind. |
710 | auto handlerCallback = |
711 | +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { |
712 | PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); |
713 | py::object pyDiagnosticObject = |
714 | py::cast(pyDiagnostic, py::return_value_policy::take_ownership); |
715 | |
716 | auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); |
717 | bool result = false; |
718 | { |
719 | // Since this can be called from arbitrary C++ contexts, always get the |
720 | // gil. |
721 | py::gil_scoped_acquire gil; |
722 | try { |
723 | result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); |
724 | } catch (std::exception &e) { |
725 | fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n" , |
726 | e.what()); |
727 | pyHandler->hadError = true; |
728 | } |
729 | } |
730 | |
731 | pyDiagnostic->invalidate(); |
732 | return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); |
733 | }; |
734 | auto deleteCallback = +[](void *userData) { |
735 | auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); |
736 | assert(pyHandler->registeredID && "handler is not registered" ); |
737 | pyHandler->registeredID.reset(); |
738 | |
739 | // Decrement reference, balancing the inc_ref() above. |
740 | py::object pyHandlerObject = |
741 | py::cast(pyHandler, py::return_value_policy::reference); |
742 | pyHandlerObject.dec_ref(); |
743 | }; |
744 | |
745 | pyHandler->registeredID = mlirContextAttachDiagnosticHandler( |
746 | get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); |
747 | return pyHandlerObject; |
748 | } |
749 | |
750 | MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, |
751 | void *userData) { |
752 | auto *self = static_cast<ErrorCapture *>(userData); |
753 | // Check if the context requested we emit errors instead of capturing them. |
754 | if (self->ctx->emitErrorDiagnostics) |
755 | return mlirLogicalResultFailure(); |
756 | |
757 | if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) |
758 | return mlirLogicalResultFailure(); |
759 | |
760 | self->errors.emplace_back(args: PyDiagnostic(diag).getInfo()); |
761 | return mlirLogicalResultSuccess(); |
762 | } |
763 | |
764 | PyMlirContext &DefaultingPyMlirContext::resolve() { |
765 | PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); |
766 | if (!context) { |
767 | throw std::runtime_error( |
768 | "An MLIR function requires a Context but none was provided in the call " |
769 | "or from the surrounding environment. Either pass to the function with " |
770 | "a 'context=' argument or establish a default using 'with Context():'" ); |
771 | } |
772 | return *context; |
773 | } |
774 | |
775 | //------------------------------------------------------------------------------ |
776 | // PyThreadContextEntry management |
777 | //------------------------------------------------------------------------------ |
778 | |
779 | std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { |
780 | static thread_local std::vector<PyThreadContextEntry> stack; |
781 | return stack; |
782 | } |
783 | |
784 | PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { |
785 | auto &stack = getStack(); |
786 | if (stack.empty()) |
787 | return nullptr; |
788 | return &stack.back(); |
789 | } |
790 | |
791 | void PyThreadContextEntry::push(FrameKind frameKind, py::object context, |
792 | py::object insertionPoint, |
793 | py::object location) { |
794 | auto &stack = getStack(); |
795 | stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), |
796 | std::move(location)); |
797 | // If the new stack has more than one entry and the context of the new top |
798 | // entry matches the previous, copy the insertionPoint and location from the |
799 | // previous entry if missing from the new top entry. |
800 | if (stack.size() > 1) { |
801 | auto &prev = *(stack.rbegin() + 1); |
802 | auto ¤t = stack.back(); |
803 | if (current.context.is(prev.context)) { |
804 | // Default non-context objects from the previous entry. |
805 | if (!current.insertionPoint) |
806 | current.insertionPoint = prev.insertionPoint; |
807 | if (!current.location) |
808 | current.location = prev.location; |
809 | } |
810 | } |
811 | } |
812 | |
813 | PyMlirContext *PyThreadContextEntry::getContext() { |
814 | if (!context) |
815 | return nullptr; |
816 | return py::cast<PyMlirContext *>(context); |
817 | } |
818 | |
819 | PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { |
820 | if (!insertionPoint) |
821 | return nullptr; |
822 | return py::cast<PyInsertionPoint *>(insertionPoint); |
823 | } |
824 | |
825 | PyLocation *PyThreadContextEntry::getLocation() { |
826 | if (!location) |
827 | return nullptr; |
828 | return py::cast<PyLocation *>(location); |
829 | } |
830 | |
831 | PyMlirContext *PyThreadContextEntry::getDefaultContext() { |
832 | auto *tos = getTopOfStack(); |
833 | return tos ? tos->getContext() : nullptr; |
834 | } |
835 | |
836 | PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { |
837 | auto *tos = getTopOfStack(); |
838 | return tos ? tos->getInsertionPoint() : nullptr; |
839 | } |
840 | |
841 | PyLocation *PyThreadContextEntry::getDefaultLocation() { |
842 | auto *tos = getTopOfStack(); |
843 | return tos ? tos->getLocation() : nullptr; |
844 | } |
845 | |
846 | py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { |
847 | py::object contextObj = py::cast(context); |
848 | push(FrameKind::Context, /*context=*/contextObj, |
849 | /*insertionPoint=*/py::object(), |
850 | /*location=*/py::object()); |
851 | return contextObj; |
852 | } |
853 | |
854 | void PyThreadContextEntry::popContext(PyMlirContext &context) { |
855 | auto &stack = getStack(); |
856 | if (stack.empty()) |
857 | throw std::runtime_error("Unbalanced Context enter/exit" ); |
858 | auto &tos = stack.back(); |
859 | if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) |
860 | throw std::runtime_error("Unbalanced Context enter/exit" ); |
861 | stack.pop_back(); |
862 | } |
863 | |
864 | py::object |
865 | PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { |
866 | py::object contextObj = |
867 | insertionPoint.getBlock().getParentOperation()->getContext().getObject(); |
868 | py::object insertionPointObj = py::cast(insertionPoint); |
869 | push(FrameKind::InsertionPoint, |
870 | /*context=*/contextObj, |
871 | /*insertionPoint=*/insertionPointObj, |
872 | /*location=*/py::object()); |
873 | return insertionPointObj; |
874 | } |
875 | |
876 | void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { |
877 | auto &stack = getStack(); |
878 | if (stack.empty()) |
879 | throw std::runtime_error("Unbalanced InsertionPoint enter/exit" ); |
880 | auto &tos = stack.back(); |
881 | if (tos.frameKind != FrameKind::InsertionPoint && |
882 | tos.getInsertionPoint() != &insertionPoint) |
883 | throw std::runtime_error("Unbalanced InsertionPoint enter/exit" ); |
884 | stack.pop_back(); |
885 | } |
886 | |
887 | py::object PyThreadContextEntry::pushLocation(PyLocation &location) { |
888 | py::object contextObj = location.getContext().getObject(); |
889 | py::object locationObj = py::cast(location); |
890 | push(FrameKind::Location, /*context=*/contextObj, |
891 | /*insertionPoint=*/py::object(), |
892 | /*location=*/locationObj); |
893 | return locationObj; |
894 | } |
895 | |
896 | void PyThreadContextEntry::popLocation(PyLocation &location) { |
897 | auto &stack = getStack(); |
898 | if (stack.empty()) |
899 | throw std::runtime_error("Unbalanced Location enter/exit" ); |
900 | auto &tos = stack.back(); |
901 | if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) |
902 | throw std::runtime_error("Unbalanced Location enter/exit" ); |
903 | stack.pop_back(); |
904 | } |
905 | |
906 | //------------------------------------------------------------------------------ |
907 | // PyDiagnostic* |
908 | //------------------------------------------------------------------------------ |
909 | |
910 | void PyDiagnostic::invalidate() { |
911 | valid = false; |
912 | if (materializedNotes) { |
913 | for (auto ¬eObject : *materializedNotes) { |
914 | PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); |
915 | note->invalidate(); |
916 | } |
917 | } |
918 | } |
919 | |
920 | PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, |
921 | py::object callback) |
922 | : context(context), callback(std::move(callback)) {} |
923 | |
924 | PyDiagnosticHandler::~PyDiagnosticHandler() = default; |
925 | |
926 | void PyDiagnosticHandler::detach() { |
927 | if (!registeredID) |
928 | return; |
929 | MlirDiagnosticHandlerID localID = *registeredID; |
930 | mlirContextDetachDiagnosticHandler(context, localID); |
931 | assert(!registeredID && "should have unregistered" ); |
932 | // Not strictly necessary but keeps stale pointers from being around to cause |
933 | // issues. |
934 | context = {nullptr}; |
935 | } |
936 | |
937 | void PyDiagnostic::checkValid() { |
938 | if (!valid) { |
939 | throw std::invalid_argument( |
940 | "Diagnostic is invalid (used outside of callback)" ); |
941 | } |
942 | } |
943 | |
944 | MlirDiagnosticSeverity PyDiagnostic::getSeverity() { |
945 | checkValid(); |
946 | return mlirDiagnosticGetSeverity(diagnostic); |
947 | } |
948 | |
949 | PyLocation PyDiagnostic::getLocation() { |
950 | checkValid(); |
951 | MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); |
952 | MlirContext context = mlirLocationGetContext(loc); |
953 | return PyLocation(PyMlirContext::forContext(context), loc); |
954 | } |
955 | |
956 | py::str PyDiagnostic::getMessage() { |
957 | checkValid(); |
958 | py::object fileObject = py::module::import("io" ).attr("StringIO" )(); |
959 | PyFileAccumulator accum(fileObject, /*binary=*/false); |
960 | mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); |
961 | return fileObject.attr("getvalue" )(); |
962 | } |
963 | |
964 | py::tuple PyDiagnostic::getNotes() { |
965 | checkValid(); |
966 | if (materializedNotes) |
967 | return *materializedNotes; |
968 | intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); |
969 | materializedNotes = py::tuple(numNotes); |
970 | for (intptr_t i = 0; i < numNotes; ++i) { |
971 | MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); |
972 | (*materializedNotes)[i] = PyDiagnostic(noteDiag); |
973 | } |
974 | return *materializedNotes; |
975 | } |
976 | |
977 | PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { |
978 | std::vector<DiagnosticInfo> notes; |
979 | for (py::handle n : getNotes()) |
980 | notes.emplace_back(n.cast<PyDiagnostic>().getInfo()); |
981 | return {getSeverity(), getLocation(), getMessage(), std::move(notes)}; |
982 | } |
983 | |
984 | //------------------------------------------------------------------------------ |
985 | // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry |
986 | //------------------------------------------------------------------------------ |
987 | |
988 | MlirDialect PyDialects::getDialectForKey(const std::string &key, |
989 | bool attrError) { |
990 | MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), |
991 | {key.data(), key.size()}); |
992 | if (mlirDialectIsNull(dialect)) { |
993 | std::string msg = (Twine("Dialect '" ) + key + "' not found" ).str(); |
994 | if (attrError) |
995 | throw py::attribute_error(msg); |
996 | throw py::index_error(msg); |
997 | } |
998 | return dialect; |
999 | } |
1000 | |
1001 | py::object PyDialectRegistry::getCapsule() { |
1002 | return py::reinterpret_steal<py::object>( |
1003 | mlirPythonDialectRegistryToCapsule(*this)); |
1004 | } |
1005 | |
1006 | PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { |
1007 | MlirDialectRegistry rawRegistry = |
1008 | mlirPythonCapsuleToDialectRegistry(capsule.ptr()); |
1009 | if (mlirDialectRegistryIsNull(rawRegistry)) |
1010 | throw py::error_already_set(); |
1011 | return PyDialectRegistry(rawRegistry); |
1012 | } |
1013 | |
1014 | //------------------------------------------------------------------------------ |
1015 | // PyLocation |
1016 | //------------------------------------------------------------------------------ |
1017 | |
1018 | py::object PyLocation::getCapsule() { |
1019 | return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); |
1020 | } |
1021 | |
1022 | PyLocation PyLocation::createFromCapsule(py::object capsule) { |
1023 | MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); |
1024 | if (mlirLocationIsNull(rawLoc)) |
1025 | throw py::error_already_set(); |
1026 | return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), |
1027 | rawLoc); |
1028 | } |
1029 | |
1030 | py::object PyLocation::contextEnter() { |
1031 | return PyThreadContextEntry::pushLocation(*this); |
1032 | } |
1033 | |
1034 | void PyLocation::contextExit(const pybind11::object &excType, |
1035 | const pybind11::object &excVal, |
1036 | const pybind11::object &excTb) { |
1037 | PyThreadContextEntry::popLocation(location&: *this); |
1038 | } |
1039 | |
1040 | PyLocation &DefaultingPyLocation::resolve() { |
1041 | auto *location = PyThreadContextEntry::getDefaultLocation(); |
1042 | if (!location) { |
1043 | throw std::runtime_error( |
1044 | "An MLIR function requires a Location but none was provided in the " |
1045 | "call or from the surrounding environment. Either pass to the function " |
1046 | "with a 'loc=' argument or establish a default using 'with loc:'" ); |
1047 | } |
1048 | return *location; |
1049 | } |
1050 | |
1051 | //------------------------------------------------------------------------------ |
1052 | // PyModule |
1053 | //------------------------------------------------------------------------------ |
1054 | |
1055 | PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) |
1056 | : BaseContextObject(std::move(contextRef)), module(module) {} |
1057 | |
1058 | PyModule::~PyModule() { |
1059 | py::gil_scoped_acquire acquire; |
1060 | auto &liveModules = getContext()->liveModules; |
1061 | assert(liveModules.count(module.ptr) == 1 && |
1062 | "destroying module not in live map" ); |
1063 | liveModules.erase(module.ptr); |
1064 | mlirModuleDestroy(module); |
1065 | } |
1066 | |
1067 | PyModuleRef PyModule::forModule(MlirModule module) { |
1068 | MlirContext context = mlirModuleGetContext(module); |
1069 | PyMlirContextRef contextRef = PyMlirContext::forContext(context); |
1070 | |
1071 | py::gil_scoped_acquire acquire; |
1072 | auto &liveModules = contextRef->liveModules; |
1073 | auto it = liveModules.find(module.ptr); |
1074 | if (it == liveModules.end()) { |
1075 | // Create. |
1076 | PyModule *unownedModule = new PyModule(std::move(contextRef), module); |
1077 | // Note that the default return value policy on cast is automatic_reference, |
1078 | // which does not take ownership (delete will not be called). |
1079 | // Just be explicit. |
1080 | py::object pyRef = |
1081 | py::cast(unownedModule, py::return_value_policy::take_ownership); |
1082 | unownedModule->handle = pyRef; |
1083 | liveModules[module.ptr] = |
1084 | std::make_pair(unownedModule->handle, unownedModule); |
1085 | return PyModuleRef(unownedModule, std::move(pyRef)); |
1086 | } |
1087 | // Use existing. |
1088 | PyModule *existing = it->second.second; |
1089 | py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); |
1090 | return PyModuleRef(existing, std::move(pyRef)); |
1091 | } |
1092 | |
1093 | py::object PyModule::createFromCapsule(py::object capsule) { |
1094 | MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); |
1095 | if (mlirModuleIsNull(rawModule)) |
1096 | throw py::error_already_set(); |
1097 | return forModule(rawModule).releaseObject(); |
1098 | } |
1099 | |
1100 | py::object PyModule::getCapsule() { |
1101 | return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); |
1102 | } |
1103 | |
1104 | //------------------------------------------------------------------------------ |
1105 | // PyOperation |
1106 | //------------------------------------------------------------------------------ |
1107 | |
1108 | PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) |
1109 | : BaseContextObject(std::move(contextRef)), operation(operation) {} |
1110 | |
1111 | PyOperation::~PyOperation() { |
1112 | // If the operation has already been invalidated there is nothing to do. |
1113 | if (!valid) |
1114 | return; |
1115 | auto &liveOperations = getContext()->liveOperations; |
1116 | assert(liveOperations.count(operation.ptr) == 1 && |
1117 | "destroying operation not in live map" ); |
1118 | liveOperations.erase(operation.ptr); |
1119 | if (!isAttached()) { |
1120 | mlirOperationDestroy(operation); |
1121 | } |
1122 | } |
1123 | |
1124 | PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, |
1125 | MlirOperation operation, |
1126 | py::object parentKeepAlive) { |
1127 | auto &liveOperations = contextRef->liveOperations; |
1128 | // Create. |
1129 | PyOperation *unownedOperation = |
1130 | new PyOperation(std::move(contextRef), operation); |
1131 | // Note that the default return value policy on cast is automatic_reference, |
1132 | // which does not take ownership (delete will not be called). |
1133 | // Just be explicit. |
1134 | py::object pyRef = |
1135 | py::cast(unownedOperation, py::return_value_policy::take_ownership); |
1136 | unownedOperation->handle = pyRef; |
1137 | if (parentKeepAlive) { |
1138 | unownedOperation->parentKeepAlive = std::move(parentKeepAlive); |
1139 | } |
1140 | liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); |
1141 | return PyOperationRef(unownedOperation, std::move(pyRef)); |
1142 | } |
1143 | |
1144 | PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, |
1145 | MlirOperation operation, |
1146 | py::object parentKeepAlive) { |
1147 | auto &liveOperations = contextRef->liveOperations; |
1148 | auto it = liveOperations.find(operation.ptr); |
1149 | if (it == liveOperations.end()) { |
1150 | // Create. |
1151 | return createInstance(std::move(contextRef), operation, |
1152 | std::move(parentKeepAlive)); |
1153 | } |
1154 | // Use existing. |
1155 | PyOperation *existing = it->second.second; |
1156 | py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); |
1157 | return PyOperationRef(existing, std::move(pyRef)); |
1158 | } |
1159 | |
1160 | PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, |
1161 | MlirOperation operation, |
1162 | py::object parentKeepAlive) { |
1163 | auto &liveOperations = contextRef->liveOperations; |
1164 | assert(liveOperations.count(operation.ptr) == 0 && |
1165 | "cannot create detached operation that already exists" ); |
1166 | (void)liveOperations; |
1167 | |
1168 | PyOperationRef created = createInstance(std::move(contextRef), operation, |
1169 | std::move(parentKeepAlive)); |
1170 | created->attached = false; |
1171 | return created; |
1172 | } |
1173 | |
1174 | PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, |
1175 | const std::string &sourceStr, |
1176 | const std::string &sourceName) { |
1177 | PyMlirContext::ErrorCapture errors(contextRef); |
1178 | MlirOperation op = |
1179 | mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), |
1180 | toMlirStringRef(sourceName)); |
1181 | if (mlirOperationIsNull(op)) |
1182 | throw MLIRError("Unable to parse operation assembly" , errors.take()); |
1183 | return PyOperation::createDetached(std::move(contextRef), op); |
1184 | } |
1185 | |
1186 | void PyOperation::checkValid() const { |
1187 | if (!valid) { |
1188 | throw std::runtime_error("the operation has been invalidated" ); |
1189 | } |
1190 | } |
1191 | |
1192 | void PyOperationBase::print(std::optional<int64_t> largeElementsLimit, |
1193 | bool enableDebugInfo, bool prettyDebugInfo, |
1194 | bool printGenericOpForm, bool useLocalScope, |
1195 | bool assumeVerified, py::object fileObject, |
1196 | bool binary) { |
1197 | PyOperation &operation = getOperation(); |
1198 | operation.checkValid(); |
1199 | if (fileObject.is_none()) |
1200 | fileObject = py::module::import("sys" ).attr("stdout" ); |
1201 | |
1202 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); |
1203 | if (largeElementsLimit) |
1204 | mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); |
1205 | if (enableDebugInfo) |
1206 | mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, |
1207 | /*prettyForm=*/prettyDebugInfo); |
1208 | if (printGenericOpForm) |
1209 | mlirOpPrintingFlagsPrintGenericOpForm(flags); |
1210 | if (useLocalScope) |
1211 | mlirOpPrintingFlagsUseLocalScope(flags); |
1212 | if (assumeVerified) |
1213 | mlirOpPrintingFlagsAssumeVerified(flags); |
1214 | |
1215 | PyFileAccumulator accum(fileObject, binary); |
1216 | mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), |
1217 | accum.getUserData()); |
1218 | mlirOpPrintingFlagsDestroy(flags); |
1219 | } |
1220 | |
1221 | void PyOperationBase::print(PyAsmState &state, py::object fileObject, |
1222 | bool binary) { |
1223 | PyOperation &operation = getOperation(); |
1224 | operation.checkValid(); |
1225 | if (fileObject.is_none()) |
1226 | fileObject = py::module::import("sys" ).attr("stdout" ); |
1227 | PyFileAccumulator accum(fileObject, binary); |
1228 | mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), |
1229 | accum.getUserData()); |
1230 | } |
1231 | |
1232 | void PyOperationBase::writeBytecode(const py::object &fileObject, |
1233 | std::optional<int64_t> bytecodeVersion) { |
1234 | PyOperation &operation = getOperation(); |
1235 | operation.checkValid(); |
1236 | PyFileAccumulator accum(fileObject, /*binary=*/true); |
1237 | |
1238 | if (!bytecodeVersion.has_value()) |
1239 | return mlirOperationWriteBytecode(operation, accum.getCallback(), |
1240 | accum.getUserData()); |
1241 | |
1242 | MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate(); |
1243 | mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); |
1244 | MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig( |
1245 | operation, config, accum.getCallback(), accum.getUserData()); |
1246 | mlirBytecodeWriterConfigDestroy(config); |
1247 | if (mlirLogicalResultIsFailure(res)) |
1248 | throw py::value_error((Twine("Unable to honor desired bytecode version " ) + |
1249 | Twine(*bytecodeVersion)) |
1250 | .str()); |
1251 | } |
1252 | |
1253 | void PyOperationBase::walk( |
1254 | std::function<MlirWalkResult(MlirOperation)> callback, |
1255 | MlirWalkOrder walkOrder) { |
1256 | PyOperation &operation = getOperation(); |
1257 | operation.checkValid(); |
1258 | struct UserData { |
1259 | std::function<MlirWalkResult(MlirOperation)> callback; |
1260 | bool gotException; |
1261 | std::string exceptionWhat; |
1262 | py::object exceptionType; |
1263 | }; |
1264 | UserData userData{callback, false, {}, {}}; |
1265 | MlirOperationWalkCallback walkCallback = [](MlirOperation op, |
1266 | void *userData) { |
1267 | UserData *calleeUserData = static_cast<UserData *>(userData); |
1268 | try { |
1269 | return (calleeUserData->callback)(op); |
1270 | } catch (py::error_already_set &e) { |
1271 | calleeUserData->gotException = true; |
1272 | calleeUserData->exceptionWhat = e.what(); |
1273 | calleeUserData->exceptionType = e.type(); |
1274 | return MlirWalkResult::MlirWalkResultInterrupt; |
1275 | } |
1276 | }; |
1277 | mlirOperationWalk(operation, walkCallback, &userData, walkOrder); |
1278 | if (userData.gotException) { |
1279 | std::string message("Exception raised in callback: " ); |
1280 | message.append(str: userData.exceptionWhat); |
1281 | throw std::runtime_error(message); |
1282 | } |
1283 | } |
1284 | |
1285 | py::object PyOperationBase::getAsm(bool binary, |
1286 | std::optional<int64_t> largeElementsLimit, |
1287 | bool enableDebugInfo, bool prettyDebugInfo, |
1288 | bool printGenericOpForm, bool useLocalScope, |
1289 | bool assumeVerified) { |
1290 | py::object fileObject; |
1291 | if (binary) { |
1292 | fileObject = py::module::import("io" ).attr("BytesIO" )(); |
1293 | } else { |
1294 | fileObject = py::module::import("io" ).attr("StringIO" )(); |
1295 | } |
1296 | print(/*largeElementsLimit=*/largeElementsLimit, |
1297 | /*enableDebugInfo=*/enableDebugInfo, |
1298 | /*prettyDebugInfo=*/prettyDebugInfo, |
1299 | /*printGenericOpForm=*/printGenericOpForm, |
1300 | /*useLocalScope=*/useLocalScope, |
1301 | /*assumeVerified=*/assumeVerified, |
1302 | /*fileObject=*/fileObject, |
1303 | /*binary=*/binary); |
1304 | |
1305 | return fileObject.attr("getvalue" )(); |
1306 | } |
1307 | |
1308 | void PyOperationBase::moveAfter(PyOperationBase &other) { |
1309 | PyOperation &operation = getOperation(); |
1310 | PyOperation &otherOp = other.getOperation(); |
1311 | operation.checkValid(); |
1312 | otherOp.checkValid(); |
1313 | mlirOperationMoveAfter(operation, otherOp); |
1314 | operation.parentKeepAlive = otherOp.parentKeepAlive; |
1315 | } |
1316 | |
1317 | void PyOperationBase::moveBefore(PyOperationBase &other) { |
1318 | PyOperation &operation = getOperation(); |
1319 | PyOperation &otherOp = other.getOperation(); |
1320 | operation.checkValid(); |
1321 | otherOp.checkValid(); |
1322 | mlirOperationMoveBefore(operation, otherOp); |
1323 | operation.parentKeepAlive = otherOp.parentKeepAlive; |
1324 | } |
1325 | |
1326 | bool PyOperationBase::verify() { |
1327 | PyOperation &op = getOperation(); |
1328 | PyMlirContext::ErrorCapture errors(op.getContext()); |
1329 | if (!mlirOperationVerify(op.get())) |
1330 | throw MLIRError("Verification failed" , errors.take()); |
1331 | return true; |
1332 | } |
1333 | |
1334 | std::optional<PyOperationRef> PyOperation::getParentOperation() { |
1335 | checkValid(); |
1336 | if (!isAttached()) |
1337 | throw py::value_error("Detached operations have no parent" ); |
1338 | MlirOperation operation = mlirOperationGetParentOperation(get()); |
1339 | if (mlirOperationIsNull(operation)) |
1340 | return {}; |
1341 | return PyOperation::forOperation(getContext(), operation); |
1342 | } |
1343 | |
1344 | PyBlock PyOperation::getBlock() { |
1345 | checkValid(); |
1346 | std::optional<PyOperationRef> parentOperation = getParentOperation(); |
1347 | MlirBlock block = mlirOperationGetBlock(get()); |
1348 | assert(!mlirBlockIsNull(block) && "Attached operation has null parent" ); |
1349 | assert(parentOperation && "Operation has no parent" ); |
1350 | return PyBlock{std::move(*parentOperation), block}; |
1351 | } |
1352 | |
1353 | py::object PyOperation::getCapsule() { |
1354 | checkValid(); |
1355 | return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); |
1356 | } |
1357 | |
1358 | py::object PyOperation::createFromCapsule(py::object capsule) { |
1359 | MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); |
1360 | if (mlirOperationIsNull(rawOperation)) |
1361 | throw py::error_already_set(); |
1362 | MlirContext rawCtxt = mlirOperationGetContext(rawOperation); |
1363 | return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) |
1364 | .releaseObject(); |
1365 | } |
1366 | |
1367 | static void maybeInsertOperation(PyOperationRef &op, |
1368 | const py::object &maybeIp) { |
1369 | // InsertPoint active? |
1370 | if (!maybeIp.is(py::cast(false))) { |
1371 | PyInsertionPoint *ip; |
1372 | if (maybeIp.is_none()) { |
1373 | ip = PyThreadContextEntry::getDefaultInsertionPoint(); |
1374 | } else { |
1375 | ip = py::cast<PyInsertionPoint *>(maybeIp); |
1376 | } |
1377 | if (ip) |
1378 | ip->insert(*op.get()); |
1379 | } |
1380 | } |
1381 | |
1382 | py::object PyOperation::create(const std::string &name, |
1383 | std::optional<std::vector<PyType *>> results, |
1384 | std::optional<std::vector<PyValue *>> operands, |
1385 | std::optional<py::dict> attributes, |
1386 | std::optional<std::vector<PyBlock *>> successors, |
1387 | int regions, DefaultingPyLocation location, |
1388 | const py::object &maybeIp, bool inferType) { |
1389 | llvm::SmallVector<MlirValue, 4> mlirOperands; |
1390 | llvm::SmallVector<MlirType, 4> mlirResults; |
1391 | llvm::SmallVector<MlirBlock, 4> mlirSuccessors; |
1392 | llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; |
1393 | |
1394 | // General parameter validation. |
1395 | if (regions < 0) |
1396 | throw py::value_error("number of regions must be >= 0" ); |
1397 | |
1398 | // Unpack/validate operands. |
1399 | if (operands) { |
1400 | mlirOperands.reserve(operands->size()); |
1401 | for (PyValue *operand : *operands) { |
1402 | if (!operand) |
1403 | throw py::value_error("operand value cannot be None" ); |
1404 | mlirOperands.push_back(operand->get()); |
1405 | } |
1406 | } |
1407 | |
1408 | // Unpack/validate results. |
1409 | if (results) { |
1410 | mlirResults.reserve(results->size()); |
1411 | for (PyType *result : *results) { |
1412 | // TODO: Verify result type originate from the same context. |
1413 | if (!result) |
1414 | throw py::value_error("result type cannot be None" ); |
1415 | mlirResults.push_back(*result); |
1416 | } |
1417 | } |
1418 | // Unpack/validate attributes. |
1419 | if (attributes) { |
1420 | mlirAttributes.reserve(attributes->size()); |
1421 | for (auto &it : *attributes) { |
1422 | std::string key; |
1423 | try { |
1424 | key = it.first.cast<std::string>(); |
1425 | } catch (py::cast_error &err) { |
1426 | std::string msg = "Invalid attribute key (not a string) when " |
1427 | "attempting to create the operation \"" + |
1428 | name + "\" (" + err.what() + ")" ; |
1429 | throw py::cast_error(msg); |
1430 | } |
1431 | try { |
1432 | auto &attribute = it.second.cast<PyAttribute &>(); |
1433 | // TODO: Verify attribute originates from the same context. |
1434 | mlirAttributes.emplace_back(std::move(key), attribute); |
1435 | } catch (py::reference_cast_error &) { |
1436 | // This exception seems thrown when the value is "None". |
1437 | std::string msg = |
1438 | "Found an invalid (`None`?) attribute value for the key \"" + key + |
1439 | "\" when attempting to create the operation \"" + name + "\"" ; |
1440 | throw py::cast_error(msg); |
1441 | } catch (py::cast_error &err) { |
1442 | std::string msg = "Invalid attribute value for the key \"" + key + |
1443 | "\" when attempting to create the operation \"" + |
1444 | name + "\" (" + err.what() + ")" ; |
1445 | throw py::cast_error(msg); |
1446 | } |
1447 | } |
1448 | } |
1449 | // Unpack/validate successors. |
1450 | if (successors) { |
1451 | mlirSuccessors.reserve(successors->size()); |
1452 | for (auto *successor : *successors) { |
1453 | // TODO: Verify successor originate from the same context. |
1454 | if (!successor) |
1455 | throw py::value_error("successor block cannot be None" ); |
1456 | mlirSuccessors.push_back(successor->get()); |
1457 | } |
1458 | } |
1459 | |
1460 | // Apply unpacked/validated to the operation state. Beyond this |
1461 | // point, exceptions cannot be thrown or else the state will leak. |
1462 | MlirOperationState state = |
1463 | mlirOperationStateGet(toMlirStringRef(name), location); |
1464 | if (!mlirOperands.empty()) |
1465 | mlirOperationStateAddOperands(&state, mlirOperands.size(), |
1466 | mlirOperands.data()); |
1467 | state.enableResultTypeInference = inferType; |
1468 | if (!mlirResults.empty()) |
1469 | mlirOperationStateAddResults(&state, mlirResults.size(), |
1470 | mlirResults.data()); |
1471 | if (!mlirAttributes.empty()) { |
1472 | // Note that the attribute names directly reference bytes in |
1473 | // mlirAttributes, so that vector must not be changed from here |
1474 | // on. |
1475 | llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; |
1476 | mlirNamedAttributes.reserve(mlirAttributes.size()); |
1477 | for (auto &it : mlirAttributes) |
1478 | mlirNamedAttributes.push_back(mlirNamedAttributeGet( |
1479 | mlirIdentifierGet(mlirAttributeGetContext(it.second), |
1480 | toMlirStringRef(it.first)), |
1481 | it.second)); |
1482 | mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), |
1483 | mlirNamedAttributes.data()); |
1484 | } |
1485 | if (!mlirSuccessors.empty()) |
1486 | mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), |
1487 | mlirSuccessors.data()); |
1488 | if (regions) { |
1489 | llvm::SmallVector<MlirRegion, 4> mlirRegions; |
1490 | mlirRegions.resize(regions); |
1491 | for (int i = 0; i < regions; ++i) |
1492 | mlirRegions[i] = mlirRegionCreate(); |
1493 | mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), |
1494 | mlirRegions.data()); |
1495 | } |
1496 | |
1497 | // Construct the operation. |
1498 | MlirOperation operation = mlirOperationCreate(&state); |
1499 | if (!operation.ptr) |
1500 | throw py::value_error("Operation creation failed" ); |
1501 | PyOperationRef created = |
1502 | PyOperation::createDetached(location->getContext(), operation); |
1503 | maybeInsertOperation(created, maybeIp); |
1504 | |
1505 | return created->createOpView(); |
1506 | } |
1507 | |
1508 | py::object PyOperation::clone(const py::object &maybeIp) { |
1509 | MlirOperation clonedOperation = mlirOperationClone(operation); |
1510 | PyOperationRef cloned = |
1511 | PyOperation::createDetached(getContext(), clonedOperation); |
1512 | maybeInsertOperation(cloned, maybeIp); |
1513 | |
1514 | return cloned->createOpView(); |
1515 | } |
1516 | |
1517 | py::object PyOperation::createOpView() { |
1518 | checkValid(); |
1519 | MlirIdentifier ident = mlirOperationGetName(get()); |
1520 | MlirStringRef identStr = mlirIdentifierStr(ident); |
1521 | auto operationCls = PyGlobals::get().lookupOperationClass( |
1522 | StringRef(identStr.data, identStr.length)); |
1523 | if (operationCls) |
1524 | return PyOpView::constructDerived(*operationCls, *getRef().get()); |
1525 | return py::cast(PyOpView(getRef().getObject())); |
1526 | } |
1527 | |
1528 | void PyOperation::erase() { |
1529 | checkValid(); |
1530 | // TODO: Fix memory hazards when erasing a tree of operations for which a deep |
1531 | // Python reference to a child operation is live. All children should also |
1532 | // have their `valid` bit set to false. |
1533 | auto &liveOperations = getContext()->liveOperations; |
1534 | if (liveOperations.count(operation.ptr)) |
1535 | liveOperations.erase(operation.ptr); |
1536 | mlirOperationDestroy(operation); |
1537 | valid = false; |
1538 | } |
1539 | |
1540 | //------------------------------------------------------------------------------ |
1541 | // PyOpView |
1542 | //------------------------------------------------------------------------------ |
1543 | |
1544 | static void populateResultTypes(StringRef name, py::list resultTypeList, |
1545 | const py::object &resultSegmentSpecObj, |
1546 | std::vector<int32_t> &resultSegmentLengths, |
1547 | std::vector<PyType *> &resultTypes) { |
1548 | resultTypes.reserve(n: resultTypeList.size()); |
1549 | if (resultSegmentSpecObj.is_none()) { |
1550 | // Non-variadic result unpacking. |
1551 | for (const auto &it : llvm::enumerate(resultTypeList)) { |
1552 | try { |
1553 | resultTypes.push_back(py::cast<PyType *>(it.value())); |
1554 | if (!resultTypes.back()) |
1555 | throw py::cast_error(); |
1556 | } catch (py::cast_error &err) { |
1557 | throw py::value_error((llvm::Twine("Result " ) + |
1558 | llvm::Twine(it.index()) + " of operation \"" + |
1559 | name + "\" must be a Type (" + err.what() + ")" ) |
1560 | .str()); |
1561 | } |
1562 | } |
1563 | } else { |
1564 | // Sized result unpacking. |
1565 | auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); |
1566 | if (resultSegmentSpec.size() != resultTypeList.size()) { |
1567 | throw py::value_error((llvm::Twine("Operation \"" ) + name + |
1568 | "\" requires " + |
1569 | llvm::Twine(resultSegmentSpec.size()) + |
1570 | " result segments but was provided " + |
1571 | llvm::Twine(resultTypeList.size())) |
1572 | .str()); |
1573 | } |
1574 | resultSegmentLengths.reserve(n: resultTypeList.size()); |
1575 | for (const auto &it : |
1576 | llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { |
1577 | int segmentSpec = std::get<1>(it.value()); |
1578 | if (segmentSpec == 1 || segmentSpec == 0) { |
1579 | // Unpack unary element. |
1580 | try { |
1581 | auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); |
1582 | if (resultType) { |
1583 | resultTypes.push_back(resultType); |
1584 | resultSegmentLengths.push_back(1); |
1585 | } else if (segmentSpec == 0) { |
1586 | // Allowed to be optional. |
1587 | resultSegmentLengths.push_back(0); |
1588 | } else { |
1589 | throw py::cast_error("was None and result is not optional" ); |
1590 | } |
1591 | } catch (py::cast_error &err) { |
1592 | throw py::value_error((llvm::Twine("Result " ) + |
1593 | llvm::Twine(it.index()) + " of operation \"" + |
1594 | name + "\" must be a Type (" + err.what() + |
1595 | ")" ) |
1596 | .str()); |
1597 | } |
1598 | } else if (segmentSpec == -1) { |
1599 | // Unpack sequence by appending. |
1600 | try { |
1601 | if (std::get<0>(it.value()).is_none()) { |
1602 | // Treat it as an empty list. |
1603 | resultSegmentLengths.push_back(0); |
1604 | } else { |
1605 | // Unpack the list. |
1606 | auto segment = py::cast<py::sequence>(std::get<0>(it.value())); |
1607 | for (py::object segmentItem : segment) { |
1608 | resultTypes.push_back(py::cast<PyType *>(segmentItem)); |
1609 | if (!resultTypes.back()) { |
1610 | throw py::cast_error("contained a None item" ); |
1611 | } |
1612 | } |
1613 | resultSegmentLengths.push_back(segment.size()); |
1614 | } |
1615 | } catch (std::exception &err) { |
1616 | // NOTE: Sloppy to be using a catch-all here, but there are at least |
1617 | // three different unrelated exceptions that can be thrown in the |
1618 | // above "casts". Just keep the scope above small and catch them all. |
1619 | throw py::value_error((llvm::Twine("Result " ) + |
1620 | llvm::Twine(it.index()) + " of operation \"" + |
1621 | name + "\" must be a Sequence of Types (" + |
1622 | err.what() + ")" ) |
1623 | .str()); |
1624 | } |
1625 | } else { |
1626 | throw py::value_error("Unexpected segment spec" ); |
1627 | } |
1628 | } |
1629 | } |
1630 | } |
1631 | |
1632 | py::object PyOpView::buildGeneric( |
1633 | const py::object &cls, std::optional<py::list> resultTypeList, |
1634 | py::list operandList, std::optional<py::dict> attributes, |
1635 | std::optional<std::vector<PyBlock *>> successors, |
1636 | std::optional<int> regions, DefaultingPyLocation location, |
1637 | const py::object &maybeIp) { |
1638 | PyMlirContextRef context = location->getContext(); |
1639 | // Class level operation construction metadata. |
1640 | std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME" )); |
1641 | // Operand and result segment specs are either none, which does no |
1642 | // variadic unpacking, or a list of ints with segment sizes, where each |
1643 | // element is either a positive number (typically 1 for a scalar) or -1 to |
1644 | // indicate that it is derived from the length of the same-indexed operand |
1645 | // or result (implying that it is a list at that position). |
1646 | py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS" ); |
1647 | py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS" ); |
1648 | |
1649 | std::vector<int32_t> operandSegmentLengths; |
1650 | std::vector<int32_t> resultSegmentLengths; |
1651 | |
1652 | // Validate/determine region count. |
1653 | auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS" )); |
1654 | int opMinRegionCount = std::get<0>(opRegionSpec); |
1655 | bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); |
1656 | if (!regions) { |
1657 | regions = opMinRegionCount; |
1658 | } |
1659 | if (*regions < opMinRegionCount) { |
1660 | throw py::value_error( |
1661 | (llvm::Twine("Operation \"" ) + name + "\" requires a minimum of " + |
1662 | llvm::Twine(opMinRegionCount) + |
1663 | " regions but was built with regions=" + llvm::Twine(*regions)) |
1664 | .str()); |
1665 | } |
1666 | if (opHasNoVariadicRegions && *regions > opMinRegionCount) { |
1667 | throw py::value_error( |
1668 | (llvm::Twine("Operation \"" ) + name + "\" requires a maximum of " + |
1669 | llvm::Twine(opMinRegionCount) + |
1670 | " regions but was built with regions=" + llvm::Twine(*regions)) |
1671 | .str()); |
1672 | } |
1673 | |
1674 | // Unpack results. |
1675 | std::vector<PyType *> resultTypes; |
1676 | if (resultTypeList.has_value()) { |
1677 | populateResultTypes(name, *resultTypeList, resultSegmentSpecObj, |
1678 | resultSegmentLengths, resultTypes); |
1679 | } |
1680 | |
1681 | // Unpack operands. |
1682 | std::vector<PyValue *> operands; |
1683 | operands.reserve(n: operands.size()); |
1684 | if (operandSegmentSpecObj.is_none()) { |
1685 | // Non-sized operand unpacking. |
1686 | for (const auto &it : llvm::enumerate(operandList)) { |
1687 | try { |
1688 | operands.push_back(py::cast<PyValue *>(it.value())); |
1689 | if (!operands.back()) |
1690 | throw py::cast_error(); |
1691 | } catch (py::cast_error &err) { |
1692 | throw py::value_error((llvm::Twine("Operand " ) + |
1693 | llvm::Twine(it.index()) + " of operation \"" + |
1694 | name + "\" must be a Value (" + err.what() + ")" ) |
1695 | .str()); |
1696 | } |
1697 | } |
1698 | } else { |
1699 | // Sized operand unpacking. |
1700 | auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); |
1701 | if (operandSegmentSpec.size() != operandList.size()) { |
1702 | throw py::value_error((llvm::Twine("Operation \"" ) + name + |
1703 | "\" requires " + |
1704 | llvm::Twine(operandSegmentSpec.size()) + |
1705 | "operand segments but was provided " + |
1706 | llvm::Twine(operandList.size())) |
1707 | .str()); |
1708 | } |
1709 | operandSegmentLengths.reserve(n: operandList.size()); |
1710 | for (const auto &it : |
1711 | llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { |
1712 | int segmentSpec = std::get<1>(it.value()); |
1713 | if (segmentSpec == 1 || segmentSpec == 0) { |
1714 | // Unpack unary element. |
1715 | try { |
1716 | auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); |
1717 | if (operandValue) { |
1718 | operands.push_back(operandValue); |
1719 | operandSegmentLengths.push_back(1); |
1720 | } else if (segmentSpec == 0) { |
1721 | // Allowed to be optional. |
1722 | operandSegmentLengths.push_back(0); |
1723 | } else { |
1724 | throw py::cast_error("was None and operand is not optional" ); |
1725 | } |
1726 | } catch (py::cast_error &err) { |
1727 | throw py::value_error((llvm::Twine("Operand " ) + |
1728 | llvm::Twine(it.index()) + " of operation \"" + |
1729 | name + "\" must be a Value (" + err.what() + |
1730 | ")" ) |
1731 | .str()); |
1732 | } |
1733 | } else if (segmentSpec == -1) { |
1734 | // Unpack sequence by appending. |
1735 | try { |
1736 | if (std::get<0>(it.value()).is_none()) { |
1737 | // Treat it as an empty list. |
1738 | operandSegmentLengths.push_back(0); |
1739 | } else { |
1740 | // Unpack the list. |
1741 | auto segment = py::cast<py::sequence>(std::get<0>(it.value())); |
1742 | for (py::object segmentItem : segment) { |
1743 | operands.push_back(py::cast<PyValue *>(segmentItem)); |
1744 | if (!operands.back()) { |
1745 | throw py::cast_error("contained a None item" ); |
1746 | } |
1747 | } |
1748 | operandSegmentLengths.push_back(segment.size()); |
1749 | } |
1750 | } catch (std::exception &err) { |
1751 | // NOTE: Sloppy to be using a catch-all here, but there are at least |
1752 | // three different unrelated exceptions that can be thrown in the |
1753 | // above "casts". Just keep the scope above small and catch them all. |
1754 | throw py::value_error((llvm::Twine("Operand " ) + |
1755 | llvm::Twine(it.index()) + " of operation \"" + |
1756 | name + "\" must be a Sequence of Values (" + |
1757 | err.what() + ")" ) |
1758 | .str()); |
1759 | } |
1760 | } else { |
1761 | throw py::value_error("Unexpected segment spec" ); |
1762 | } |
1763 | } |
1764 | } |
1765 | |
1766 | // Merge operand/result segment lengths into attributes if needed. |
1767 | if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { |
1768 | // Dup. |
1769 | if (attributes) { |
1770 | attributes = py::dict(*attributes); |
1771 | } else { |
1772 | attributes = py::dict(); |
1773 | } |
1774 | if (attributes->contains("resultSegmentSizes" ) || |
1775 | attributes->contains("operandSegmentSizes" )) { |
1776 | throw py::value_error("Manually setting a 'resultSegmentSizes' or " |
1777 | "'operandSegmentSizes' attribute is unsupported. " |
1778 | "Use Operation.create for such low-level access." ); |
1779 | } |
1780 | |
1781 | // Add resultSegmentSizes attribute. |
1782 | if (!resultSegmentLengths.empty()) { |
1783 | MlirAttribute segmentLengthAttr = |
1784 | mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(), |
1785 | resultSegmentLengths.data()); |
1786 | (*attributes)["resultSegmentSizes" ] = |
1787 | PyAttribute(context, segmentLengthAttr); |
1788 | } |
1789 | |
1790 | // Add operandSegmentSizes attribute. |
1791 | if (!operandSegmentLengths.empty()) { |
1792 | MlirAttribute segmentLengthAttr = |
1793 | mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(), |
1794 | operandSegmentLengths.data()); |
1795 | (*attributes)["operandSegmentSizes" ] = |
1796 | PyAttribute(context, segmentLengthAttr); |
1797 | } |
1798 | } |
1799 | |
1800 | // Delegate to create. |
1801 | return PyOperation::create(name, |
1802 | /*results=*/std::move(resultTypes), |
1803 | /*operands=*/std::move(operands), |
1804 | /*attributes=*/std::move(attributes), |
1805 | /*successors=*/std::move(successors), |
1806 | /*regions=*/*regions, location, maybeIp, |
1807 | !resultTypeList); |
1808 | } |
1809 | |
1810 | pybind11::object PyOpView::constructDerived(const pybind11::object &cls, |
1811 | const PyOperation &operation) { |
1812 | // TODO: pybind11 2.6 supports a more direct form. |
1813 | // Upgrade many years from now. |
1814 | // auto opViewType = py::type::of<PyOpView>(); |
1815 | py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true); |
1816 | py::object instance = cls.attr("__new__" )(cls); |
1817 | opViewType.attr("__init__" )(instance, operation); |
1818 | return instance; |
1819 | } |
1820 | |
1821 | PyOpView::PyOpView(const py::object &operationObject) |
1822 | // Casting through the PyOperationBase base-class and then back to the |
1823 | // Operation lets us accept any PyOperationBase subclass. |
1824 | : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), |
1825 | operationObject(operation.getRef().getObject()) {} |
1826 | |
1827 | //------------------------------------------------------------------------------ |
1828 | // PyInsertionPoint. |
1829 | //------------------------------------------------------------------------------ |
1830 | |
1831 | PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} |
1832 | |
1833 | PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) |
1834 | : refOperation(beforeOperationBase.getOperation().getRef()), |
1835 | block((*refOperation)->getBlock()) {} |
1836 | |
1837 | void PyInsertionPoint::insert(PyOperationBase &operationBase) { |
1838 | PyOperation &operation = operationBase.getOperation(); |
1839 | if (operation.isAttached()) |
1840 | throw py::value_error( |
1841 | "Attempt to insert operation that is already attached" ); |
1842 | block.getParentOperation()->checkValid(); |
1843 | MlirOperation beforeOp = {nullptr}; |
1844 | if (refOperation) { |
1845 | // Insert before operation. |
1846 | (*refOperation)->checkValid(); |
1847 | beforeOp = (*refOperation)->get(); |
1848 | } else { |
1849 | // Insert at end (before null) is only valid if the block does not |
1850 | // already end in a known terminator (violating this will cause assertion |
1851 | // failures later). |
1852 | if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { |
1853 | throw py::index_error("Cannot insert operation at the end of a block " |
1854 | "that already has a terminator. Did you mean to " |
1855 | "use 'InsertionPoint.at_block_terminator(block)' " |
1856 | "versus 'InsertionPoint(block)'?" ); |
1857 | } |
1858 | } |
1859 | mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); |
1860 | operation.setAttached(); |
1861 | } |
1862 | |
1863 | PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { |
1864 | MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); |
1865 | if (mlirOperationIsNull(firstOp)) { |
1866 | // Just insert at end. |
1867 | return PyInsertionPoint(block); |
1868 | } |
1869 | |
1870 | // Insert before first op. |
1871 | PyOperationRef firstOpRef = PyOperation::forOperation( |
1872 | block.getParentOperation()->getContext(), firstOp); |
1873 | return PyInsertionPoint{block, std::move(firstOpRef)}; |
1874 | } |
1875 | |
1876 | PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { |
1877 | MlirOperation terminator = mlirBlockGetTerminator(block.get()); |
1878 | if (mlirOperationIsNull(terminator)) |
1879 | throw py::value_error("Block has no terminator" ); |
1880 | PyOperationRef terminatorOpRef = PyOperation::forOperation( |
1881 | block.getParentOperation()->getContext(), terminator); |
1882 | return PyInsertionPoint{block, std::move(terminatorOpRef)}; |
1883 | } |
1884 | |
1885 | py::object PyInsertionPoint::contextEnter() { |
1886 | return PyThreadContextEntry::pushInsertionPoint(*this); |
1887 | } |
1888 | |
1889 | void PyInsertionPoint::contextExit(const pybind11::object &excType, |
1890 | const pybind11::object &excVal, |
1891 | const pybind11::object &excTb) { |
1892 | PyThreadContextEntry::popInsertionPoint(insertionPoint&: *this); |
1893 | } |
1894 | |
1895 | //------------------------------------------------------------------------------ |
1896 | // PyAttribute. |
1897 | //------------------------------------------------------------------------------ |
1898 | |
1899 | bool PyAttribute::operator==(const PyAttribute &other) const { |
1900 | return mlirAttributeEqual(attr, other.attr); |
1901 | } |
1902 | |
1903 | py::object PyAttribute::getCapsule() { |
1904 | return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); |
1905 | } |
1906 | |
1907 | PyAttribute PyAttribute::createFromCapsule(py::object capsule) { |
1908 | MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); |
1909 | if (mlirAttributeIsNull(rawAttr)) |
1910 | throw py::error_already_set(); |
1911 | return PyAttribute( |
1912 | PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); |
1913 | } |
1914 | |
1915 | //------------------------------------------------------------------------------ |
1916 | // PyNamedAttribute. |
1917 | //------------------------------------------------------------------------------ |
1918 | |
1919 | PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) |
1920 | : ownedName(new std::string(std::move(ownedName))) { |
1921 | namedAttr = mlirNamedAttributeGet( |
1922 | mlirIdentifierGet(mlirAttributeGetContext(attr), |
1923 | toMlirStringRef(*this->ownedName)), |
1924 | attr); |
1925 | } |
1926 | |
1927 | //------------------------------------------------------------------------------ |
1928 | // PyType. |
1929 | //------------------------------------------------------------------------------ |
1930 | |
1931 | bool PyType::operator==(const PyType &other) const { |
1932 | return mlirTypeEqual(type, other.type); |
1933 | } |
1934 | |
1935 | py::object PyType::getCapsule() { |
1936 | return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); |
1937 | } |
1938 | |
1939 | PyType PyType::createFromCapsule(py::object capsule) { |
1940 | MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); |
1941 | if (mlirTypeIsNull(rawType)) |
1942 | throw py::error_already_set(); |
1943 | return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), |
1944 | rawType); |
1945 | } |
1946 | |
1947 | //------------------------------------------------------------------------------ |
1948 | // PyTypeID. |
1949 | //------------------------------------------------------------------------------ |
1950 | |
1951 | py::object PyTypeID::getCapsule() { |
1952 | return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this)); |
1953 | } |
1954 | |
1955 | PyTypeID PyTypeID::createFromCapsule(py::object capsule) { |
1956 | MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); |
1957 | if (mlirTypeIDIsNull(mlirTypeID)) |
1958 | throw py::error_already_set(); |
1959 | return PyTypeID(mlirTypeID); |
1960 | } |
1961 | bool PyTypeID::operator==(const PyTypeID &other) const { |
1962 | return mlirTypeIDEqual(typeID, other.typeID); |
1963 | } |
1964 | |
1965 | //------------------------------------------------------------------------------ |
1966 | // PyValue and subclasses. |
1967 | //------------------------------------------------------------------------------ |
1968 | |
1969 | pybind11::object PyValue::getCapsule() { |
1970 | return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); |
1971 | } |
1972 | |
1973 | pybind11::object PyValue::maybeDownCast() { |
1974 | MlirType type = mlirValueGetType(get()); |
1975 | MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); |
1976 | assert(!mlirTypeIDIsNull(mlirTypeID) && |
1977 | "mlirTypeID was expected to be non-null." ); |
1978 | std::optional<pybind11::function> valueCaster = |
1979 | PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); |
1980 | // py::return_value_policy::move means use std::move to move the return value |
1981 | // contents into a new instance that will be owned by Python. |
1982 | py::object thisObj = py::cast(this, py::return_value_policy::move); |
1983 | if (!valueCaster) |
1984 | return thisObj; |
1985 | return valueCaster.value()(thisObj); |
1986 | } |
1987 | |
1988 | PyValue PyValue::createFromCapsule(pybind11::object capsule) { |
1989 | MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); |
1990 | if (mlirValueIsNull(value)) |
1991 | throw py::error_already_set(); |
1992 | MlirOperation owner; |
1993 | if (mlirValueIsAOpResult(value)) |
1994 | owner = mlirOpResultGetOwner(value); |
1995 | if (mlirValueIsABlockArgument(value)) |
1996 | owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); |
1997 | if (mlirOperationIsNull(owner)) |
1998 | throw py::error_already_set(); |
1999 | MlirContext ctx = mlirOperationGetContext(owner); |
2000 | PyOperationRef ownerRef = |
2001 | PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); |
2002 | return PyValue(ownerRef, value); |
2003 | } |
2004 | |
2005 | //------------------------------------------------------------------------------ |
2006 | // PySymbolTable. |
2007 | //------------------------------------------------------------------------------ |
2008 | |
2009 | PySymbolTable::PySymbolTable(PyOperationBase &operation) |
2010 | : operation(operation.getOperation().getRef()) { |
2011 | symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); |
2012 | if (mlirSymbolTableIsNull(symbolTable)) { |
2013 | throw py::cast_error("Operation is not a Symbol Table." ); |
2014 | } |
2015 | } |
2016 | |
2017 | py::object PySymbolTable::dunderGetItem(const std::string &name) { |
2018 | operation->checkValid(); |
2019 | MlirOperation symbol = mlirSymbolTableLookup( |
2020 | symbolTable, mlirStringRefCreate(name.data(), name.length())); |
2021 | if (mlirOperationIsNull(symbol)) |
2022 | throw py::key_error("Symbol '" + name + "' not in the symbol table." ); |
2023 | |
2024 | return PyOperation::forOperation(operation->getContext(), symbol, |
2025 | operation.getObject()) |
2026 | ->createOpView(); |
2027 | } |
2028 | |
2029 | void PySymbolTable::erase(PyOperationBase &symbol) { |
2030 | operation->checkValid(); |
2031 | symbol.getOperation().checkValid(); |
2032 | mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); |
2033 | // The operation is also erased, so we must invalidate it. There may be Python |
2034 | // references to this operation so we don't want to delete it from the list of |
2035 | // live operations here. |
2036 | symbol.getOperation().valid = false; |
2037 | } |
2038 | |
2039 | void PySymbolTable::dunderDel(const std::string &name) { |
2040 | py::object operation = dunderGetItem(name); |
2041 | erase(py::cast<PyOperationBase &>(operation)); |
2042 | } |
2043 | |
2044 | MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { |
2045 | operation->checkValid(); |
2046 | symbol.getOperation().checkValid(); |
2047 | MlirAttribute symbolAttr = mlirOperationGetAttributeByName( |
2048 | symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); |
2049 | if (mlirAttributeIsNull(symbolAttr)) |
2050 | throw py::value_error("Expected operation to have a symbol name." ); |
2051 | return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); |
2052 | } |
2053 | |
2054 | MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { |
2055 | // Op must already be a symbol. |
2056 | PyOperation &operation = symbol.getOperation(); |
2057 | operation.checkValid(); |
2058 | MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); |
2059 | MlirAttribute existingNameAttr = |
2060 | mlirOperationGetAttributeByName(operation.get(), attrName); |
2061 | if (mlirAttributeIsNull(existingNameAttr)) |
2062 | throw py::value_error("Expected operation to have a symbol name." ); |
2063 | return existingNameAttr; |
2064 | } |
2065 | |
2066 | void PySymbolTable::setSymbolName(PyOperationBase &symbol, |
2067 | const std::string &name) { |
2068 | // Op must already be a symbol. |
2069 | PyOperation &operation = symbol.getOperation(); |
2070 | operation.checkValid(); |
2071 | MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); |
2072 | MlirAttribute existingNameAttr = |
2073 | mlirOperationGetAttributeByName(operation.get(), attrName); |
2074 | if (mlirAttributeIsNull(existingNameAttr)) |
2075 | throw py::value_error("Expected operation to have a symbol name." ); |
2076 | MlirAttribute newNameAttr = |
2077 | mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); |
2078 | mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); |
2079 | } |
2080 | |
2081 | MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { |
2082 | PyOperation &operation = symbol.getOperation(); |
2083 | operation.checkValid(); |
2084 | MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); |
2085 | MlirAttribute existingVisAttr = |
2086 | mlirOperationGetAttributeByName(operation.get(), attrName); |
2087 | if (mlirAttributeIsNull(existingVisAttr)) |
2088 | throw py::value_error("Expected operation to have a symbol visibility." ); |
2089 | return existingVisAttr; |
2090 | } |
2091 | |
2092 | void PySymbolTable::setVisibility(PyOperationBase &symbol, |
2093 | const std::string &visibility) { |
2094 | if (visibility != "public" && visibility != "private" && |
2095 | visibility != "nested" ) |
2096 | throw py::value_error( |
2097 | "Expected visibility to be 'public', 'private' or 'nested'" ); |
2098 | PyOperation &operation = symbol.getOperation(); |
2099 | operation.checkValid(); |
2100 | MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); |
2101 | MlirAttribute existingVisAttr = |
2102 | mlirOperationGetAttributeByName(operation.get(), attrName); |
2103 | if (mlirAttributeIsNull(existingVisAttr)) |
2104 | throw py::value_error("Expected operation to have a symbol visibility." ); |
2105 | MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), |
2106 | toMlirStringRef(visibility)); |
2107 | mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); |
2108 | } |
2109 | |
2110 | void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, |
2111 | const std::string &newSymbol, |
2112 | PyOperationBase &from) { |
2113 | PyOperation &fromOperation = from.getOperation(); |
2114 | fromOperation.checkValid(); |
2115 | if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( |
2116 | toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), |
2117 | from.getOperation()))) |
2118 | |
2119 | throw py::value_error("Symbol rename failed" ); |
2120 | } |
2121 | |
2122 | void PySymbolTable::walkSymbolTables(PyOperationBase &from, |
2123 | bool allSymUsesVisible, |
2124 | py::object callback) { |
2125 | PyOperation &fromOperation = from.getOperation(); |
2126 | fromOperation.checkValid(); |
2127 | struct UserData { |
2128 | PyMlirContextRef context; |
2129 | py::object callback; |
2130 | bool gotException; |
2131 | std::string exceptionWhat; |
2132 | py::object exceptionType; |
2133 | }; |
2134 | UserData userData{ |
2135 | fromOperation.getContext(), std::move(callback), false, {}, {}}; |
2136 | mlirSymbolTableWalkSymbolTables( |
2137 | fromOperation.get(), allSymUsesVisible, |
2138 | [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { |
2139 | UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); |
2140 | auto pyFoundOp = |
2141 | PyOperation::forOperation(calleeUserData->context, foundOp); |
2142 | if (calleeUserData->gotException) |
2143 | return; |
2144 | try { |
2145 | calleeUserData->callback(pyFoundOp.getObject(), isVisible); |
2146 | } catch (py::error_already_set &e) { |
2147 | calleeUserData->gotException = true; |
2148 | calleeUserData->exceptionWhat = e.what(); |
2149 | calleeUserData->exceptionType = e.type(); |
2150 | } |
2151 | }, |
2152 | static_cast<void *>(&userData)); |
2153 | if (userData.gotException) { |
2154 | std::string message("Exception raised in callback: " ); |
2155 | message.append(str: userData.exceptionWhat); |
2156 | throw std::runtime_error(message); |
2157 | } |
2158 | } |
2159 | |
2160 | namespace { |
2161 | /// CRTP base class for Python MLIR values that subclass Value and should be |
2162 | /// castable from it. The value hierarchy is one level deep and is not supposed |
2163 | /// to accommodate other levels unless core MLIR changes. |
2164 | template <typename DerivedTy> |
2165 | class PyConcreteValue : public PyValue { |
2166 | public: |
2167 | // Derived classes must define statics for: |
2168 | // IsAFunctionTy isaFunction |
2169 | // const char *pyClassName |
2170 | // and redefine bindDerived. |
2171 | using ClassTy = py::class_<DerivedTy, PyValue>; |
2172 | using IsAFunctionTy = bool (*)(MlirValue); |
2173 | |
2174 | PyConcreteValue() = default; |
2175 | PyConcreteValue(PyOperationRef operationRef, MlirValue value) |
2176 | : PyValue(operationRef, value) {} |
2177 | PyConcreteValue(PyValue &orig) |
2178 | : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} |
2179 | |
2180 | /// Attempts to cast the original value to the derived type and throws on |
2181 | /// type mismatches. |
2182 | static MlirValue castFrom(PyValue &orig) { |
2183 | if (!DerivedTy::isaFunction(orig.get())) { |
2184 | auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); |
2185 | throw py::value_error((Twine("Cannot cast value to " ) + |
2186 | DerivedTy::pyClassName + " (from " + origRepr + |
2187 | ")" ) |
2188 | .str()); |
2189 | } |
2190 | return orig.get(); |
2191 | } |
2192 | |
2193 | /// Binds the Python module objects to functions of this class. |
2194 | static void bind(py::module &m) { |
2195 | auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); |
2196 | cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value" )); |
2197 | cls.def_static( |
2198 | "isinstance" , |
2199 | [](PyValue &otherValue) -> bool { |
2200 | return DerivedTy::isaFunction(otherValue); |
2201 | }, |
2202 | py::arg("other_value" )); |
2203 | cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, |
2204 | [](DerivedTy &self) { return self.maybeDownCast(); }); |
2205 | DerivedTy::bindDerived(cls); |
2206 | } |
2207 | |
2208 | /// Implemented by derived classes to add methods to the Python subclass. |
2209 | static void bindDerived(ClassTy &m) {} |
2210 | }; |
2211 | |
2212 | /// Python wrapper for MlirBlockArgument. |
2213 | class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { |
2214 | public: |
2215 | static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; |
2216 | static constexpr const char *pyClassName = "BlockArgument" ; |
2217 | using PyConcreteValue::PyConcreteValue; |
2218 | |
2219 | static void bindDerived(ClassTy &c) { |
2220 | c.def_property_readonly("owner" , [](PyBlockArgument &self) { |
2221 | return PyBlock(self.getParentOperation(), |
2222 | mlirBlockArgumentGetOwner(self.get())); |
2223 | }); |
2224 | c.def_property_readonly("arg_number" , [](PyBlockArgument &self) { |
2225 | return mlirBlockArgumentGetArgNumber(self.get()); |
2226 | }); |
2227 | c.def( |
2228 | "set_type" , |
2229 | [](PyBlockArgument &self, PyType type) { |
2230 | return mlirBlockArgumentSetType(self.get(), type); |
2231 | }, |
2232 | py::arg("type" )); |
2233 | } |
2234 | }; |
2235 | |
2236 | /// Python wrapper for MlirOpResult. |
2237 | class PyOpResult : public PyConcreteValue<PyOpResult> { |
2238 | public: |
2239 | static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; |
2240 | static constexpr const char *pyClassName = "OpResult" ; |
2241 | using PyConcreteValue::PyConcreteValue; |
2242 | |
2243 | static void bindDerived(ClassTy &c) { |
2244 | c.def_property_readonly("owner" , [](PyOpResult &self) { |
2245 | assert( |
2246 | mlirOperationEqual(self.getParentOperation()->get(), |
2247 | mlirOpResultGetOwner(self.get())) && |
2248 | "expected the owner of the value in Python to match that in the IR" ); |
2249 | return self.getParentOperation().getObject(); |
2250 | }); |
2251 | c.def_property_readonly("result_number" , [](PyOpResult &self) { |
2252 | return mlirOpResultGetResultNumber(self.get()); |
2253 | }); |
2254 | } |
2255 | }; |
2256 | |
2257 | /// Returns the list of types of the values held by container. |
2258 | template <typename Container> |
2259 | static std::vector<MlirType> getValueTypes(Container &container, |
2260 | PyMlirContextRef &context) { |
2261 | std::vector<MlirType> result; |
2262 | result.reserve(container.size()); |
2263 | for (int i = 0, e = container.size(); i < e; ++i) { |
2264 | result.push_back(mlirValueGetType(container.getElement(i).get())); |
2265 | } |
2266 | return result; |
2267 | } |
2268 | |
2269 | /// A list of block arguments. Internally, these are stored as consecutive |
2270 | /// elements, random access is cheap. The argument list is associated with the |
2271 | /// operation that contains the block (detached blocks are not allowed in |
2272 | /// Python bindings) and extends its lifetime. |
2273 | class PyBlockArgumentList |
2274 | : public Sliceable<PyBlockArgumentList, PyBlockArgument> { |
2275 | public: |
2276 | static constexpr const char *pyClassName = "BlockArgumentList" ; |
2277 | using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>; |
2278 | |
2279 | PyBlockArgumentList(PyOperationRef operation, MlirBlock block, |
2280 | intptr_t startIndex = 0, intptr_t length = -1, |
2281 | intptr_t step = 1) |
2282 | : Sliceable(startIndex, |
2283 | length == -1 ? mlirBlockGetNumArguments(block) : length, |
2284 | step), |
2285 | operation(std::move(operation)), block(block) {} |
2286 | |
2287 | static void bindDerived(ClassTy &c) { |
2288 | c.def_property_readonly("types" , [](PyBlockArgumentList &self) { |
2289 | return getValueTypes(self, self.operation->getContext()); |
2290 | }); |
2291 | } |
2292 | |
2293 | private: |
2294 | /// Give the parent CRTP class access to hook implementations below. |
2295 | friend class Sliceable<PyBlockArgumentList, PyBlockArgument>; |
2296 | |
2297 | /// Returns the number of arguments in the list. |
2298 | intptr_t getRawNumElements() { |
2299 | operation->checkValid(); |
2300 | return mlirBlockGetNumArguments(block); |
2301 | } |
2302 | |
2303 | /// Returns `pos`-the element in the list. |
2304 | PyBlockArgument getRawElement(intptr_t pos) { |
2305 | MlirValue argument = mlirBlockGetArgument(block, pos); |
2306 | return PyBlockArgument(operation, argument); |
2307 | } |
2308 | |
2309 | /// Returns a sublist of this list. |
2310 | PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, |
2311 | intptr_t step) { |
2312 | return PyBlockArgumentList(operation, block, startIndex, length, step); |
2313 | } |
2314 | |
2315 | PyOperationRef operation; |
2316 | MlirBlock block; |
2317 | }; |
2318 | |
2319 | /// A list of operation operands. Internally, these are stored as consecutive |
2320 | /// elements, random access is cheap. The (returned) operand list is associated |
2321 | /// with the operation whose operands these are, and thus extends the lifetime |
2322 | /// of this operation. |
2323 | class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { |
2324 | public: |
2325 | static constexpr const char *pyClassName = "OpOperandList" ; |
2326 | using SliceableT = Sliceable<PyOpOperandList, PyValue>; |
2327 | |
2328 | PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, |
2329 | intptr_t length = -1, intptr_t step = 1) |
2330 | : Sliceable(startIndex, |
2331 | length == -1 ? mlirOperationGetNumOperands(operation->get()) |
2332 | : length, |
2333 | step), |
2334 | operation(operation) {} |
2335 | |
2336 | void dunderSetItem(intptr_t index, PyValue value) { |
2337 | index = wrapIndex(index); |
2338 | mlirOperationSetOperand(operation->get(), index, value.get()); |
2339 | } |
2340 | |
2341 | static void bindDerived(ClassTy &c) { |
2342 | c.def("__setitem__" , &PyOpOperandList::dunderSetItem); |
2343 | } |
2344 | |
2345 | private: |
2346 | /// Give the parent CRTP class access to hook implementations below. |
2347 | friend class Sliceable<PyOpOperandList, PyValue>; |
2348 | |
2349 | intptr_t getRawNumElements() { |
2350 | operation->checkValid(); |
2351 | return mlirOperationGetNumOperands(operation->get()); |
2352 | } |
2353 | |
2354 | PyValue getRawElement(intptr_t pos) { |
2355 | MlirValue operand = mlirOperationGetOperand(operation->get(), pos); |
2356 | MlirOperation owner; |
2357 | if (mlirValueIsAOpResult(operand)) |
2358 | owner = mlirOpResultGetOwner(operand); |
2359 | else if (mlirValueIsABlockArgument(operand)) |
2360 | owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); |
2361 | else |
2362 | assert(false && "Value must be an block arg or op result." ); |
2363 | PyOperationRef pyOwner = |
2364 | PyOperation::forOperation(operation->getContext(), owner); |
2365 | return PyValue(pyOwner, operand); |
2366 | } |
2367 | |
2368 | PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
2369 | return PyOpOperandList(operation, startIndex, length, step); |
2370 | } |
2371 | |
2372 | PyOperationRef operation; |
2373 | }; |
2374 | |
2375 | /// A list of operation results. Internally, these are stored as consecutive |
2376 | /// elements, random access is cheap. The (returned) result list is associated |
2377 | /// with the operation whose results these are, and thus extends the lifetime of |
2378 | /// this operation. |
2379 | class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { |
2380 | public: |
2381 | static constexpr const char *pyClassName = "OpResultList" ; |
2382 | using SliceableT = Sliceable<PyOpResultList, PyOpResult>; |
2383 | |
2384 | PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, |
2385 | intptr_t length = -1, intptr_t step = 1) |
2386 | : Sliceable(startIndex, |
2387 | length == -1 ? mlirOperationGetNumResults(operation->get()) |
2388 | : length, |
2389 | step), |
2390 | operation(std::move(operation)) {} |
2391 | |
2392 | static void bindDerived(ClassTy &c) { |
2393 | c.def_property_readonly("types" , [](PyOpResultList &self) { |
2394 | return getValueTypes(self, self.operation->getContext()); |
2395 | }); |
2396 | c.def_property_readonly("owner" , [](PyOpResultList &self) { |
2397 | return self.operation->createOpView(); |
2398 | }); |
2399 | } |
2400 | |
2401 | private: |
2402 | /// Give the parent CRTP class access to hook implementations below. |
2403 | friend class Sliceable<PyOpResultList, PyOpResult>; |
2404 | |
2405 | intptr_t getRawNumElements() { |
2406 | operation->checkValid(); |
2407 | return mlirOperationGetNumResults(operation->get()); |
2408 | } |
2409 | |
2410 | PyOpResult getRawElement(intptr_t index) { |
2411 | PyValue value(operation, mlirOperationGetResult(operation->get(), index)); |
2412 | return PyOpResult(value); |
2413 | } |
2414 | |
2415 | PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
2416 | return PyOpResultList(operation, startIndex, length, step); |
2417 | } |
2418 | |
2419 | PyOperationRef operation; |
2420 | }; |
2421 | |
2422 | /// A list of operation successors. Internally, these are stored as consecutive |
2423 | /// elements, random access is cheap. The (returned) successor list is |
2424 | /// associated with the operation whose successors these are, and thus extends |
2425 | /// the lifetime of this operation. |
2426 | class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> { |
2427 | public: |
2428 | static constexpr const char *pyClassName = "OpSuccessors" ; |
2429 | |
2430 | PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0, |
2431 | intptr_t length = -1, intptr_t step = 1) |
2432 | : Sliceable(startIndex, |
2433 | length == -1 ? mlirOperationGetNumSuccessors(operation->get()) |
2434 | : length, |
2435 | step), |
2436 | operation(operation) {} |
2437 | |
2438 | void dunderSetItem(intptr_t index, PyBlock block) { |
2439 | index = wrapIndex(index); |
2440 | mlirOperationSetSuccessor(operation->get(), index, block.get()); |
2441 | } |
2442 | |
2443 | static void bindDerived(ClassTy &c) { |
2444 | c.def("__setitem__" , &PyOpSuccessors::dunderSetItem); |
2445 | } |
2446 | |
2447 | private: |
2448 | /// Give the parent CRTP class access to hook implementations below. |
2449 | friend class Sliceable<PyOpSuccessors, PyBlock>; |
2450 | |
2451 | intptr_t getRawNumElements() { |
2452 | operation->checkValid(); |
2453 | return mlirOperationGetNumSuccessors(operation->get()); |
2454 | } |
2455 | |
2456 | PyBlock getRawElement(intptr_t pos) { |
2457 | MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos); |
2458 | return PyBlock(operation, block); |
2459 | } |
2460 | |
2461 | PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
2462 | return PyOpSuccessors(operation, startIndex, length, step); |
2463 | } |
2464 | |
2465 | PyOperationRef operation; |
2466 | }; |
2467 | |
2468 | /// A list of operation attributes. Can be indexed by name, producing |
2469 | /// attributes, or by index, producing named attributes. |
2470 | class PyOpAttributeMap { |
2471 | public: |
2472 | PyOpAttributeMap(PyOperationRef operation) |
2473 | : operation(std::move(operation)) {} |
2474 | |
2475 | MlirAttribute dunderGetItemNamed(const std::string &name) { |
2476 | MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), |
2477 | toMlirStringRef(name)); |
2478 | if (mlirAttributeIsNull(attr)) { |
2479 | throw py::key_error("attempt to access a non-existent attribute" ); |
2480 | } |
2481 | return attr; |
2482 | } |
2483 | |
2484 | PyNamedAttribute dunderGetItemIndexed(intptr_t index) { |
2485 | if (index < 0 || index >= dunderLen()) { |
2486 | throw py::index_error("attempt to access out of bounds attribute" ); |
2487 | } |
2488 | MlirNamedAttribute namedAttr = |
2489 | mlirOperationGetAttribute(operation->get(), index); |
2490 | return PyNamedAttribute( |
2491 | namedAttr.attribute, |
2492 | std::string(mlirIdentifierStr(namedAttr.name).data, |
2493 | mlirIdentifierStr(namedAttr.name).length)); |
2494 | } |
2495 | |
2496 | void dunderSetItem(const std::string &name, const PyAttribute &attr) { |
2497 | mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), |
2498 | attr); |
2499 | } |
2500 | |
2501 | void dunderDelItem(const std::string &name) { |
2502 | int removed = mlirOperationRemoveAttributeByName(operation->get(), |
2503 | toMlirStringRef(name)); |
2504 | if (!removed) |
2505 | throw py::key_error("attempt to delete a non-existent attribute" ); |
2506 | } |
2507 | |
2508 | intptr_t dunderLen() { |
2509 | return mlirOperationGetNumAttributes(operation->get()); |
2510 | } |
2511 | |
2512 | bool dunderContains(const std::string &name) { |
2513 | return !mlirAttributeIsNull(mlirOperationGetAttributeByName( |
2514 | operation->get(), toMlirStringRef(name))); |
2515 | } |
2516 | |
2517 | static void bind(py::module &m) { |
2518 | py::class_<PyOpAttributeMap>(m, "OpAttributeMap" , py::module_local()) |
2519 | .def("__contains__" , &PyOpAttributeMap::dunderContains) |
2520 | .def("__len__" , &PyOpAttributeMap::dunderLen) |
2521 | .def("__getitem__" , &PyOpAttributeMap::dunderGetItemNamed) |
2522 | .def("__getitem__" , &PyOpAttributeMap::dunderGetItemIndexed) |
2523 | .def("__setitem__" , &PyOpAttributeMap::dunderSetItem) |
2524 | .def("__delitem__" , &PyOpAttributeMap::dunderDelItem); |
2525 | } |
2526 | |
2527 | private: |
2528 | PyOperationRef operation; |
2529 | }; |
2530 | |
2531 | } // namespace |
2532 | |
2533 | //------------------------------------------------------------------------------ |
2534 | // Populates the core exports of the 'ir' submodule. |
2535 | //------------------------------------------------------------------------------ |
2536 | |
2537 | void mlir::python::populateIRCore(py::module &m) { |
2538 | //---------------------------------------------------------------------------- |
2539 | // Enums. |
2540 | //---------------------------------------------------------------------------- |
2541 | py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity" , py::module_local()) |
2542 | .value("ERROR" , MlirDiagnosticError) |
2543 | .value("WARNING" , MlirDiagnosticWarning) |
2544 | .value("NOTE" , MlirDiagnosticNote) |
2545 | .value("REMARK" , MlirDiagnosticRemark); |
2546 | |
2547 | py::enum_<MlirWalkOrder>(m, "WalkOrder" , py::module_local()) |
2548 | .value("PRE_ORDER" , MlirWalkPreOrder) |
2549 | .value("POST_ORDER" , MlirWalkPostOrder); |
2550 | |
2551 | py::enum_<MlirWalkResult>(m, "WalkResult" , py::module_local()) |
2552 | .value("ADVANCE" , MlirWalkResultAdvance) |
2553 | .value("INTERRUPT" , MlirWalkResultInterrupt) |
2554 | .value("SKIP" , MlirWalkResultSkip); |
2555 | |
2556 | //---------------------------------------------------------------------------- |
2557 | // Mapping of Diagnostics. |
2558 | //---------------------------------------------------------------------------- |
2559 | py::class_<PyDiagnostic>(m, "Diagnostic" , py::module_local()) |
2560 | .def_property_readonly("severity" , &PyDiagnostic::getSeverity) |
2561 | .def_property_readonly("location" , &PyDiagnostic::getLocation) |
2562 | .def_property_readonly("message" , &PyDiagnostic::getMessage) |
2563 | .def_property_readonly("notes" , &PyDiagnostic::getNotes) |
2564 | .def("__str__" , [](PyDiagnostic &self) -> py::str { |
2565 | if (!self.isValid()) |
2566 | return "<Invalid Diagnostic>" ; |
2567 | return self.getMessage(); |
2568 | }); |
2569 | |
2570 | py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo" , |
2571 | py::module_local()) |
2572 | .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); })) |
2573 | .def_readonly("severity" , &PyDiagnostic::DiagnosticInfo::severity) |
2574 | .def_readonly("location" , &PyDiagnostic::DiagnosticInfo::location) |
2575 | .def_readonly("message" , &PyDiagnostic::DiagnosticInfo::message) |
2576 | .def_readonly("notes" , &PyDiagnostic::DiagnosticInfo::notes) |
2577 | .def("__str__" , |
2578 | [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); |
2579 | |
2580 | py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler" , py::module_local()) |
2581 | .def("detach" , &PyDiagnosticHandler::detach) |
2582 | .def_property_readonly("attached" , &PyDiagnosticHandler::isAttached) |
2583 | .def_property_readonly("had_error" , &PyDiagnosticHandler::getHadError) |
2584 | .def("__enter__" , &PyDiagnosticHandler::contextEnter) |
2585 | .def("__exit__" , &PyDiagnosticHandler::contextExit); |
2586 | |
2587 | //---------------------------------------------------------------------------- |
2588 | // Mapping of MlirContext. |
2589 | // Note that this is exported as _BaseContext. The containing, Python level |
2590 | // __init__.py will subclass it with site-specific functionality and set a |
2591 | // "Context" attribute on this module. |
2592 | //---------------------------------------------------------------------------- |
2593 | py::class_<PyMlirContext>(m, "_BaseContext" , py::module_local()) |
2594 | .def(py::init<>(&PyMlirContext::createNewContextForInit)) |
2595 | .def_static("_get_live_count" , &PyMlirContext::getLiveCount) |
2596 | .def("_get_context_again" , |
2597 | [](PyMlirContext &self) { |
2598 | PyMlirContextRef ref = PyMlirContext::forContext(self.get()); |
2599 | return ref.releaseObject(); |
2600 | }) |
2601 | .def("_get_live_operation_count" , &PyMlirContext::getLiveOperationCount) |
2602 | .def("_get_live_operation_objects" , |
2603 | &PyMlirContext::getLiveOperationObjects) |
2604 | .def("_clear_live_operations" , &PyMlirContext::clearLiveOperations) |
2605 | .def("_clear_live_operations_inside" , |
2606 | py::overload_cast<MlirOperation>( |
2607 | &PyMlirContext::clearOperationsInside)) |
2608 | .def("_get_live_module_count" , &PyMlirContext::getLiveModuleCount) |
2609 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
2610 | &PyMlirContext::getCapsule) |
2611 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) |
2612 | .def("__enter__" , &PyMlirContext::contextEnter) |
2613 | .def("__exit__" , &PyMlirContext::contextExit) |
2614 | .def_property_readonly_static( |
2615 | "current" , |
2616 | [](py::object & /*class*/) { |
2617 | auto *context = PyThreadContextEntry::getDefaultContext(); |
2618 | if (!context) |
2619 | return py::none().cast<py::object>(); |
2620 | return py::cast(context); |
2621 | }, |
2622 | "Gets the Context bound to the current thread or raises ValueError" ) |
2623 | .def_property_readonly( |
2624 | "dialects" , |
2625 | [](PyMlirContext &self) { return PyDialects(self.getRef()); }, |
2626 | "Gets a container for accessing dialects by name" ) |
2627 | .def_property_readonly( |
2628 | "d" , [](PyMlirContext &self) { return PyDialects(self.getRef()); }, |
2629 | "Alias for 'dialect'" ) |
2630 | .def( |
2631 | "get_dialect_descriptor" , |
2632 | [=](PyMlirContext &self, std::string &name) { |
2633 | MlirDialect dialect = mlirContextGetOrLoadDialect( |
2634 | self.get(), {name.data(), name.size()}); |
2635 | if (mlirDialectIsNull(dialect)) { |
2636 | throw py::value_error( |
2637 | (Twine("Dialect '" ) + name + "' not found" ).str()); |
2638 | } |
2639 | return PyDialectDescriptor(self.getRef(), dialect); |
2640 | }, |
2641 | py::arg("dialect_name" ), |
2642 | "Gets or loads a dialect by name, returning its descriptor object" ) |
2643 | .def_property( |
2644 | "allow_unregistered_dialects" , |
2645 | [](PyMlirContext &self) -> bool { |
2646 | return mlirContextGetAllowUnregisteredDialects(self.get()); |
2647 | }, |
2648 | [](PyMlirContext &self, bool value) { |
2649 | mlirContextSetAllowUnregisteredDialects(self.get(), value); |
2650 | }) |
2651 | .def("attach_diagnostic_handler" , &PyMlirContext::attachDiagnosticHandler, |
2652 | py::arg("callback" ), |
2653 | "Attaches a diagnostic handler that will receive callbacks" ) |
2654 | .def( |
2655 | "enable_multithreading" , |
2656 | [](PyMlirContext &self, bool enable) { |
2657 | mlirContextEnableMultithreading(self.get(), enable); |
2658 | }, |
2659 | py::arg("enable" )) |
2660 | .def( |
2661 | "is_registered_operation" , |
2662 | [](PyMlirContext &self, std::string &name) { |
2663 | return mlirContextIsRegisteredOperation( |
2664 | self.get(), MlirStringRef{name.data(), name.size()}); |
2665 | }, |
2666 | py::arg("operation_name" )) |
2667 | .def( |
2668 | "append_dialect_registry" , |
2669 | [](PyMlirContext &self, PyDialectRegistry ®istry) { |
2670 | mlirContextAppendDialectRegistry(self.get(), registry); |
2671 | }, |
2672 | py::arg("registry" )) |
2673 | .def_property("emit_error_diagnostics" , nullptr, |
2674 | &PyMlirContext::setEmitErrorDiagnostics, |
2675 | "Emit error diagnostics to diagnostic handlers. By default " |
2676 | "error diagnostics are captured and reported through " |
2677 | "MLIRError exceptions." ) |
2678 | .def("load_all_available_dialects" , [](PyMlirContext &self) { |
2679 | mlirContextLoadAllAvailableDialects(self.get()); |
2680 | }); |
2681 | |
2682 | //---------------------------------------------------------------------------- |
2683 | // Mapping of PyDialectDescriptor |
2684 | //---------------------------------------------------------------------------- |
2685 | py::class_<PyDialectDescriptor>(m, "DialectDescriptor" , py::module_local()) |
2686 | .def_property_readonly("namespace" , |
2687 | [](PyDialectDescriptor &self) { |
2688 | MlirStringRef ns = |
2689 | mlirDialectGetNamespace(self.get()); |
2690 | return py::str(ns.data, ns.length); |
2691 | }) |
2692 | .def("__repr__" , [](PyDialectDescriptor &self) { |
2693 | MlirStringRef ns = mlirDialectGetNamespace(self.get()); |
2694 | std::string repr("<DialectDescriptor " ); |
2695 | repr.append(ns.data, ns.length); |
2696 | repr.append(">" ); |
2697 | return repr; |
2698 | }); |
2699 | |
2700 | //---------------------------------------------------------------------------- |
2701 | // Mapping of PyDialects |
2702 | //---------------------------------------------------------------------------- |
2703 | py::class_<PyDialects>(m, "Dialects" , py::module_local()) |
2704 | .def("__getitem__" , |
2705 | [=](PyDialects &self, std::string keyName) { |
2706 | MlirDialect dialect = |
2707 | self.getDialectForKey(keyName, /*attrError=*/false); |
2708 | py::object descriptor = |
2709 | py::cast(PyDialectDescriptor{self.getContext(), dialect}); |
2710 | return createCustomDialectWrapper(keyName, std::move(descriptor)); |
2711 | }) |
2712 | .def("__getattr__" , [=](PyDialects &self, std::string attrName) { |
2713 | MlirDialect dialect = |
2714 | self.getDialectForKey(attrName, /*attrError=*/true); |
2715 | py::object descriptor = |
2716 | py::cast(PyDialectDescriptor{self.getContext(), dialect}); |
2717 | return createCustomDialectWrapper(attrName, std::move(descriptor)); |
2718 | }); |
2719 | |
2720 | //---------------------------------------------------------------------------- |
2721 | // Mapping of PyDialect |
2722 | //---------------------------------------------------------------------------- |
2723 | py::class_<PyDialect>(m, "Dialect" , py::module_local()) |
2724 | .def(py::init<py::object>(), py::arg("descriptor" )) |
2725 | .def_property_readonly( |
2726 | "descriptor" , [](PyDialect &self) { return self.getDescriptor(); }) |
2727 | .def("__repr__" , [](py::object self) { |
2728 | auto clazz = self.attr("__class__" ); |
2729 | return py::str("<Dialect " ) + |
2730 | self.attr("descriptor" ).attr("namespace" ) + py::str(" (class " ) + |
2731 | clazz.attr("__module__" ) + py::str("." ) + |
2732 | clazz.attr("__name__" ) + py::str(")>" ); |
2733 | }); |
2734 | |
2735 | //---------------------------------------------------------------------------- |
2736 | // Mapping of PyDialectRegistry |
2737 | //---------------------------------------------------------------------------- |
2738 | py::class_<PyDialectRegistry>(m, "DialectRegistry" , py::module_local()) |
2739 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
2740 | &PyDialectRegistry::getCapsule) |
2741 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) |
2742 | .def(py::init<>()); |
2743 | |
2744 | //---------------------------------------------------------------------------- |
2745 | // Mapping of Location |
2746 | //---------------------------------------------------------------------------- |
2747 | py::class_<PyLocation>(m, "Location" , py::module_local()) |
2748 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) |
2749 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) |
2750 | .def("__enter__" , &PyLocation::contextEnter) |
2751 | .def("__exit__" , &PyLocation::contextExit) |
2752 | .def("__eq__" , |
2753 | [](PyLocation &self, PyLocation &other) -> bool { |
2754 | return mlirLocationEqual(self, other); |
2755 | }) |
2756 | .def("__eq__" , [](PyLocation &self, py::object other) { return false; }) |
2757 | .def_property_readonly_static( |
2758 | "current" , |
2759 | [](py::object & /*class*/) { |
2760 | auto *loc = PyThreadContextEntry::getDefaultLocation(); |
2761 | if (!loc) |
2762 | throw py::value_error("No current Location" ); |
2763 | return loc; |
2764 | }, |
2765 | "Gets the Location bound to the current thread or raises ValueError" ) |
2766 | .def_static( |
2767 | "unknown" , |
2768 | [](DefaultingPyMlirContext context) { |
2769 | return PyLocation(context->getRef(), |
2770 | mlirLocationUnknownGet(context->get())); |
2771 | }, |
2772 | py::arg("context" ) = py::none(), |
2773 | "Gets a Location representing an unknown location" ) |
2774 | .def_static( |
2775 | "callsite" , |
2776 | [](PyLocation callee, const std::vector<PyLocation> &frames, |
2777 | DefaultingPyMlirContext context) { |
2778 | if (frames.empty()) |
2779 | throw py::value_error("No caller frames provided" ); |
2780 | MlirLocation caller = frames.back().get(); |
2781 | for (const PyLocation &frame : |
2782 | llvm::reverse(llvm::ArrayRef(frames).drop_back())) |
2783 | caller = mlirLocationCallSiteGet(frame.get(), caller); |
2784 | return PyLocation(context->getRef(), |
2785 | mlirLocationCallSiteGet(callee.get(), caller)); |
2786 | }, |
2787 | py::arg("callee" ), py::arg("frames" ), py::arg("context" ) = py::none(), |
2788 | kContextGetCallSiteLocationDocstring) |
2789 | .def_static( |
2790 | "file" , |
2791 | [](std::string filename, int line, int col, |
2792 | DefaultingPyMlirContext context) { |
2793 | return PyLocation( |
2794 | context->getRef(), |
2795 | mlirLocationFileLineColGet( |
2796 | context->get(), toMlirStringRef(filename), line, col)); |
2797 | }, |
2798 | py::arg("filename" ), py::arg("line" ), py::arg("col" ), |
2799 | py::arg("context" ) = py::none(), kContextGetFileLocationDocstring) |
2800 | .def_static( |
2801 | "fused" , |
2802 | [](const std::vector<PyLocation> &pyLocations, |
2803 | std::optional<PyAttribute> metadata, |
2804 | DefaultingPyMlirContext context) { |
2805 | llvm::SmallVector<MlirLocation, 4> locations; |
2806 | locations.reserve(pyLocations.size()); |
2807 | for (auto &pyLocation : pyLocations) |
2808 | locations.push_back(pyLocation.get()); |
2809 | MlirLocation location = mlirLocationFusedGet( |
2810 | context->get(), locations.size(), locations.data(), |
2811 | metadata ? metadata->get() : MlirAttribute{0}); |
2812 | return PyLocation(context->getRef(), location); |
2813 | }, |
2814 | py::arg("locations" ), py::arg("metadata" ) = py::none(), |
2815 | py::arg("context" ) = py::none(), kContextGetFusedLocationDocstring) |
2816 | .def_static( |
2817 | "name" , |
2818 | [](std::string name, std::optional<PyLocation> childLoc, |
2819 | DefaultingPyMlirContext context) { |
2820 | return PyLocation( |
2821 | context->getRef(), |
2822 | mlirLocationNameGet( |
2823 | context->get(), toMlirStringRef(name), |
2824 | childLoc ? childLoc->get() |
2825 | : mlirLocationUnknownGet(context->get()))); |
2826 | }, |
2827 | py::arg("name" ), py::arg("childLoc" ) = py::none(), |
2828 | py::arg("context" ) = py::none(), kContextGetNameLocationDocString) |
2829 | .def_static( |
2830 | "from_attr" , |
2831 | [](PyAttribute &attribute, DefaultingPyMlirContext context) { |
2832 | return PyLocation(context->getRef(), |
2833 | mlirLocationFromAttribute(attribute)); |
2834 | }, |
2835 | py::arg("attribute" ), py::arg("context" ) = py::none(), |
2836 | "Gets a Location from a LocationAttr" ) |
2837 | .def_property_readonly( |
2838 | "context" , |
2839 | [](PyLocation &self) { return self.getContext().getObject(); }, |
2840 | "Context that owns the Location" ) |
2841 | .def_property_readonly( |
2842 | "attr" , |
2843 | [](PyLocation &self) { return mlirLocationGetAttribute(self); }, |
2844 | "Get the underlying LocationAttr" ) |
2845 | .def( |
2846 | "emit_error" , |
2847 | [](PyLocation &self, std::string message) { |
2848 | mlirEmitError(self, message.c_str()); |
2849 | }, |
2850 | py::arg("message" ), "Emits an error at this location" ) |
2851 | .def("__repr__" , [](PyLocation &self) { |
2852 | PyPrintAccumulator printAccum; |
2853 | mlirLocationPrint(self, printAccum.getCallback(), |
2854 | printAccum.getUserData()); |
2855 | return printAccum.join(); |
2856 | }); |
2857 | |
2858 | //---------------------------------------------------------------------------- |
2859 | // Mapping of Module |
2860 | //---------------------------------------------------------------------------- |
2861 | py::class_<PyModule>(m, "Module" , py::module_local()) |
2862 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) |
2863 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) |
2864 | .def_static( |
2865 | "parse" , |
2866 | [](const std::string &moduleAsm, DefaultingPyMlirContext context) { |
2867 | PyMlirContext::ErrorCapture errors(context->getRef()); |
2868 | MlirModule module = mlirModuleCreateParse( |
2869 | context->get(), toMlirStringRef(moduleAsm)); |
2870 | if (mlirModuleIsNull(module)) |
2871 | throw MLIRError("Unable to parse module assembly" , errors.take()); |
2872 | return PyModule::forModule(module).releaseObject(); |
2873 | }, |
2874 | py::arg("asm" ), py::arg("context" ) = py::none(), |
2875 | kModuleParseDocstring) |
2876 | .def_static( |
2877 | "create" , |
2878 | [](DefaultingPyLocation loc) { |
2879 | MlirModule module = mlirModuleCreateEmpty(loc); |
2880 | return PyModule::forModule(module).releaseObject(); |
2881 | }, |
2882 | py::arg("loc" ) = py::none(), "Creates an empty module" ) |
2883 | .def_property_readonly( |
2884 | "context" , |
2885 | [](PyModule &self) { return self.getContext().getObject(); }, |
2886 | "Context that created the Module" ) |
2887 | .def_property_readonly( |
2888 | "operation" , |
2889 | [](PyModule &self) { |
2890 | return PyOperation::forOperation(self.getContext(), |
2891 | mlirModuleGetOperation(self.get()), |
2892 | self.getRef().releaseObject()) |
2893 | .releaseObject(); |
2894 | }, |
2895 | "Accesses the module as an operation" ) |
2896 | .def_property_readonly( |
2897 | "body" , |
2898 | [](PyModule &self) { |
2899 | PyOperationRef moduleOp = PyOperation::forOperation( |
2900 | self.getContext(), mlirModuleGetOperation(self.get()), |
2901 | self.getRef().releaseObject()); |
2902 | PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); |
2903 | return returnBlock; |
2904 | }, |
2905 | "Return the block for this module" ) |
2906 | .def( |
2907 | "dump" , |
2908 | [](PyModule &self) { |
2909 | mlirOperationDump(mlirModuleGetOperation(self.get())); |
2910 | }, |
2911 | kDumpDocstring) |
2912 | .def( |
2913 | "__str__" , |
2914 | [](py::object self) { |
2915 | // Defer to the operation's __str__. |
2916 | return self.attr("operation" ).attr("__str__" )(); |
2917 | }, |
2918 | kOperationStrDunderDocstring); |
2919 | |
2920 | //---------------------------------------------------------------------------- |
2921 | // Mapping of Operation. |
2922 | //---------------------------------------------------------------------------- |
2923 | py::class_<PyOperationBase>(m, "_OperationBase" , py::module_local()) |
2924 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
2925 | [](PyOperationBase &self) { |
2926 | return self.getOperation().getCapsule(); |
2927 | }) |
2928 | .def("__eq__" , |
2929 | [](PyOperationBase &self, PyOperationBase &other) { |
2930 | return &self.getOperation() == &other.getOperation(); |
2931 | }) |
2932 | .def("__eq__" , |
2933 | [](PyOperationBase &self, py::object other) { return false; }) |
2934 | .def("__hash__" , |
2935 | [](PyOperationBase &self) { |
2936 | return static_cast<size_t>(llvm::hash_value(&self.getOperation())); |
2937 | }) |
2938 | .def_property_readonly("attributes" , |
2939 | [](PyOperationBase &self) { |
2940 | return PyOpAttributeMap( |
2941 | self.getOperation().getRef()); |
2942 | }) |
2943 | .def_property_readonly( |
2944 | "context" , |
2945 | [](PyOperationBase &self) { |
2946 | PyOperation &concreteOperation = self.getOperation(); |
2947 | concreteOperation.checkValid(); |
2948 | return concreteOperation.getContext().getObject(); |
2949 | }, |
2950 | "Context that owns the Operation" ) |
2951 | .def_property_readonly("name" , |
2952 | [](PyOperationBase &self) { |
2953 | auto &concreteOperation = self.getOperation(); |
2954 | concreteOperation.checkValid(); |
2955 | MlirOperation operation = |
2956 | concreteOperation.get(); |
2957 | MlirStringRef name = mlirIdentifierStr( |
2958 | mlirOperationGetName(operation)); |
2959 | return py::str(name.data, name.length); |
2960 | }) |
2961 | .def_property_readonly("operands" , |
2962 | [](PyOperationBase &self) { |
2963 | return PyOpOperandList( |
2964 | self.getOperation().getRef()); |
2965 | }) |
2966 | .def_property_readonly("regions" , |
2967 | [](PyOperationBase &self) { |
2968 | return PyRegionList( |
2969 | self.getOperation().getRef()); |
2970 | }) |
2971 | .def_property_readonly( |
2972 | "results" , |
2973 | [](PyOperationBase &self) { |
2974 | return PyOpResultList(self.getOperation().getRef()); |
2975 | }, |
2976 | "Returns the list of Operation results." ) |
2977 | .def_property_readonly( |
2978 | "result" , |
2979 | [](PyOperationBase &self) { |
2980 | auto &operation = self.getOperation(); |
2981 | auto numResults = mlirOperationGetNumResults(operation); |
2982 | if (numResults != 1) { |
2983 | auto name = mlirIdentifierStr(mlirOperationGetName(operation)); |
2984 | throw py::value_error( |
2985 | (Twine("Cannot call .result on operation " ) + |
2986 | StringRef(name.data, name.length) + " which has " + |
2987 | Twine(numResults) + |
2988 | " results (it is only valid for operations with a " |
2989 | "single result)" ) |
2990 | .str()); |
2991 | } |
2992 | return PyOpResult(operation.getRef(), |
2993 | mlirOperationGetResult(operation, 0)) |
2994 | .maybeDownCast(); |
2995 | }, |
2996 | "Shortcut to get an op result if it has only one (throws an error " |
2997 | "otherwise)." ) |
2998 | .def_property_readonly( |
2999 | "location" , |
3000 | [](PyOperationBase &self) { |
3001 | PyOperation &operation = self.getOperation(); |
3002 | return PyLocation(operation.getContext(), |
3003 | mlirOperationGetLocation(operation.get())); |
3004 | }, |
3005 | "Returns the source location the operation was defined or derived " |
3006 | "from." ) |
3007 | .def_property_readonly("parent" , |
3008 | [](PyOperationBase &self) -> py::object { |
3009 | auto parent = |
3010 | self.getOperation().getParentOperation(); |
3011 | if (parent) |
3012 | return parent->getObject(); |
3013 | return py::none(); |
3014 | }) |
3015 | .def( |
3016 | "__str__" , |
3017 | [](PyOperationBase &self) { |
3018 | return self.getAsm(/*binary=*/false, |
3019 | /*largeElementsLimit=*/std::nullopt, |
3020 | /*enableDebugInfo=*/false, |
3021 | /*prettyDebugInfo=*/false, |
3022 | /*printGenericOpForm=*/false, |
3023 | /*useLocalScope=*/false, |
3024 | /*assumeVerified=*/false); |
3025 | }, |
3026 | "Returns the assembly form of the operation." ) |
3027 | .def("print" , |
3028 | py::overload_cast<PyAsmState &, pybind11::object, bool>( |
3029 | &PyOperationBase::print), |
3030 | py::arg("state" ), py::arg("file" ) = py::none(), |
3031 | py::arg("binary" ) = false, kOperationPrintStateDocstring) |
3032 | .def("print" , |
3033 | py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool, |
3034 | bool, py::object, bool>(&PyOperationBase::print), |
3035 | // Careful: Lots of arguments must match up with print method. |
3036 | py::arg("large_elements_limit" ) = py::none(), |
3037 | py::arg("enable_debug_info" ) = false, |
3038 | py::arg("pretty_debug_info" ) = false, |
3039 | py::arg("print_generic_op_form" ) = false, |
3040 | py::arg("use_local_scope" ) = false, |
3041 | py::arg("assume_verified" ) = false, py::arg("file" ) = py::none(), |
3042 | py::arg("binary" ) = false, kOperationPrintDocstring) |
3043 | .def("write_bytecode" , &PyOperationBase::writeBytecode, py::arg("file" ), |
3044 | py::arg("desired_version" ) = py::none(), |
3045 | kOperationPrintBytecodeDocstring) |
3046 | .def("get_asm" , &PyOperationBase::getAsm, |
3047 | // Careful: Lots of arguments must match up with get_asm method. |
3048 | py::arg("binary" ) = false, |
3049 | py::arg("large_elements_limit" ) = py::none(), |
3050 | py::arg("enable_debug_info" ) = false, |
3051 | py::arg("pretty_debug_info" ) = false, |
3052 | py::arg("print_generic_op_form" ) = false, |
3053 | py::arg("use_local_scope" ) = false, |
3054 | py::arg("assume_verified" ) = false, kOperationGetAsmDocstring) |
3055 | .def("verify" , &PyOperationBase::verify, |
3056 | "Verify the operation. Raises MLIRError if verification fails, and " |
3057 | "returns true otherwise." ) |
3058 | .def("move_after" , &PyOperationBase::moveAfter, py::arg("other" ), |
3059 | "Puts self immediately after the other operation in its parent " |
3060 | "block." ) |
3061 | .def("move_before" , &PyOperationBase::moveBefore, py::arg("other" ), |
3062 | "Puts self immediately before the other operation in its parent " |
3063 | "block." ) |
3064 | .def( |
3065 | "clone" , |
3066 | [](PyOperationBase &self, py::object ip) { |
3067 | return self.getOperation().clone(ip); |
3068 | }, |
3069 | py::arg("ip" ) = py::none()) |
3070 | .def( |
3071 | "detach_from_parent" , |
3072 | [](PyOperationBase &self) { |
3073 | PyOperation &operation = self.getOperation(); |
3074 | operation.checkValid(); |
3075 | if (!operation.isAttached()) |
3076 | throw py::value_error("Detached operation has no parent." ); |
3077 | |
3078 | operation.detachFromParent(); |
3079 | return operation.createOpView(); |
3080 | }, |
3081 | "Detaches the operation from its parent block." ) |
3082 | .def("erase" , [](PyOperationBase &self) { self.getOperation().erase(); }) |
3083 | .def("walk" , &PyOperationBase::walk, py::arg("callback" ), |
3084 | py::arg("walk_order" ) = MlirWalkPostOrder); |
3085 | |
3086 | py::class_<PyOperation, PyOperationBase>(m, "Operation" , py::module_local()) |
3087 | .def_static("create" , &PyOperation::create, py::arg("name" ), |
3088 | py::arg("results" ) = py::none(), |
3089 | py::arg("operands" ) = py::none(), |
3090 | py::arg("attributes" ) = py::none(), |
3091 | py::arg("successors" ) = py::none(), py::arg("regions" ) = 0, |
3092 | py::arg("loc" ) = py::none(), py::arg("ip" ) = py::none(), |
3093 | py::arg("infer_type" ) = false, kOperationCreateDocstring) |
3094 | .def_static( |
3095 | "parse" , |
3096 | [](const std::string &sourceStr, const std::string &sourceName, |
3097 | DefaultingPyMlirContext context) { |
3098 | return PyOperation::parse(context->getRef(), sourceStr, sourceName) |
3099 | ->createOpView(); |
3100 | }, |
3101 | py::arg("source" ), py::kw_only(), py::arg("source_name" ) = "" , |
3102 | py::arg("context" ) = py::none(), |
3103 | "Parses an operation. Supports both text assembly format and binary " |
3104 | "bytecode format." ) |
3105 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
3106 | &PyOperation::getCapsule) |
3107 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) |
3108 | .def_property_readonly("operation" , [](py::object self) { return self; }) |
3109 | .def_property_readonly("opview" , &PyOperation::createOpView) |
3110 | .def_property_readonly( |
3111 | "successors" , |
3112 | [](PyOperationBase &self) { |
3113 | return PyOpSuccessors(self.getOperation().getRef()); |
3114 | }, |
3115 | "Returns the list of Operation successors." ); |
3116 | |
3117 | auto opViewClass = |
3118 | py::class_<PyOpView, PyOperationBase>(m, "OpView" , py::module_local()) |
3119 | .def(py::init<py::object>(), py::arg("operation" )) |
3120 | .def_property_readonly("operation" , &PyOpView::getOperationObject) |
3121 | .def_property_readonly("opview" , [](py::object self) { return self; }) |
3122 | .def( |
3123 | "__str__" , |
3124 | [](PyOpView &self) { return py::str(self.getOperationObject()); }) |
3125 | .def_property_readonly( |
3126 | "successors" , |
3127 | [](PyOperationBase &self) { |
3128 | return PyOpSuccessors(self.getOperation().getRef()); |
3129 | }, |
3130 | "Returns the list of Operation successors." ); |
3131 | opViewClass.attr("_ODS_REGIONS" ) = py::make_tuple(0, true); |
3132 | opViewClass.attr("_ODS_OPERAND_SEGMENTS" ) = py::none(); |
3133 | opViewClass.attr("_ODS_RESULT_SEGMENTS" ) = py::none(); |
3134 | opViewClass.attr("build_generic" ) = classmethod( |
3135 | &PyOpView::buildGeneric, py::arg("cls" ), py::arg("results" ) = py::none(), |
3136 | py::arg("operands" ) = py::none(), py::arg("attributes" ) = py::none(), |
3137 | py::arg("successors" ) = py::none(), py::arg("regions" ) = py::none(), |
3138 | py::arg("loc" ) = py::none(), py::arg("ip" ) = py::none(), |
3139 | "Builds a specific, generated OpView based on class level attributes." ); |
3140 | opViewClass.attr("parse" ) = classmethod( |
3141 | [](const py::object &cls, const std::string &sourceStr, |
3142 | const std::string &sourceName, DefaultingPyMlirContext context) { |
3143 | PyOperationRef parsed = |
3144 | PyOperation::parse(context->getRef(), sourceStr, sourceName); |
3145 | |
3146 | // Check if the expected operation was parsed, and cast to to the |
3147 | // appropriate `OpView` subclass if successful. |
3148 | // NOTE: This accesses attributes that have been automatically added to |
3149 | // `OpView` subclasses, and is not intended to be used on `OpView` |
3150 | // directly. |
3151 | std::string clsOpName = |
3152 | py::cast<std::string>(cls.attr("OPERATION_NAME" )); |
3153 | MlirStringRef identifier = |
3154 | mlirIdentifierStr(mlirOperationGetName(*parsed.get())); |
3155 | std::string_view parsedOpName(identifier.data, identifier.length); |
3156 | if (clsOpName != parsedOpName) |
3157 | throw MLIRError(Twine("Expected a '" ) + clsOpName + "' op, got: '" + |
3158 | parsedOpName + "'" ); |
3159 | return PyOpView::constructDerived(cls, *parsed.get()); |
3160 | }, |
3161 | py::arg("cls" ), py::arg("source" ), py::kw_only(), |
3162 | py::arg("source_name" ) = "" , py::arg("context" ) = py::none(), |
3163 | "Parses a specific, generated OpView based on class level attributes" ); |
3164 | |
3165 | //---------------------------------------------------------------------------- |
3166 | // Mapping of PyRegion. |
3167 | //---------------------------------------------------------------------------- |
3168 | py::class_<PyRegion>(m, "Region" , py::module_local()) |
3169 | .def_property_readonly( |
3170 | "blocks" , |
3171 | [](PyRegion &self) { |
3172 | return PyBlockList(self.getParentOperation(), self.get()); |
3173 | }, |
3174 | "Returns a forward-optimized sequence of blocks." ) |
3175 | .def_property_readonly( |
3176 | "owner" , |
3177 | [](PyRegion &self) { |
3178 | return self.getParentOperation()->createOpView(); |
3179 | }, |
3180 | "Returns the operation owning this region." ) |
3181 | .def( |
3182 | "__iter__" , |
3183 | [](PyRegion &self) { |
3184 | self.checkValid(); |
3185 | MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); |
3186 | return PyBlockIterator(self.getParentOperation(), firstBlock); |
3187 | }, |
3188 | "Iterates over blocks in the region." ) |
3189 | .def("__eq__" , |
3190 | [](PyRegion &self, PyRegion &other) { |
3191 | return self.get().ptr == other.get().ptr; |
3192 | }) |
3193 | .def("__eq__" , [](PyRegion &self, py::object &other) { return false; }); |
3194 | |
3195 | //---------------------------------------------------------------------------- |
3196 | // Mapping of PyBlock. |
3197 | //---------------------------------------------------------------------------- |
3198 | py::class_<PyBlock>(m, "Block" , py::module_local()) |
3199 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) |
3200 | .def_property_readonly( |
3201 | "owner" , |
3202 | [](PyBlock &self) { |
3203 | return self.getParentOperation()->createOpView(); |
3204 | }, |
3205 | "Returns the owning operation of this block." ) |
3206 | .def_property_readonly( |
3207 | "region" , |
3208 | [](PyBlock &self) { |
3209 | MlirRegion region = mlirBlockGetParentRegion(self.get()); |
3210 | return PyRegion(self.getParentOperation(), region); |
3211 | }, |
3212 | "Returns the owning region of this block." ) |
3213 | .def_property_readonly( |
3214 | "arguments" , |
3215 | [](PyBlock &self) { |
3216 | return PyBlockArgumentList(self.getParentOperation(), self.get()); |
3217 | }, |
3218 | "Returns a list of block arguments." ) |
3219 | .def_property_readonly( |
3220 | "operations" , |
3221 | [](PyBlock &self) { |
3222 | return PyOperationList(self.getParentOperation(), self.get()); |
3223 | }, |
3224 | "Returns a forward-optimized sequence of operations." ) |
3225 | .def_static( |
3226 | "create_at_start" , |
3227 | [](PyRegion &parent, const py::list &pyArgTypes, |
3228 | const std::optional<py::sequence> &pyArgLocs) { |
3229 | parent.checkValid(); |
3230 | MlirBlock block = createBlock(pyArgTypes, pyArgLocs); |
3231 | mlirRegionInsertOwnedBlock(parent, 0, block); |
3232 | return PyBlock(parent.getParentOperation(), block); |
3233 | }, |
3234 | py::arg("parent" ), py::arg("arg_types" ) = py::list(), |
3235 | py::arg("arg_locs" ) = std::nullopt, |
3236 | "Creates and returns a new Block at the beginning of the given " |
3237 | "region (with given argument types and locations)." ) |
3238 | .def( |
3239 | "append_to" , |
3240 | [](PyBlock &self, PyRegion ®ion) { |
3241 | MlirBlock b = self.get(); |
3242 | if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) |
3243 | mlirBlockDetach(b); |
3244 | mlirRegionAppendOwnedBlock(region.get(), b); |
3245 | }, |
3246 | "Append this block to a region, transferring ownership if necessary" ) |
3247 | .def( |
3248 | "create_before" , |
3249 | [](PyBlock &self, const py::args &pyArgTypes, |
3250 | const std::optional<py::sequence> &pyArgLocs) { |
3251 | self.checkValid(); |
3252 | MlirBlock block = createBlock(pyArgTypes, pyArgLocs); |
3253 | MlirRegion region = mlirBlockGetParentRegion(self.get()); |
3254 | mlirRegionInsertOwnedBlockBefore(region, self.get(), block); |
3255 | return PyBlock(self.getParentOperation(), block); |
3256 | }, |
3257 | py::arg("arg_locs" ) = std::nullopt, |
3258 | "Creates and returns a new Block before this block " |
3259 | "(with given argument types and locations)." ) |
3260 | .def( |
3261 | "create_after" , |
3262 | [](PyBlock &self, const py::args &pyArgTypes, |
3263 | const std::optional<py::sequence> &pyArgLocs) { |
3264 | self.checkValid(); |
3265 | MlirBlock block = createBlock(pyArgTypes, pyArgLocs); |
3266 | MlirRegion region = mlirBlockGetParentRegion(self.get()); |
3267 | mlirRegionInsertOwnedBlockAfter(region, self.get(), block); |
3268 | return PyBlock(self.getParentOperation(), block); |
3269 | }, |
3270 | py::arg("arg_locs" ) = std::nullopt, |
3271 | "Creates and returns a new Block after this block " |
3272 | "(with given argument types and locations)." ) |
3273 | .def( |
3274 | "__iter__" , |
3275 | [](PyBlock &self) { |
3276 | self.checkValid(); |
3277 | MlirOperation firstOperation = |
3278 | mlirBlockGetFirstOperation(self.get()); |
3279 | return PyOperationIterator(self.getParentOperation(), |
3280 | firstOperation); |
3281 | }, |
3282 | "Iterates over operations in the block." ) |
3283 | .def("__eq__" , |
3284 | [](PyBlock &self, PyBlock &other) { |
3285 | return self.get().ptr == other.get().ptr; |
3286 | }) |
3287 | .def("__eq__" , [](PyBlock &self, py::object &other) { return false; }) |
3288 | .def("__hash__" , |
3289 | [](PyBlock &self) { |
3290 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
3291 | }) |
3292 | .def( |
3293 | "__str__" , |
3294 | [](PyBlock &self) { |
3295 | self.checkValid(); |
3296 | PyPrintAccumulator printAccum; |
3297 | mlirBlockPrint(self.get(), printAccum.getCallback(), |
3298 | printAccum.getUserData()); |
3299 | return printAccum.join(); |
3300 | }, |
3301 | "Returns the assembly form of the block." ) |
3302 | .def( |
3303 | "append" , |
3304 | [](PyBlock &self, PyOperationBase &operation) { |
3305 | if (operation.getOperation().isAttached()) |
3306 | operation.getOperation().detachFromParent(); |
3307 | |
3308 | MlirOperation mlirOperation = operation.getOperation().get(); |
3309 | mlirBlockAppendOwnedOperation(self.get(), mlirOperation); |
3310 | operation.getOperation().setAttached( |
3311 | self.getParentOperation().getObject()); |
3312 | }, |
3313 | py::arg("operation" ), |
3314 | "Appends an operation to this block. If the operation is currently " |
3315 | "in another block, it will be moved." ); |
3316 | |
3317 | //---------------------------------------------------------------------------- |
3318 | // Mapping of PyInsertionPoint. |
3319 | //---------------------------------------------------------------------------- |
3320 | |
3321 | py::class_<PyInsertionPoint>(m, "InsertionPoint" , py::module_local()) |
3322 | .def(py::init<PyBlock &>(), py::arg("block" ), |
3323 | "Inserts after the last operation but still inside the block." ) |
3324 | .def("__enter__" , &PyInsertionPoint::contextEnter) |
3325 | .def("__exit__" , &PyInsertionPoint::contextExit) |
3326 | .def_property_readonly_static( |
3327 | "current" , |
3328 | [](py::object & /*class*/) { |
3329 | auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); |
3330 | if (!ip) |
3331 | throw py::value_error("No current InsertionPoint" ); |
3332 | return ip; |
3333 | }, |
3334 | "Gets the InsertionPoint bound to the current thread or raises " |
3335 | "ValueError if none has been set" ) |
3336 | .def(py::init<PyOperationBase &>(), py::arg("beforeOperation" ), |
3337 | "Inserts before a referenced operation." ) |
3338 | .def_static("at_block_begin" , &PyInsertionPoint::atBlockBegin, |
3339 | py::arg("block" ), "Inserts at the beginning of the block." ) |
3340 | .def_static("at_block_terminator" , &PyInsertionPoint::atBlockTerminator, |
3341 | py::arg("block" ), "Inserts before the block terminator." ) |
3342 | .def("insert" , &PyInsertionPoint::insert, py::arg("operation" ), |
3343 | "Inserts an operation." ) |
3344 | .def_property_readonly( |
3345 | "block" , [](PyInsertionPoint &self) { return self.getBlock(); }, |
3346 | "Returns the block that this InsertionPoint points to." ) |
3347 | .def_property_readonly( |
3348 | "ref_operation" , |
3349 | [](PyInsertionPoint &self) -> py::object { |
3350 | auto refOperation = self.getRefOperation(); |
3351 | if (refOperation) |
3352 | return refOperation->getObject(); |
3353 | return py::none(); |
3354 | }, |
3355 | "The reference operation before which new operations are " |
3356 | "inserted, or None if the insertion point is at the end of " |
3357 | "the block" ); |
3358 | |
3359 | //---------------------------------------------------------------------------- |
3360 | // Mapping of PyAttribute. |
3361 | //---------------------------------------------------------------------------- |
3362 | py::class_<PyAttribute>(m, "Attribute" , py::module_local()) |
3363 | // Delegate to the PyAttribute copy constructor, which will also lifetime |
3364 | // extend the backing context which owns the MlirAttribute. |
3365 | .def(py::init<PyAttribute &>(), py::arg("cast_from_type" ), |
3366 | "Casts the passed attribute to the generic Attribute" ) |
3367 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
3368 | &PyAttribute::getCapsule) |
3369 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) |
3370 | .def_static( |
3371 | "parse" , |
3372 | [](const std::string &attrSpec, DefaultingPyMlirContext context) { |
3373 | PyMlirContext::ErrorCapture errors(context->getRef()); |
3374 | MlirAttribute attr = mlirAttributeParseGet( |
3375 | context->get(), toMlirStringRef(attrSpec)); |
3376 | if (mlirAttributeIsNull(attr)) |
3377 | throw MLIRError("Unable to parse attribute" , errors.take()); |
3378 | return attr; |
3379 | }, |
3380 | py::arg("asm" ), py::arg("context" ) = py::none(), |
3381 | "Parses an attribute from an assembly form. Raises an MLIRError on " |
3382 | "failure." ) |
3383 | .def_property_readonly( |
3384 | "context" , |
3385 | [](PyAttribute &self) { return self.getContext().getObject(); }, |
3386 | "Context that owns the Attribute" ) |
3387 | .def_property_readonly( |
3388 | "type" , [](PyAttribute &self) { return mlirAttributeGetType(self); }) |
3389 | .def( |
3390 | "get_named" , |
3391 | [](PyAttribute &self, std::string name) { |
3392 | return PyNamedAttribute(self, std::move(name)); |
3393 | }, |
3394 | py::keep_alive<0, 1>(), "Binds a name to the attribute" ) |
3395 | .def("__eq__" , |
3396 | [](PyAttribute &self, PyAttribute &other) { return self == other; }) |
3397 | .def("__eq__" , [](PyAttribute &self, py::object &other) { return false; }) |
3398 | .def("__hash__" , |
3399 | [](PyAttribute &self) { |
3400 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
3401 | }) |
3402 | .def( |
3403 | "dump" , [](PyAttribute &self) { mlirAttributeDump(self); }, |
3404 | kDumpDocstring) |
3405 | .def( |
3406 | "__str__" , |
3407 | [](PyAttribute &self) { |
3408 | PyPrintAccumulator printAccum; |
3409 | mlirAttributePrint(self, printAccum.getCallback(), |
3410 | printAccum.getUserData()); |
3411 | return printAccum.join(); |
3412 | }, |
3413 | "Returns the assembly form of the Attribute." ) |
3414 | .def("__repr__" , |
3415 | [](PyAttribute &self) { |
3416 | // Generally, assembly formats are not printed for __repr__ because |
3417 | // this can cause exceptionally long debug output and exceptions. |
3418 | // However, attribute values are generally considered useful and |
3419 | // are printed. This may need to be re-evaluated if debug dumps end |
3420 | // up being excessive. |
3421 | PyPrintAccumulator printAccum; |
3422 | printAccum.parts.append("Attribute(" ); |
3423 | mlirAttributePrint(self, printAccum.getCallback(), |
3424 | printAccum.getUserData()); |
3425 | printAccum.parts.append(")" ); |
3426 | return printAccum.join(); |
3427 | }) |
3428 | .def_property_readonly( |
3429 | "typeid" , |
3430 | [](PyAttribute &self) -> MlirTypeID { |
3431 | MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); |
3432 | assert(!mlirTypeIDIsNull(mlirTypeID) && |
3433 | "mlirTypeID was expected to be non-null." ); |
3434 | return mlirTypeID; |
3435 | }) |
3436 | .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { |
3437 | MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); |
3438 | assert(!mlirTypeIDIsNull(mlirTypeID) && |
3439 | "mlirTypeID was expected to be non-null." ); |
3440 | std::optional<pybind11::function> typeCaster = |
3441 | PyGlobals::get().lookupTypeCaster(mlirTypeID, |
3442 | mlirAttributeGetDialect(self)); |
3443 | if (!typeCaster) |
3444 | return py::cast(self); |
3445 | return typeCaster.value()(self); |
3446 | }); |
3447 | |
3448 | //---------------------------------------------------------------------------- |
3449 | // Mapping of PyNamedAttribute |
3450 | //---------------------------------------------------------------------------- |
3451 | py::class_<PyNamedAttribute>(m, "NamedAttribute" , py::module_local()) |
3452 | .def("__repr__" , |
3453 | [](PyNamedAttribute &self) { |
3454 | PyPrintAccumulator printAccum; |
3455 | printAccum.parts.append("NamedAttribute(" ); |
3456 | printAccum.parts.append( |
3457 | py::str(mlirIdentifierStr(self.namedAttr.name).data, |
3458 | mlirIdentifierStr(self.namedAttr.name).length)); |
3459 | printAccum.parts.append("=" ); |
3460 | mlirAttributePrint(self.namedAttr.attribute, |
3461 | printAccum.getCallback(), |
3462 | printAccum.getUserData()); |
3463 | printAccum.parts.append(")" ); |
3464 | return printAccum.join(); |
3465 | }) |
3466 | .def_property_readonly( |
3467 | "name" , |
3468 | [](PyNamedAttribute &self) { |
3469 | return py::str(mlirIdentifierStr(self.namedAttr.name).data, |
3470 | mlirIdentifierStr(self.namedAttr.name).length); |
3471 | }, |
3472 | "The name of the NamedAttribute binding" ) |
3473 | .def_property_readonly( |
3474 | "attr" , |
3475 | [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, |
3476 | py::keep_alive<0, 1>(), |
3477 | "The underlying generic attribute of the NamedAttribute binding" ); |
3478 | |
3479 | //---------------------------------------------------------------------------- |
3480 | // Mapping of PyType. |
3481 | //---------------------------------------------------------------------------- |
3482 | py::class_<PyType>(m, "Type" , py::module_local()) |
3483 | // Delegate to the PyType copy constructor, which will also lifetime |
3484 | // extend the backing context which owns the MlirType. |
3485 | .def(py::init<PyType &>(), py::arg("cast_from_type" ), |
3486 | "Casts the passed type to the generic Type" ) |
3487 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) |
3488 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) |
3489 | .def_static( |
3490 | "parse" , |
3491 | [](std::string typeSpec, DefaultingPyMlirContext context) { |
3492 | PyMlirContext::ErrorCapture errors(context->getRef()); |
3493 | MlirType type = |
3494 | mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); |
3495 | if (mlirTypeIsNull(type)) |
3496 | throw MLIRError("Unable to parse type" , errors.take()); |
3497 | return type; |
3498 | }, |
3499 | py::arg("asm" ), py::arg("context" ) = py::none(), |
3500 | kContextParseTypeDocstring) |
3501 | .def_property_readonly( |
3502 | "context" , [](PyType &self) { return self.getContext().getObject(); }, |
3503 | "Context that owns the Type" ) |
3504 | .def("__eq__" , [](PyType &self, PyType &other) { return self == other; }) |
3505 | .def("__eq__" , [](PyType &self, py::object &other) { return false; }) |
3506 | .def("__hash__" , |
3507 | [](PyType &self) { |
3508 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
3509 | }) |
3510 | .def( |
3511 | "dump" , [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) |
3512 | .def( |
3513 | "__str__" , |
3514 | [](PyType &self) { |
3515 | PyPrintAccumulator printAccum; |
3516 | mlirTypePrint(self, printAccum.getCallback(), |
3517 | printAccum.getUserData()); |
3518 | return printAccum.join(); |
3519 | }, |
3520 | "Returns the assembly form of the type." ) |
3521 | .def("__repr__" , |
3522 | [](PyType &self) { |
3523 | // Generally, assembly formats are not printed for __repr__ because |
3524 | // this can cause exceptionally long debug output and exceptions. |
3525 | // However, types are an exception as they typically have compact |
3526 | // assembly forms and printing them is useful. |
3527 | PyPrintAccumulator printAccum; |
3528 | printAccum.parts.append("Type(" ); |
3529 | mlirTypePrint(self, printAccum.getCallback(), |
3530 | printAccum.getUserData()); |
3531 | printAccum.parts.append(")" ); |
3532 | return printAccum.join(); |
3533 | }) |
3534 | .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, |
3535 | [](PyType &self) { |
3536 | MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); |
3537 | assert(!mlirTypeIDIsNull(mlirTypeID) && |
3538 | "mlirTypeID was expected to be non-null." ); |
3539 | std::optional<pybind11::function> typeCaster = |
3540 | PyGlobals::get().lookupTypeCaster(mlirTypeID, |
3541 | mlirTypeGetDialect(self)); |
3542 | if (!typeCaster) |
3543 | return py::cast(self); |
3544 | return typeCaster.value()(self); |
3545 | }) |
3546 | .def_property_readonly("typeid" , [](PyType &self) -> MlirTypeID { |
3547 | MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); |
3548 | if (!mlirTypeIDIsNull(mlirTypeID)) |
3549 | return mlirTypeID; |
3550 | auto origRepr = |
3551 | pybind11::repr(pybind11::cast(self)).cast<std::string>(); |
3552 | throw py::value_error( |
3553 | (origRepr + llvm::Twine(" has no typeid." )).str()); |
3554 | }); |
3555 | |
3556 | //---------------------------------------------------------------------------- |
3557 | // Mapping of PyTypeID. |
3558 | //---------------------------------------------------------------------------- |
3559 | py::class_<PyTypeID>(m, "TypeID" , py::module_local()) |
3560 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) |
3561 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) |
3562 | // Note, this tests whether the underlying TypeIDs are the same, |
3563 | // not whether the wrapper MlirTypeIDs are the same, nor whether |
3564 | // the Python objects are the same (i.e., PyTypeID is a value type). |
3565 | .def("__eq__" , |
3566 | [](PyTypeID &self, PyTypeID &other) { return self == other; }) |
3567 | .def("__eq__" , |
3568 | [](PyTypeID &self, const py::object &other) { return false; }) |
3569 | // Note, this gives the hash value of the underlying TypeID, not the |
3570 | // hash value of the Python object, nor the hash value of the |
3571 | // MlirTypeID wrapper. |
3572 | .def("__hash__" , [](PyTypeID &self) { |
3573 | return static_cast<size_t>(mlirTypeIDHashValue(self)); |
3574 | }); |
3575 | |
3576 | //---------------------------------------------------------------------------- |
3577 | // Mapping of Value. |
3578 | //---------------------------------------------------------------------------- |
3579 | py::class_<PyValue>(m, "Value" , py::module_local()) |
3580 | .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value" )) |
3581 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) |
3582 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) |
3583 | .def_property_readonly( |
3584 | "context" , |
3585 | [](PyValue &self) { return self.getParentOperation()->getContext(); }, |
3586 | "Context in which the value lives." ) |
3587 | .def( |
3588 | "dump" , [](PyValue &self) { mlirValueDump(self.get()); }, |
3589 | kDumpDocstring) |
3590 | .def_property_readonly( |
3591 | "owner" , |
3592 | [](PyValue &self) -> py::object { |
3593 | MlirValue v = self.get(); |
3594 | if (mlirValueIsAOpResult(v)) { |
3595 | assert( |
3596 | mlirOperationEqual(self.getParentOperation()->get(), |
3597 | mlirOpResultGetOwner(self.get())) && |
3598 | "expected the owner of the value in Python to match that in " |
3599 | "the IR" ); |
3600 | return self.getParentOperation().getObject(); |
3601 | } |
3602 | |
3603 | if (mlirValueIsABlockArgument(v)) { |
3604 | MlirBlock block = mlirBlockArgumentGetOwner(self.get()); |
3605 | return py::cast(PyBlock(self.getParentOperation(), block)); |
3606 | } |
3607 | |
3608 | assert(false && "Value must be a block argument or an op result" ); |
3609 | return py::none(); |
3610 | }) |
3611 | .def_property_readonly("uses" , |
3612 | [](PyValue &self) { |
3613 | return PyOpOperandIterator( |
3614 | mlirValueGetFirstUse(self.get())); |
3615 | }) |
3616 | .def("__eq__" , |
3617 | [](PyValue &self, PyValue &other) { |
3618 | return self.get().ptr == other.get().ptr; |
3619 | }) |
3620 | .def("__eq__" , [](PyValue &self, py::object other) { return false; }) |
3621 | .def("__hash__" , |
3622 | [](PyValue &self) { |
3623 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
3624 | }) |
3625 | .def( |
3626 | "__str__" , |
3627 | [](PyValue &self) { |
3628 | PyPrintAccumulator printAccum; |
3629 | printAccum.parts.append("Value(" ); |
3630 | mlirValuePrint(self.get(), printAccum.getCallback(), |
3631 | printAccum.getUserData()); |
3632 | printAccum.parts.append(")" ); |
3633 | return printAccum.join(); |
3634 | }, |
3635 | kValueDunderStrDocstring) |
3636 | .def( |
3637 | "get_name" , |
3638 | [](PyValue &self, bool useLocalScope) { |
3639 | PyPrintAccumulator printAccum; |
3640 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); |
3641 | if (useLocalScope) |
3642 | mlirOpPrintingFlagsUseLocalScope(flags); |
3643 | MlirAsmState valueState = |
3644 | mlirAsmStateCreateForValue(self.get(), flags); |
3645 | mlirValuePrintAsOperand(self.get(), valueState, |
3646 | printAccum.getCallback(), |
3647 | printAccum.getUserData()); |
3648 | mlirOpPrintingFlagsDestroy(flags); |
3649 | mlirAsmStateDestroy(valueState); |
3650 | return printAccum.join(); |
3651 | }, |
3652 | py::arg("use_local_scope" ) = false) |
3653 | .def( |
3654 | "get_name" , |
3655 | [](PyValue &self, std::reference_wrapper<PyAsmState> state) { |
3656 | PyPrintAccumulator printAccum; |
3657 | MlirAsmState valueState = state.get().get(); |
3658 | mlirValuePrintAsOperand(self.get(), valueState, |
3659 | printAccum.getCallback(), |
3660 | printAccum.getUserData()); |
3661 | return printAccum.join(); |
3662 | }, |
3663 | py::arg("state" ), kGetNameAsOperand) |
3664 | .def_property_readonly( |
3665 | "type" , [](PyValue &self) { return mlirValueGetType(self.get()); }) |
3666 | .def( |
3667 | "set_type" , |
3668 | [](PyValue &self, const PyType &type) { |
3669 | return mlirValueSetType(self.get(), type); |
3670 | }, |
3671 | py::arg("type" )) |
3672 | .def( |
3673 | "replace_all_uses_with" , |
3674 | [](PyValue &self, PyValue &with) { |
3675 | mlirValueReplaceAllUsesOfWith(self.get(), with.get()); |
3676 | }, |
3677 | kValueReplaceAllUsesWithDocstring) |
3678 | .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, |
3679 | [](PyValue &self) { return self.maybeDownCast(); }); |
3680 | PyBlockArgument::bind(m); |
3681 | PyOpResult::bind(m); |
3682 | PyOpOperand::bind(m); |
3683 | |
3684 | py::class_<PyAsmState>(m, "AsmState" , py::module_local()) |
3685 | .def(py::init<PyValue &, bool>(), py::arg("value" ), |
3686 | py::arg("use_local_scope" ) = false) |
3687 | .def(py::init<PyOperationBase &, bool>(), py::arg("op" ), |
3688 | py::arg("use_local_scope" ) = false); |
3689 | |
3690 | //---------------------------------------------------------------------------- |
3691 | // Mapping of SymbolTable. |
3692 | //---------------------------------------------------------------------------- |
3693 | py::class_<PySymbolTable>(m, "SymbolTable" , py::module_local()) |
3694 | .def(py::init<PyOperationBase &>()) |
3695 | .def("__getitem__" , &PySymbolTable::dunderGetItem) |
3696 | .def("insert" , &PySymbolTable::insert, py::arg("operation" )) |
3697 | .def("erase" , &PySymbolTable::erase, py::arg("operation" )) |
3698 | .def("__delitem__" , &PySymbolTable::dunderDel) |
3699 | .def("__contains__" , |
3700 | [](PySymbolTable &table, const std::string &name) { |
3701 | return !mlirOperationIsNull(mlirSymbolTableLookup( |
3702 | table, mlirStringRefCreate(name.data(), name.length()))); |
3703 | }) |
3704 | // Static helpers. |
3705 | .def_static("set_symbol_name" , &PySymbolTable::setSymbolName, |
3706 | py::arg("symbol" ), py::arg("name" )) |
3707 | .def_static("get_symbol_name" , &PySymbolTable::getSymbolName, |
3708 | py::arg("symbol" )) |
3709 | .def_static("get_visibility" , &PySymbolTable::getVisibility, |
3710 | py::arg("symbol" )) |
3711 | .def_static("set_visibility" , &PySymbolTable::setVisibility, |
3712 | py::arg("symbol" ), py::arg("visibility" )) |
3713 | .def_static("replace_all_symbol_uses" , |
3714 | &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol" ), |
3715 | py::arg("new_symbol" ), py::arg("from_op" )) |
3716 | .def_static("walk_symbol_tables" , &PySymbolTable::walkSymbolTables, |
3717 | py::arg("from_op" ), py::arg("all_sym_uses_visible" ), |
3718 | py::arg("callback" )); |
3719 | |
3720 | // Container bindings. |
3721 | PyBlockArgumentList::bind(m); |
3722 | PyBlockIterator::bind(m); |
3723 | PyBlockList::bind(m); |
3724 | PyOperationIterator::bind(m); |
3725 | PyOperationList::bind(m); |
3726 | PyOpAttributeMap::bind(m); |
3727 | PyOpOperandIterator::bind(m); |
3728 | PyOpOperandList::bind(m); |
3729 | PyOpResultList::bind(m); |
3730 | PyOpSuccessors::bind(m); |
3731 | PyRegionIterator::bind(m); |
3732 | PyRegionList::bind(m); |
3733 | |
3734 | // Debug bindings. |
3735 | PyGlobalDebugFlag::bind(m); |
3736 | |
3737 | // Attribute builder getter. |
3738 | PyAttrBuilderMap::bind(m); |
3739 | |
3740 | py::register_local_exception_translator([](std::exception_ptr p) { |
3741 | // We can't define exceptions with custom fields through pybind, so instead |
3742 | // the exception class is defined in python and imported here. |
3743 | try { |
3744 | if (p) |
3745 | std::rethrow_exception(p); |
3746 | } catch (const MLIRError &e) { |
3747 | py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir" )) |
3748 | .attr("MLIRError" )(e.message, e.errorDiagnostics); |
3749 | PyErr_SetObject(PyExc_Exception, obj.ptr()); |
3750 | } |
3751 | }); |
3752 | } |
3753 | |