1 | //===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "NanobindUtils.h" |
10 | |
11 | #include "mlir-c/Dialect/SMT.h" |
12 | #include "mlir-c/IR.h" |
13 | #include "mlir-c/Support.h" |
14 | #include "mlir-c/Target/ExportSMTLIB.h" |
15 | #include "mlir/Bindings/Python/Diagnostics.h" |
16 | #include "mlir/Bindings/Python/Nanobind.h" |
17 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
18 | |
19 | namespace nb = nanobind; |
20 | |
21 | using namespace nanobind::literals; |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::python; |
25 | using namespace mlir::python::nanobind_adaptors; |
26 | |
27 | void populateDialectSMTSubmodule(nanobind::module_ &m) { |
28 | |
29 | auto smtBoolType = mlir_type_subclass(m, "BoolType" , mlirSMTTypeIsABool) |
30 | .def_classmethod( |
31 | "get" , |
32 | [](const nb::object &, MlirContext context) { |
33 | return mlirSMTTypeGetBool(context); |
34 | }, |
35 | "cls"_a , "context"_a .none() = nb::none()); |
36 | auto smtBitVectorType = |
37 | mlir_type_subclass(m, "BitVectorType" , mlirSMTTypeIsABitVector) |
38 | .def_classmethod( |
39 | "get" , |
40 | [](const nb::object &, int32_t width, MlirContext context) { |
41 | return mlirSMTTypeGetBitVector(context, width); |
42 | }, |
43 | "cls"_a , "width"_a , "context"_a .none() = nb::none()); |
44 | |
45 | auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues, |
46 | bool indentLetBody) { |
47 | mlir::python::CollectDiagnosticsToStringScope scope( |
48 | mlirOperationGetContext(module)); |
49 | PyPrintAccumulator printAccum; |
50 | MlirLogicalResult result = mlirTranslateOperationToSMTLIB( |
51 | module, printAccum.getCallback(), printAccum.getUserData(), |
52 | inlineSingleUseValues, indentLetBody); |
53 | if (mlirLogicalResultIsSuccess(result)) |
54 | return printAccum.join(); |
55 | throw nb::value_error( |
56 | ("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage()) |
57 | .c_str()); |
58 | }; |
59 | |
60 | m.def( |
61 | "export_smtlib" , |
62 | [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues, |
63 | bool indentLetBody) { |
64 | return exportSMTLIB(module, inlineSingleUseValues, indentLetBody); |
65 | }, |
66 | "module"_a , "inline_single_use_values"_a = false, |
67 | "indent_let_body"_a = false); |
68 | m.def( |
69 | "export_smtlib" , |
70 | [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues, |
71 | bool indentLetBody) { |
72 | return exportSMTLIB(mlirModuleGetOperation(module), |
73 | inlineSingleUseValues, indentLetBody); |
74 | }, |
75 | "module"_a , "inline_single_use_values"_a = false, |
76 | "indent_let_body"_a = false); |
77 | } |
78 | |
79 | NB_MODULE(_mlirDialectsSMT, m) { |
80 | m.doc() = "MLIR SMT Dialect" ; |
81 | |
82 | populateDialectSMTSubmodule(m); |
83 | } |
84 | |