1 | //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <cstddef> |
10 | #include <cstdint> |
11 | #include <pybind11/cast.h> |
12 | #include <pybind11/detail/common.h> |
13 | #include <pybind11/pybind11.h> |
14 | #include <pybind11/pytypes.h> |
15 | #include <string> |
16 | #include <utility> |
17 | #include <vector> |
18 | |
19 | #include "IRModule.h" |
20 | |
21 | #include "PybindUtils.h" |
22 | |
23 | #include "mlir-c/AffineExpr.h" |
24 | #include "mlir-c/AffineMap.h" |
25 | #include "mlir-c/Bindings/Python/Interop.h" |
26 | #include "mlir-c/IntegerSet.h" |
27 | #include "mlir/Support/LLVM.h" |
28 | #include "llvm/ADT/Hashing.h" |
29 | #include "llvm/ADT/SmallVector.h" |
30 | #include "llvm/ADT/StringRef.h" |
31 | #include "llvm/ADT/Twine.h" |
32 | |
33 | namespace py = pybind11; |
34 | using namespace mlir; |
35 | using namespace mlir::python; |
36 | |
37 | using llvm::SmallVector; |
38 | using llvm::StringRef; |
39 | using llvm::Twine; |
40 | |
41 | static const char kDumpDocstring[] = |
42 | R"(Dumps a debug representation of the object to stderr.)" ; |
43 | |
44 | /// Attempts to populate `result` with the content of `list` casted to the |
45 | /// appropriate type (Python and C types are provided as template arguments). |
46 | /// Throws errors in case of failure, using "action" to describe what the caller |
47 | /// was attempting to do. |
48 | template <typename PyType, typename CType> |
49 | static void pyListToVector(const py::list &list, |
50 | llvm::SmallVectorImpl<CType> &result, |
51 | StringRef action) { |
52 | result.reserve(py::len(list)); |
53 | for (py::handle item : list) { |
54 | try { |
55 | result.push_back(item.cast<PyType>()); |
56 | } catch (py::cast_error &err) { |
57 | std::string msg = (llvm::Twine("Invalid expression when " ) + action + |
58 | " (" + err.what() + ")" ) |
59 | .str(); |
60 | throw py::cast_error(msg); |
61 | } catch (py::reference_cast_error &err) { |
62 | std::string msg = (llvm::Twine("Invalid expression (None?) when " ) + |
63 | action + " (" + err.what() + ")" ) |
64 | .str(); |
65 | throw py::cast_error(msg); |
66 | } |
67 | } |
68 | } |
69 | |
70 | template <typename PermutationTy> |
71 | static bool isPermutation(std::vector<PermutationTy> permutation) { |
72 | llvm::SmallVector<bool, 8> seen(permutation.size(), false); |
73 | for (auto val : permutation) { |
74 | if (val < permutation.size()) { |
75 | if (seen[val]) |
76 | return false; |
77 | seen[val] = true; |
78 | continue; |
79 | } |
80 | return false; |
81 | } |
82 | return true; |
83 | } |
84 | |
85 | namespace { |
86 | |
87 | /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr |
88 | /// and should be castable from it. Intermediate hierarchy classes can be |
89 | /// modeled by specifying BaseTy. |
90 | template <typename DerivedTy, typename BaseTy = PyAffineExpr> |
91 | class PyConcreteAffineExpr : public BaseTy { |
92 | public: |
93 | // Derived classes must define statics for: |
94 | // IsAFunctionTy isaFunction |
95 | // const char *pyClassName |
96 | // and redefine bindDerived. |
97 | using ClassTy = py::class_<DerivedTy, BaseTy>; |
98 | using IsAFunctionTy = bool (*)(MlirAffineExpr); |
99 | |
100 | PyConcreteAffineExpr() = default; |
101 | PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) |
102 | : BaseTy(std::move(contextRef), affineExpr) {} |
103 | PyConcreteAffineExpr(PyAffineExpr &orig) |
104 | : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} |
105 | |
106 | static MlirAffineExpr castFrom(PyAffineExpr &orig) { |
107 | if (!DerivedTy::isaFunction(orig)) { |
108 | auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); |
109 | throw py::value_error((Twine("Cannot cast affine expression to " ) + |
110 | DerivedTy::pyClassName + " (from " + origRepr + |
111 | ")" ) |
112 | .str()); |
113 | } |
114 | return orig; |
115 | } |
116 | |
117 | static void bind(py::module &m) { |
118 | auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); |
119 | cls.def(py::init<PyAffineExpr &>(), py::arg("expr" )); |
120 | cls.def_static( |
121 | "isinstance" , |
122 | [](PyAffineExpr &otherAffineExpr) -> bool { |
123 | return DerivedTy::isaFunction(otherAffineExpr); |
124 | }, |
125 | py::arg("other" )); |
126 | DerivedTy::bindDerived(cls); |
127 | } |
128 | |
129 | /// Implemented by derived classes to add methods to the Python subclass. |
130 | static void bindDerived(ClassTy &m) {} |
131 | }; |
132 | |
133 | class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { |
134 | public: |
135 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; |
136 | static constexpr const char *pyClassName = "AffineConstantExpr" ; |
137 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
138 | |
139 | static PyAffineConstantExpr get(intptr_t value, |
140 | DefaultingPyMlirContext context) { |
141 | MlirAffineExpr affineExpr = |
142 | mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); |
143 | return PyAffineConstantExpr(context->getRef(), affineExpr); |
144 | } |
145 | |
146 | static void bindDerived(ClassTy &c) { |
147 | c.def_static("get" , &PyAffineConstantExpr::get, py::arg("value" ), |
148 | py::arg("context" ) = py::none()); |
149 | c.def_property_readonly("value" , [](PyAffineConstantExpr &self) { |
150 | return mlirAffineConstantExprGetValue(self); |
151 | }); |
152 | } |
153 | }; |
154 | |
155 | class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { |
156 | public: |
157 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; |
158 | static constexpr const char *pyClassName = "AffineDimExpr" ; |
159 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
160 | |
161 | static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { |
162 | MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); |
163 | return PyAffineDimExpr(context->getRef(), affineExpr); |
164 | } |
165 | |
166 | static void bindDerived(ClassTy &c) { |
167 | c.def_static("get" , &PyAffineDimExpr::get, py::arg("position" ), |
168 | py::arg("context" ) = py::none()); |
169 | c.def_property_readonly("position" , [](PyAffineDimExpr &self) { |
170 | return mlirAffineDimExprGetPosition(self); |
171 | }); |
172 | } |
173 | }; |
174 | |
175 | class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { |
176 | public: |
177 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; |
178 | static constexpr const char *pyClassName = "AffineSymbolExpr" ; |
179 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
180 | |
181 | static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { |
182 | MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); |
183 | return PyAffineSymbolExpr(context->getRef(), affineExpr); |
184 | } |
185 | |
186 | static void bindDerived(ClassTy &c) { |
187 | c.def_static("get" , &PyAffineSymbolExpr::get, py::arg("position" ), |
188 | py::arg("context" ) = py::none()); |
189 | c.def_property_readonly("position" , [](PyAffineSymbolExpr &self) { |
190 | return mlirAffineSymbolExprGetPosition(self); |
191 | }); |
192 | } |
193 | }; |
194 | |
195 | class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { |
196 | public: |
197 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; |
198 | static constexpr const char *pyClassName = "AffineBinaryExpr" ; |
199 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
200 | |
201 | PyAffineExpr lhs() { |
202 | MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); |
203 | return PyAffineExpr(getContext(), lhsExpr); |
204 | } |
205 | |
206 | PyAffineExpr rhs() { |
207 | MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); |
208 | return PyAffineExpr(getContext(), rhsExpr); |
209 | } |
210 | |
211 | static void bindDerived(ClassTy &c) { |
212 | c.def_property_readonly("lhs" , &PyAffineBinaryExpr::lhs); |
213 | c.def_property_readonly("rhs" , &PyAffineBinaryExpr::rhs); |
214 | } |
215 | }; |
216 | |
217 | class PyAffineAddExpr |
218 | : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { |
219 | public: |
220 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; |
221 | static constexpr const char *pyClassName = "AffineAddExpr" ; |
222 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
223 | |
224 | static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
225 | MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); |
226 | return PyAffineAddExpr(lhs.getContext(), expr); |
227 | } |
228 | |
229 | static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
230 | MlirAffineExpr expr = mlirAffineAddExprGet( |
231 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
232 | return PyAffineAddExpr(lhs.getContext(), expr); |
233 | } |
234 | |
235 | static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
236 | MlirAffineExpr expr = mlirAffineAddExprGet( |
237 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
238 | return PyAffineAddExpr(rhs.getContext(), expr); |
239 | } |
240 | |
241 | static void bindDerived(ClassTy &c) { |
242 | c.def_static("get" , &PyAffineAddExpr::get); |
243 | } |
244 | }; |
245 | |
246 | class PyAffineMulExpr |
247 | : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { |
248 | public: |
249 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; |
250 | static constexpr const char *pyClassName = "AffineMulExpr" ; |
251 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
252 | |
253 | static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
254 | MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); |
255 | return PyAffineMulExpr(lhs.getContext(), expr); |
256 | } |
257 | |
258 | static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
259 | MlirAffineExpr expr = mlirAffineMulExprGet( |
260 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
261 | return PyAffineMulExpr(lhs.getContext(), expr); |
262 | } |
263 | |
264 | static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
265 | MlirAffineExpr expr = mlirAffineMulExprGet( |
266 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
267 | return PyAffineMulExpr(rhs.getContext(), expr); |
268 | } |
269 | |
270 | static void bindDerived(ClassTy &c) { |
271 | c.def_static("get" , &PyAffineMulExpr::get); |
272 | } |
273 | }; |
274 | |
275 | class PyAffineModExpr |
276 | : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { |
277 | public: |
278 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; |
279 | static constexpr const char *pyClassName = "AffineModExpr" ; |
280 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
281 | |
282 | static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
283 | MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); |
284 | return PyAffineModExpr(lhs.getContext(), expr); |
285 | } |
286 | |
287 | static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
288 | MlirAffineExpr expr = mlirAffineModExprGet( |
289 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
290 | return PyAffineModExpr(lhs.getContext(), expr); |
291 | } |
292 | |
293 | static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
294 | MlirAffineExpr expr = mlirAffineModExprGet( |
295 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
296 | return PyAffineModExpr(rhs.getContext(), expr); |
297 | } |
298 | |
299 | static void bindDerived(ClassTy &c) { |
300 | c.def_static("get" , &PyAffineModExpr::get); |
301 | } |
302 | }; |
303 | |
304 | class PyAffineFloorDivExpr |
305 | : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { |
306 | public: |
307 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; |
308 | static constexpr const char *pyClassName = "AffineFloorDivExpr" ; |
309 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
310 | |
311 | static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
312 | MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); |
313 | return PyAffineFloorDivExpr(lhs.getContext(), expr); |
314 | } |
315 | |
316 | static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
317 | MlirAffineExpr expr = mlirAffineFloorDivExprGet( |
318 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
319 | return PyAffineFloorDivExpr(lhs.getContext(), expr); |
320 | } |
321 | |
322 | static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
323 | MlirAffineExpr expr = mlirAffineFloorDivExprGet( |
324 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
325 | return PyAffineFloorDivExpr(rhs.getContext(), expr); |
326 | } |
327 | |
328 | static void bindDerived(ClassTy &c) { |
329 | c.def_static("get" , &PyAffineFloorDivExpr::get); |
330 | } |
331 | }; |
332 | |
333 | class PyAffineCeilDivExpr |
334 | : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { |
335 | public: |
336 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; |
337 | static constexpr const char *pyClassName = "AffineCeilDivExpr" ; |
338 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
339 | |
340 | static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
341 | MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); |
342 | return PyAffineCeilDivExpr(lhs.getContext(), expr); |
343 | } |
344 | |
345 | static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
346 | MlirAffineExpr expr = mlirAffineCeilDivExprGet( |
347 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
348 | return PyAffineCeilDivExpr(lhs.getContext(), expr); |
349 | } |
350 | |
351 | static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
352 | MlirAffineExpr expr = mlirAffineCeilDivExprGet( |
353 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
354 | return PyAffineCeilDivExpr(rhs.getContext(), expr); |
355 | } |
356 | |
357 | static void bindDerived(ClassTy &c) { |
358 | c.def_static("get" , &PyAffineCeilDivExpr::get); |
359 | } |
360 | }; |
361 | |
362 | } // namespace |
363 | |
364 | bool PyAffineExpr::operator==(const PyAffineExpr &other) const { |
365 | return mlirAffineExprEqual(affineExpr, other.affineExpr); |
366 | } |
367 | |
368 | py::object PyAffineExpr::getCapsule() { |
369 | return py::reinterpret_steal<py::object>( |
370 | mlirPythonAffineExprToCapsule(*this)); |
371 | } |
372 | |
373 | PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { |
374 | MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); |
375 | if (mlirAffineExprIsNull(rawAffineExpr)) |
376 | throw py::error_already_set(); |
377 | return PyAffineExpr( |
378 | PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), |
379 | rawAffineExpr); |
380 | } |
381 | |
382 | //------------------------------------------------------------------------------ |
383 | // PyAffineMap and utilities. |
384 | //------------------------------------------------------------------------------ |
385 | namespace { |
386 | |
387 | /// A list of expressions contained in an affine map. Internally these are |
388 | /// stored as a consecutive array leading to inexpensive random access. Both |
389 | /// the map and the expression are owned by the context so we need not bother |
390 | /// with lifetime extension. |
391 | class PyAffineMapExprList |
392 | : public Sliceable<PyAffineMapExprList, PyAffineExpr> { |
393 | public: |
394 | static constexpr const char *pyClassName = "AffineExprList" ; |
395 | |
396 | PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0, |
397 | intptr_t length = -1, intptr_t step = 1) |
398 | : Sliceable(startIndex, |
399 | length == -1 ? mlirAffineMapGetNumResults(map) : length, |
400 | step), |
401 | affineMap(map) {} |
402 | |
403 | private: |
404 | /// Give the parent CRTP class access to hook implementations below. |
405 | friend class Sliceable<PyAffineMapExprList, PyAffineExpr>; |
406 | |
407 | intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); } |
408 | |
409 | PyAffineExpr getRawElement(intptr_t pos) { |
410 | return PyAffineExpr(affineMap.getContext(), |
411 | mlirAffineMapGetResult(affineMap, pos)); |
412 | } |
413 | |
414 | PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, |
415 | intptr_t step) { |
416 | return PyAffineMapExprList(affineMap, startIndex, length, step); |
417 | } |
418 | |
419 | PyAffineMap affineMap; |
420 | }; |
421 | } // namespace |
422 | |
423 | bool PyAffineMap::operator==(const PyAffineMap &other) const { |
424 | return mlirAffineMapEqual(affineMap, other.affineMap); |
425 | } |
426 | |
427 | py::object PyAffineMap::getCapsule() { |
428 | return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); |
429 | } |
430 | |
431 | PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { |
432 | MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); |
433 | if (mlirAffineMapIsNull(rawAffineMap)) |
434 | throw py::error_already_set(); |
435 | return PyAffineMap( |
436 | PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), |
437 | rawAffineMap); |
438 | } |
439 | |
440 | //------------------------------------------------------------------------------ |
441 | // PyIntegerSet and utilities. |
442 | //------------------------------------------------------------------------------ |
443 | namespace { |
444 | |
445 | class PyIntegerSetConstraint { |
446 | public: |
447 | PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) |
448 | : set(std::move(set)), pos(pos) {} |
449 | |
450 | PyAffineExpr getExpr() { |
451 | return PyAffineExpr(set.getContext(), |
452 | mlirIntegerSetGetConstraint(set, pos)); |
453 | } |
454 | |
455 | bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } |
456 | |
457 | static void bind(py::module &m) { |
458 | py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint" , |
459 | py::module_local()) |
460 | .def_property_readonly("expr" , &PyIntegerSetConstraint::getExpr) |
461 | .def_property_readonly("is_eq" , &PyIntegerSetConstraint::isEq); |
462 | } |
463 | |
464 | private: |
465 | PyIntegerSet set; |
466 | intptr_t pos; |
467 | }; |
468 | |
469 | class PyIntegerSetConstraintList |
470 | : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { |
471 | public: |
472 | static constexpr const char *pyClassName = "IntegerSetConstraintList" ; |
473 | |
474 | PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0, |
475 | intptr_t length = -1, intptr_t step = 1) |
476 | : Sliceable(startIndex, |
477 | length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, |
478 | step), |
479 | set(set) {} |
480 | |
481 | private: |
482 | /// Give the parent CRTP class access to hook implementations below. |
483 | friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>; |
484 | |
485 | intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); } |
486 | |
487 | PyIntegerSetConstraint getRawElement(intptr_t pos) { |
488 | return PyIntegerSetConstraint(set, pos); |
489 | } |
490 | |
491 | PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, |
492 | intptr_t step) { |
493 | return PyIntegerSetConstraintList(set, startIndex, length, step); |
494 | } |
495 | |
496 | PyIntegerSet set; |
497 | }; |
498 | } // namespace |
499 | |
500 | bool PyIntegerSet::operator==(const PyIntegerSet &other) const { |
501 | return mlirIntegerSetEqual(integerSet, other.integerSet); |
502 | } |
503 | |
504 | py::object PyIntegerSet::getCapsule() { |
505 | return py::reinterpret_steal<py::object>( |
506 | mlirPythonIntegerSetToCapsule(*this)); |
507 | } |
508 | |
509 | PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { |
510 | MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); |
511 | if (mlirIntegerSetIsNull(rawIntegerSet)) |
512 | throw py::error_already_set(); |
513 | return PyIntegerSet( |
514 | PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), |
515 | rawIntegerSet); |
516 | } |
517 | |
518 | void mlir::python::populateIRAffine(py::module &m) { |
519 | //---------------------------------------------------------------------------- |
520 | // Mapping of PyAffineExpr and derived classes. |
521 | //---------------------------------------------------------------------------- |
522 | py::class_<PyAffineExpr>(m, "AffineExpr" , py::module_local()) |
523 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
524 | &PyAffineExpr::getCapsule) |
525 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) |
526 | .def("__add__" , &PyAffineAddExpr::get) |
527 | .def("__add__" , &PyAffineAddExpr::getRHSConstant) |
528 | .def("__radd__" , &PyAffineAddExpr::getRHSConstant) |
529 | .def("__mul__" , &PyAffineMulExpr::get) |
530 | .def("__mul__" , &PyAffineMulExpr::getRHSConstant) |
531 | .def("__rmul__" , &PyAffineMulExpr::getRHSConstant) |
532 | .def("__mod__" , &PyAffineModExpr::get) |
533 | .def("__mod__" , &PyAffineModExpr::getRHSConstant) |
534 | .def("__rmod__" , |
535 | [](PyAffineExpr &self, intptr_t other) { |
536 | return PyAffineModExpr::get( |
537 | PyAffineConstantExpr::get(other, *self.getContext().get()), |
538 | self); |
539 | }) |
540 | .def("__sub__" , |
541 | [](PyAffineExpr &self, PyAffineExpr &other) { |
542 | auto negOne = |
543 | PyAffineConstantExpr::get(-1, *self.getContext().get()); |
544 | return PyAffineAddExpr::get(self, |
545 | PyAffineMulExpr::get(negOne, other)); |
546 | }) |
547 | .def("__sub__" , |
548 | [](PyAffineExpr &self, intptr_t other) { |
549 | return PyAffineAddExpr::get( |
550 | self, |
551 | PyAffineConstantExpr::get(-other, *self.getContext().get())); |
552 | }) |
553 | .def("__rsub__" , |
554 | [](PyAffineExpr &self, intptr_t other) { |
555 | return PyAffineAddExpr::getLHSConstant( |
556 | other, PyAffineMulExpr::getLHSConstant(-1, self)); |
557 | }) |
558 | .def("__eq__" , [](PyAffineExpr &self, |
559 | PyAffineExpr &other) { return self == other; }) |
560 | .def("__eq__" , |
561 | [](PyAffineExpr &self, py::object &other) { return false; }) |
562 | .def("__str__" , |
563 | [](PyAffineExpr &self) { |
564 | PyPrintAccumulator printAccum; |
565 | mlirAffineExprPrint(self, printAccum.getCallback(), |
566 | printAccum.getUserData()); |
567 | return printAccum.join(); |
568 | }) |
569 | .def("__repr__" , |
570 | [](PyAffineExpr &self) { |
571 | PyPrintAccumulator printAccum; |
572 | printAccum.parts.append("AffineExpr(" ); |
573 | mlirAffineExprPrint(self, printAccum.getCallback(), |
574 | printAccum.getUserData()); |
575 | printAccum.parts.append(")" ); |
576 | return printAccum.join(); |
577 | }) |
578 | .def("__hash__" , |
579 | [](PyAffineExpr &self) { |
580 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
581 | }) |
582 | .def_property_readonly( |
583 | "context" , |
584 | [](PyAffineExpr &self) { return self.getContext().getObject(); }) |
585 | .def("compose" , |
586 | [](PyAffineExpr &self, PyAffineMap &other) { |
587 | return PyAffineExpr(self.getContext(), |
588 | mlirAffineExprCompose(self, other)); |
589 | }) |
590 | .def_static( |
591 | "get_add" , &PyAffineAddExpr::get, |
592 | "Gets an affine expression containing a sum of two expressions." ) |
593 | .def_static("get_add" , &PyAffineAddExpr::getLHSConstant, |
594 | "Gets an affine expression containing a sum of a constant " |
595 | "and another expression." ) |
596 | .def_static("get_add" , &PyAffineAddExpr::getRHSConstant, |
597 | "Gets an affine expression containing a sum of an expression " |
598 | "and a constant." ) |
599 | .def_static( |
600 | "get_mul" , &PyAffineMulExpr::get, |
601 | "Gets an affine expression containing a product of two expressions." ) |
602 | .def_static("get_mul" , &PyAffineMulExpr::getLHSConstant, |
603 | "Gets an affine expression containing a product of a " |
604 | "constant and another expression." ) |
605 | .def_static("get_mul" , &PyAffineMulExpr::getRHSConstant, |
606 | "Gets an affine expression containing a product of an " |
607 | "expression and a constant." ) |
608 | .def_static("get_mod" , &PyAffineModExpr::get, |
609 | "Gets an affine expression containing the modulo of dividing " |
610 | "one expression by another." ) |
611 | .def_static("get_mod" , &PyAffineModExpr::getLHSConstant, |
612 | "Gets a semi-affine expression containing the modulo of " |
613 | "dividing a constant by an expression." ) |
614 | .def_static("get_mod" , &PyAffineModExpr::getRHSConstant, |
615 | "Gets an affine expression containing the module of dividing" |
616 | "an expression by a constant." ) |
617 | .def_static("get_floor_div" , &PyAffineFloorDivExpr::get, |
618 | "Gets an affine expression containing the rounded-down " |
619 | "result of dividing one expression by another." ) |
620 | .def_static("get_floor_div" , &PyAffineFloorDivExpr::getLHSConstant, |
621 | "Gets a semi-affine expression containing the rounded-down " |
622 | "result of dividing a constant by an expression." ) |
623 | .def_static("get_floor_div" , &PyAffineFloorDivExpr::getRHSConstant, |
624 | "Gets an affine expression containing the rounded-down " |
625 | "result of dividing an expression by a constant." ) |
626 | .def_static("get_ceil_div" , &PyAffineCeilDivExpr::get, |
627 | "Gets an affine expression containing the rounded-up result " |
628 | "of dividing one expression by another." ) |
629 | .def_static("get_ceil_div" , &PyAffineCeilDivExpr::getLHSConstant, |
630 | "Gets a semi-affine expression containing the rounded-up " |
631 | "result of dividing a constant by an expression." ) |
632 | .def_static("get_ceil_div" , &PyAffineCeilDivExpr::getRHSConstant, |
633 | "Gets an affine expression containing the rounded-up result " |
634 | "of dividing an expression by a constant." ) |
635 | .def_static("get_constant" , &PyAffineConstantExpr::get, py::arg("value" ), |
636 | py::arg("context" ) = py::none(), |
637 | "Gets a constant affine expression with the given value." ) |
638 | .def_static( |
639 | "get_dim" , &PyAffineDimExpr::get, py::arg("position" ), |
640 | py::arg("context" ) = py::none(), |
641 | "Gets an affine expression of a dimension at the given position." ) |
642 | .def_static( |
643 | "get_symbol" , &PyAffineSymbolExpr::get, py::arg("position" ), |
644 | py::arg("context" ) = py::none(), |
645 | "Gets an affine expression of a symbol at the given position." ) |
646 | .def( |
647 | "dump" , [](PyAffineExpr &self) { mlirAffineExprDump(self); }, |
648 | kDumpDocstring); |
649 | PyAffineConstantExpr::bind(m); |
650 | PyAffineDimExpr::bind(m); |
651 | PyAffineSymbolExpr::bind(m); |
652 | PyAffineBinaryExpr::bind(m); |
653 | PyAffineAddExpr::bind(m); |
654 | PyAffineMulExpr::bind(m); |
655 | PyAffineModExpr::bind(m); |
656 | PyAffineFloorDivExpr::bind(m); |
657 | PyAffineCeilDivExpr::bind(m); |
658 | |
659 | //---------------------------------------------------------------------------- |
660 | // Mapping of PyAffineMap. |
661 | //---------------------------------------------------------------------------- |
662 | py::class_<PyAffineMap>(m, "AffineMap" , py::module_local()) |
663 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
664 | &PyAffineMap::getCapsule) |
665 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) |
666 | .def("__eq__" , |
667 | [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) |
668 | .def("__eq__" , [](PyAffineMap &self, py::object &other) { return false; }) |
669 | .def("__str__" , |
670 | [](PyAffineMap &self) { |
671 | PyPrintAccumulator printAccum; |
672 | mlirAffineMapPrint(self, printAccum.getCallback(), |
673 | printAccum.getUserData()); |
674 | return printAccum.join(); |
675 | }) |
676 | .def("__repr__" , |
677 | [](PyAffineMap &self) { |
678 | PyPrintAccumulator printAccum; |
679 | printAccum.parts.append("AffineMap(" ); |
680 | mlirAffineMapPrint(self, printAccum.getCallback(), |
681 | printAccum.getUserData()); |
682 | printAccum.parts.append(")" ); |
683 | return printAccum.join(); |
684 | }) |
685 | .def("__hash__" , |
686 | [](PyAffineMap &self) { |
687 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
688 | }) |
689 | .def_static("compress_unused_symbols" , |
690 | [](py::list affineMaps, DefaultingPyMlirContext context) { |
691 | SmallVector<MlirAffineMap> maps; |
692 | pyListToVector<PyAffineMap, MlirAffineMap>( |
693 | affineMaps, maps, "attempting to create an AffineMap" ); |
694 | std::vector<MlirAffineMap> compressed(affineMaps.size()); |
695 | auto populate = [](void *result, intptr_t idx, |
696 | MlirAffineMap m) { |
697 | static_cast<MlirAffineMap *>(result)[idx] = (m); |
698 | }; |
699 | mlirAffineMapCompressUnusedSymbols( |
700 | maps.data(), maps.size(), compressed.data(), populate); |
701 | std::vector<PyAffineMap> res; |
702 | res.reserve(compressed.size()); |
703 | for (auto m : compressed) |
704 | res.emplace_back(context->getRef(), m); |
705 | return res; |
706 | }) |
707 | .def_property_readonly( |
708 | "context" , |
709 | [](PyAffineMap &self) { return self.getContext().getObject(); }, |
710 | "Context that owns the Affine Map" ) |
711 | .def( |
712 | "dump" , [](PyAffineMap &self) { mlirAffineMapDump(self); }, |
713 | kDumpDocstring) |
714 | .def_static( |
715 | "get" , |
716 | [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, |
717 | DefaultingPyMlirContext context) { |
718 | SmallVector<MlirAffineExpr> affineExprs; |
719 | pyListToVector<PyAffineExpr, MlirAffineExpr>( |
720 | exprs, affineExprs, "attempting to create an AffineMap" ); |
721 | MlirAffineMap map = |
722 | mlirAffineMapGet(context->get(), dimCount, symbolCount, |
723 | affineExprs.size(), affineExprs.data()); |
724 | return PyAffineMap(context->getRef(), map); |
725 | }, |
726 | py::arg("dim_count" ), py::arg("symbol_count" ), py::arg("exprs" ), |
727 | py::arg("context" ) = py::none(), |
728 | "Gets a map with the given expressions as results." ) |
729 | .def_static( |
730 | "get_constant" , |
731 | [](intptr_t value, DefaultingPyMlirContext context) { |
732 | MlirAffineMap affineMap = |
733 | mlirAffineMapConstantGet(context->get(), value); |
734 | return PyAffineMap(context->getRef(), affineMap); |
735 | }, |
736 | py::arg("value" ), py::arg("context" ) = py::none(), |
737 | "Gets an affine map with a single constant result" ) |
738 | .def_static( |
739 | "get_empty" , |
740 | [](DefaultingPyMlirContext context) { |
741 | MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); |
742 | return PyAffineMap(context->getRef(), affineMap); |
743 | }, |
744 | py::arg("context" ) = py::none(), "Gets an empty affine map." ) |
745 | .def_static( |
746 | "get_identity" , |
747 | [](intptr_t nDims, DefaultingPyMlirContext context) { |
748 | MlirAffineMap affineMap = |
749 | mlirAffineMapMultiDimIdentityGet(context->get(), nDims); |
750 | return PyAffineMap(context->getRef(), affineMap); |
751 | }, |
752 | py::arg("n_dims" ), py::arg("context" ) = py::none(), |
753 | "Gets an identity map with the given number of dimensions." ) |
754 | .def_static( |
755 | "get_minor_identity" , |
756 | [](intptr_t nDims, intptr_t nResults, |
757 | DefaultingPyMlirContext context) { |
758 | MlirAffineMap affineMap = |
759 | mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); |
760 | return PyAffineMap(context->getRef(), affineMap); |
761 | }, |
762 | py::arg("n_dims" ), py::arg("n_results" ), |
763 | py::arg("context" ) = py::none(), |
764 | "Gets a minor identity map with the given number of dimensions and " |
765 | "results." ) |
766 | .def_static( |
767 | "get_permutation" , |
768 | [](std::vector<unsigned> permutation, |
769 | DefaultingPyMlirContext context) { |
770 | if (!isPermutation(permutation)) |
771 | throw py::cast_error("Invalid permutation when attempting to " |
772 | "create an AffineMap" ); |
773 | MlirAffineMap affineMap = mlirAffineMapPermutationGet( |
774 | context->get(), permutation.size(), permutation.data()); |
775 | return PyAffineMap(context->getRef(), affineMap); |
776 | }, |
777 | py::arg("permutation" ), py::arg("context" ) = py::none(), |
778 | "Gets an affine map that permutes its inputs." ) |
779 | .def( |
780 | "get_submap" , |
781 | [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { |
782 | intptr_t numResults = mlirAffineMapGetNumResults(self); |
783 | for (intptr_t pos : resultPos) { |
784 | if (pos < 0 || pos >= numResults) |
785 | throw py::value_error("result position out of bounds" ); |
786 | } |
787 | MlirAffineMap affineMap = mlirAffineMapGetSubMap( |
788 | self, resultPos.size(), resultPos.data()); |
789 | return PyAffineMap(self.getContext(), affineMap); |
790 | }, |
791 | py::arg("result_positions" )) |
792 | .def( |
793 | "get_major_submap" , |
794 | [](PyAffineMap &self, intptr_t nResults) { |
795 | if (nResults >= mlirAffineMapGetNumResults(self)) |
796 | throw py::value_error("number of results out of bounds" ); |
797 | MlirAffineMap affineMap = |
798 | mlirAffineMapGetMajorSubMap(self, nResults); |
799 | return PyAffineMap(self.getContext(), affineMap); |
800 | }, |
801 | py::arg("n_results" )) |
802 | .def( |
803 | "get_minor_submap" , |
804 | [](PyAffineMap &self, intptr_t nResults) { |
805 | if (nResults >= mlirAffineMapGetNumResults(self)) |
806 | throw py::value_error("number of results out of bounds" ); |
807 | MlirAffineMap affineMap = |
808 | mlirAffineMapGetMinorSubMap(self, nResults); |
809 | return PyAffineMap(self.getContext(), affineMap); |
810 | }, |
811 | py::arg("n_results" )) |
812 | .def( |
813 | "replace" , |
814 | [](PyAffineMap &self, PyAffineExpr &expression, |
815 | PyAffineExpr &replacement, intptr_t numResultDims, |
816 | intptr_t numResultSyms) { |
817 | MlirAffineMap affineMap = mlirAffineMapReplace( |
818 | self, expression, replacement, numResultDims, numResultSyms); |
819 | return PyAffineMap(self.getContext(), affineMap); |
820 | }, |
821 | py::arg("expr" ), py::arg("replacement" ), py::arg("n_result_dims" ), |
822 | py::arg("n_result_syms" )) |
823 | .def_property_readonly( |
824 | "is_permutation" , |
825 | [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) |
826 | .def_property_readonly("is_projected_permutation" , |
827 | [](PyAffineMap &self) { |
828 | return mlirAffineMapIsProjectedPermutation(self); |
829 | }) |
830 | .def_property_readonly( |
831 | "n_dims" , |
832 | [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) |
833 | .def_property_readonly( |
834 | "n_inputs" , |
835 | [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) |
836 | .def_property_readonly( |
837 | "n_symbols" , |
838 | [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) |
839 | .def_property_readonly("results" , [](PyAffineMap &self) { |
840 | return PyAffineMapExprList(self); |
841 | }); |
842 | PyAffineMapExprList::bind(m); |
843 | |
844 | //---------------------------------------------------------------------------- |
845 | // Mapping of PyIntegerSet. |
846 | //---------------------------------------------------------------------------- |
847 | py::class_<PyIntegerSet>(m, "IntegerSet" , py::module_local()) |
848 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
849 | &PyIntegerSet::getCapsule) |
850 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) |
851 | .def("__eq__" , [](PyIntegerSet &self, |
852 | PyIntegerSet &other) { return self == other; }) |
853 | .def("__eq__" , [](PyIntegerSet &self, py::object other) { return false; }) |
854 | .def("__str__" , |
855 | [](PyIntegerSet &self) { |
856 | PyPrintAccumulator printAccum; |
857 | mlirIntegerSetPrint(self, printAccum.getCallback(), |
858 | printAccum.getUserData()); |
859 | return printAccum.join(); |
860 | }) |
861 | .def("__repr__" , |
862 | [](PyIntegerSet &self) { |
863 | PyPrintAccumulator printAccum; |
864 | printAccum.parts.append("IntegerSet(" ); |
865 | mlirIntegerSetPrint(self, printAccum.getCallback(), |
866 | printAccum.getUserData()); |
867 | printAccum.parts.append(")" ); |
868 | return printAccum.join(); |
869 | }) |
870 | .def("__hash__" , |
871 | [](PyIntegerSet &self) { |
872 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
873 | }) |
874 | .def_property_readonly( |
875 | "context" , |
876 | [](PyIntegerSet &self) { return self.getContext().getObject(); }) |
877 | .def( |
878 | "dump" , [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, |
879 | kDumpDocstring) |
880 | .def_static( |
881 | "get" , |
882 | [](intptr_t numDims, intptr_t numSymbols, py::list exprs, |
883 | std::vector<bool> eqFlags, DefaultingPyMlirContext context) { |
884 | if (exprs.size() != eqFlags.size()) |
885 | throw py::value_error( |
886 | "Expected the number of constraints to match " |
887 | "that of equality flags" ); |
888 | if (exprs.empty()) |
889 | throw py::value_error("Expected non-empty list of constraints" ); |
890 | |
891 | // Copy over to a SmallVector because std::vector has a |
892 | // specialization for booleans that packs data and does not |
893 | // expose a `bool *`. |
894 | SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); |
895 | |
896 | SmallVector<MlirAffineExpr> affineExprs; |
897 | pyListToVector<PyAffineExpr>(exprs, affineExprs, |
898 | "attempting to create an IntegerSet" ); |
899 | MlirIntegerSet set = mlirIntegerSetGet( |
900 | context->get(), numDims, numSymbols, exprs.size(), |
901 | affineExprs.data(), flags.data()); |
902 | return PyIntegerSet(context->getRef(), set); |
903 | }, |
904 | py::arg("num_dims" ), py::arg("num_symbols" ), py::arg("exprs" ), |
905 | py::arg("eq_flags" ), py::arg("context" ) = py::none()) |
906 | .def_static( |
907 | "get_empty" , |
908 | [](intptr_t numDims, intptr_t numSymbols, |
909 | DefaultingPyMlirContext context) { |
910 | MlirIntegerSet set = |
911 | mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); |
912 | return PyIntegerSet(context->getRef(), set); |
913 | }, |
914 | py::arg("num_dims" ), py::arg("num_symbols" ), |
915 | py::arg("context" ) = py::none()) |
916 | .def( |
917 | "get_replaced" , |
918 | [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, |
919 | intptr_t numResultDims, intptr_t numResultSymbols) { |
920 | if (static_cast<intptr_t>(dimExprs.size()) != |
921 | mlirIntegerSetGetNumDims(self)) |
922 | throw py::value_error( |
923 | "Expected the number of dimension replacement expressions " |
924 | "to match that of dimensions" ); |
925 | if (static_cast<intptr_t>(symbolExprs.size()) != |
926 | mlirIntegerSetGetNumSymbols(self)) |
927 | throw py::value_error( |
928 | "Expected the number of symbol replacement expressions " |
929 | "to match that of symbols" ); |
930 | |
931 | SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; |
932 | pyListToVector<PyAffineExpr>( |
933 | dimExprs, dimAffineExprs, |
934 | "attempting to create an IntegerSet by replacing dimensions" ); |
935 | pyListToVector<PyAffineExpr>( |
936 | symbolExprs, symbolAffineExprs, |
937 | "attempting to create an IntegerSet by replacing symbols" ); |
938 | MlirIntegerSet set = mlirIntegerSetReplaceGet( |
939 | self, dimAffineExprs.data(), symbolAffineExprs.data(), |
940 | numResultDims, numResultSymbols); |
941 | return PyIntegerSet(self.getContext(), set); |
942 | }, |
943 | py::arg("dim_exprs" ), py::arg("symbol_exprs" ), |
944 | py::arg("num_result_dims" ), py::arg("num_result_symbols" )) |
945 | .def_property_readonly("is_canonical_empty" , |
946 | [](PyIntegerSet &self) { |
947 | return mlirIntegerSetIsCanonicalEmpty(self); |
948 | }) |
949 | .def_property_readonly( |
950 | "n_dims" , |
951 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) |
952 | .def_property_readonly( |
953 | "n_symbols" , |
954 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) |
955 | .def_property_readonly( |
956 | "n_inputs" , |
957 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) |
958 | .def_property_readonly("n_equalities" , |
959 | [](PyIntegerSet &self) { |
960 | return mlirIntegerSetGetNumEqualities(self); |
961 | }) |
962 | .def_property_readonly("n_inequalities" , |
963 | [](PyIntegerSet &self) { |
964 | return mlirIntegerSetGetNumInequalities(self); |
965 | }) |
966 | .def_property_readonly("constraints" , [](PyIntegerSet &self) { |
967 | return PyIntegerSetConstraintList(self); |
968 | }); |
969 | PyIntegerSetConstraint::bind(m); |
970 | PyIntegerSetConstraintList::bind(m); |
971 | } |
972 | |