1//===- MainModule.cpp - Main pybind module --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PybindUtils.h"
10
11#include "Globals.h"
12#include "IRModule.h"
13#include "Pass.h"
14
15namespace py = pybind11;
16using namespace mlir;
17using namespace py::literals;
18using namespace mlir::python;
19
20// -----------------------------------------------------------------------------
21// Module initialization.
22// -----------------------------------------------------------------------------
23
24PYBIND11_MODULE(_mlir, m) {
25 m.doc() = "MLIR Python Native Extension";
26
27 py::class_<PyGlobals>(m, "_Globals", py::module_local())
28 .def_property("dialect_search_modules",
29 &PyGlobals::getDialectSearchPrefixes,
30 &PyGlobals::setDialectSearchPrefixes)
31 .def(
32 "append_dialect_search_prefix",
33 [](PyGlobals &self, std::string moduleName) {
34 self.getDialectSearchPrefixes().push_back(std::move(moduleName));
35 },
36 "module_name"_a)
37 .def(
38 "_check_dialect_module_loaded",
39 [](PyGlobals &self, const std::string &dialectNamespace) {
40 return self.loadDialectModule(dialectNamespace);
41 },
42 "dialect_namespace"_a)
43 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
44 "dialect_namespace"_a, "dialect_class"_a,
45 "Testing hook for directly registering a dialect")
46 .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
47 "operation_name"_a, "operation_class"_a, py::kw_only(),
48 "replace"_a = false,
49 "Testing hook for directly registering an operation");
50
51 // Aside from making the globals accessible to python, having python manage
52 // it is necessary to make sure it is destroyed (and releases its python
53 // resources) properly.
54 m.attr("globals") =
55 py::cast(new PyGlobals, py::return_value_policy::take_ownership);
56
57 // Registration decorators.
58 m.def(
59 "register_dialect",
60 [](py::object pyClass) {
61 std::string dialectNamespace =
62 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
63 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
64 return pyClass;
65 },
66 "dialect_class"_a,
67 "Class decorator for registering a custom Dialect wrapper");
68 m.def(
69 "register_operation",
70 [](const py::object &dialectClass, bool replace) -> py::cpp_function {
71 return py::cpp_function(
72 [dialectClass, replace](py::object opClass) -> py::object {
73 std::string operationName =
74 opClass.attr("OPERATION_NAME").cast<std::string>();
75 PyGlobals::get().registerOperationImpl(operationName, opClass,
76 replace);
77
78 // Dict-stuff the new opClass by name onto the dialect class.
79 py::object opClassName = opClass.attr("__name__");
80 dialectClass.attr(opClassName) = opClass;
81 return opClass;
82 });
83 },
84 "dialect_class"_a, py::kw_only(), "replace"_a = false,
85 "Produce a class decorator for registering an Operation class as part of "
86 "a dialect");
87 m.def(
88 MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
89 [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
90 return py::cpp_function([mlirTypeID,
91 replace](py::object typeCaster) -> py::object {
92 PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
93 return typeCaster;
94 });
95 },
96 "typeid"_a, py::kw_only(), "replace"_a = false,
97 "Register a type caster for casting MLIR types to custom user types.");
98 m.def(
99 MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
100 [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
101 return py::cpp_function(
102 [mlirTypeID, replace](py::object valueCaster) -> py::object {
103 PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
104 replace);
105 return valueCaster;
106 });
107 },
108 "typeid"_a, py::kw_only(), "replace"_a = false,
109 "Register a value caster for casting MLIR values to custom user values.");
110
111 // Define and populate IR submodule.
112 auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
113 populateIRCore(irModule);
114 populateIRAffine(irModule);
115 populateIRAttributes(irModule);
116 populateIRInterfaces(irModule);
117 populateIRTypes(irModule);
118
119 // Define and populate PassManager submodule.
120 auto passModule =
121 m.def_submodule("passmanager", "MLIR Pass Management Bindings");
122 populatePassManagerSubmodule(passModule);
123}
124

source code of mlir/lib/Bindings/Python/MainModule.cpp