1//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/Dialect/Linalg.h"
10#include "mlir-c/IR.h"
11#include "mlir/Bindings/Python/Nanobind.h"
12#include "mlir/Bindings/Python/NanobindAdaptors.h"
13
14namespace nb = nanobind;
15using namespace mlir::python::nanobind_adaptors;
16
17static std::optional<MlirLinalgContractionDimensions>
18InferContractionDimensions(MlirOperation op) {
19 MlirLinalgContractionDimensions dims =
20 mlirLinalgInferContractionDimensions(op);
21
22 // Detect "empty" result. This occurs when `op` is not a contraction op,
23 // or when `linalg::inferContractionDims` fails.
24 if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
25 mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
26 return std::nullopt;
27 }
28 return dims;
29}
30
31static std::optional<MlirLinalgConvolutionDimensions>
32InferConvolutionDimensions(MlirOperation op) {
33 MlirLinalgConvolutionDimensions dims =
34 mlirLinalgInferConvolutionDimensions(op);
35
36 // Detect "empty" result. This occurs when `op` is not a convolution op,
37 // or when `linalg::inferConvolutionDims` fails.
38 if (mlirAttributeIsNull(dims.batch) &&
39 mlirAttributeIsNull(dims.outputImage) &&
40 mlirAttributeIsNull(dims.outputChannel) &&
41 mlirAttributeIsNull(dims.filterLoop) &&
42 mlirAttributeIsNull(dims.inputChannel) &&
43 mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) &&
44 mlirAttributeIsNull(dims.dilations)) {
45 return std::nullopt;
46 }
47
48 return dims;
49}
50
51static void populateDialectLinalgSubmodule(nb::module_ m) {
52 m.def(
53 "fill_builtin_region",
54 [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
55 nb::arg("op"),
56 "Fill the region for `op`, which is assumed to be a builtin named Linalg "
57 "op.");
58
59 m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
60 "Checks if the given operation is a Linalg contraction operation.",
61 nb::arg("op"));
62
63 nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
64 .def_prop_ro("batch",
65 [](const MlirLinalgContractionDimensions &self) {
66 return self.batch;
67 })
68 .def_prop_ro(
69 "m",
70 [](const MlirLinalgContractionDimensions &self) { return self.m; })
71 .def_prop_ro(
72 "n",
73 [](const MlirLinalgContractionDimensions &self) { return self.n; })
74 .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
75 return self.k;
76 });
77
78 m.def("infer_contraction_dimensions", &InferContractionDimensions,
79 "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
80 "op.",
81 nb::arg("op"));
82
83 m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
84 "Checks if the given operation is a Linalg convolution operation.",
85 nb::arg("op"));
86
87 nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
88 .def_prop_ro("batch",
89 [](const MlirLinalgConvolutionDimensions &self) {
90 return self.batch;
91 })
92 .def_prop_ro("output_image",
93 [](const MlirLinalgConvolutionDimensions &self) {
94 return self.outputImage;
95 })
96 .def_prop_ro("output_channel",
97 [](const MlirLinalgConvolutionDimensions &self) {
98 return self.outputChannel;
99 })
100 .def_prop_ro("filter_loop",
101 [](const MlirLinalgConvolutionDimensions &self) {
102 return self.filterLoop;
103 })
104 .def_prop_ro("input_channel",
105 [](const MlirLinalgConvolutionDimensions &self) {
106 return self.inputChannel;
107 })
108 .def_prop_ro("depth",
109 [](const MlirLinalgConvolutionDimensions &self) {
110 return self.depth;
111 })
112 .def_prop_ro("strides",
113 [](const MlirLinalgConvolutionDimensions &self) {
114 return self.strides;
115 })
116 .def_prop_ro("dilations",
117 [](const MlirLinalgConvolutionDimensions &self) {
118 return self.dilations;
119 });
120
121 m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
122 "Infers convolution dimensions", nb::arg("op"));
123
124 m.def(
125 "get_indexing_maps",
126 [](MlirOperation op) -> std::optional<MlirAttribute> {
127 MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
128 if (mlirAttributeIsNull(attr))
129 return std::nullopt;
130 return attr;
131 },
132 "Returns the indexing_maps attribute for a linalg op.");
133}
134
135NB_MODULE(_mlirDialectsLinalg, m) {
136 m.doc() = "MLIR Linalg dialect.";
137
138 populateDialectLinalgSubmodule(m);
139}
140

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Bindings/Python/DialectLinalg.cpp