1//===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstddef>
10#include <cstdint>
11#include <pybind11/cast.h>
12#include <pybind11/detail/common.h>
13#include <pybind11/pybind11.h>
14#include <pybind11/pytypes.h>
15#include <string>
16#include <utility>
17#include <vector>
18
19#include "IRModule.h"
20
21#include "PybindUtils.h"
22
23#include "mlir-c/AffineExpr.h"
24#include "mlir-c/AffineMap.h"
25#include "mlir-c/Bindings/Python/Interop.h"
26#include "mlir-c/IntegerSet.h"
27#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/Hashing.h"
29#include "llvm/ADT/SmallVector.h"
30#include "llvm/ADT/StringRef.h"
31#include "llvm/ADT/Twine.h"
32
33namespace py = pybind11;
34using namespace mlir;
35using namespace mlir::python;
36
37using llvm::SmallVector;
38using llvm::StringRef;
39using llvm::Twine;
40
41static const char kDumpDocstring[] =
42 R"(Dumps a debug representation of the object to stderr.)";
43
44/// Attempts to populate `result` with the content of `list` casted to the
45/// appropriate type (Python and C types are provided as template arguments).
46/// Throws errors in case of failure, using "action" to describe what the caller
47/// was attempting to do.
48template <typename PyType, typename CType>
49static void pyListToVector(const py::list &list,
50 llvm::SmallVectorImpl<CType> &result,
51 StringRef action) {
52 result.reserve(py::len(list));
53 for (py::handle item : list) {
54 try {
55 result.push_back(item.cast<PyType>());
56 } catch (py::cast_error &err) {
57 std::string msg = (llvm::Twine("Invalid expression when ") + action +
58 " (" + err.what() + ")")
59 .str();
60 throw py::cast_error(msg);
61 } catch (py::reference_cast_error &err) {
62 std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
63 action + " (" + err.what() + ")")
64 .str();
65 throw py::cast_error(msg);
66 }
67 }
68}
69
70template <typename PermutationTy>
71static bool isPermutation(std::vector<PermutationTy> permutation) {
72 llvm::SmallVector<bool, 8> seen(permutation.size(), false);
73 for (auto val : permutation) {
74 if (val < permutation.size()) {
75 if (seen[val])
76 return false;
77 seen[val] = true;
78 continue;
79 }
80 return false;
81 }
82 return true;
83}
84
85namespace {
86
87/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
88/// and should be castable from it. Intermediate hierarchy classes can be
89/// modeled by specifying BaseTy.
90template <typename DerivedTy, typename BaseTy = PyAffineExpr>
91class PyConcreteAffineExpr : public BaseTy {
92public:
93 // Derived classes must define statics for:
94 // IsAFunctionTy isaFunction
95 // const char *pyClassName
96 // and redefine bindDerived.
97 using ClassTy = py::class_<DerivedTy, BaseTy>;
98 using IsAFunctionTy = bool (*)(MlirAffineExpr);
99
100 PyConcreteAffineExpr() = default;
101 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
102 : BaseTy(std::move(contextRef), affineExpr) {}
103 PyConcreteAffineExpr(PyAffineExpr &orig)
104 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
105
106 static MlirAffineExpr castFrom(PyAffineExpr &orig) {
107 if (!DerivedTy::isaFunction(orig)) {
108 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
109 throw py::value_error((Twine("Cannot cast affine expression to ") +
110 DerivedTy::pyClassName + " (from " + origRepr +
111 ")")
112 .str());
113 }
114 return orig;
115 }
116
117 static void bind(py::module &m) {
118 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
119 cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
120 cls.def_static(
121 "isinstance",
122 [](PyAffineExpr &otherAffineExpr) -> bool {
123 return DerivedTy::isaFunction(otherAffineExpr);
124 },
125 py::arg("other"));
126 DerivedTy::bindDerived(cls);
127 }
128
129 /// Implemented by derived classes to add methods to the Python subclass.
130 static void bindDerived(ClassTy &m) {}
131};
132
133class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
134public:
135 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
136 static constexpr const char *pyClassName = "AffineConstantExpr";
137 using PyConcreteAffineExpr::PyConcreteAffineExpr;
138
139 static PyAffineConstantExpr get(intptr_t value,
140 DefaultingPyMlirContext context) {
141 MlirAffineExpr affineExpr =
142 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
143 return PyAffineConstantExpr(context->getRef(), affineExpr);
144 }
145
146 static void bindDerived(ClassTy &c) {
147 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
148 py::arg("context") = py::none());
149 c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
150 return mlirAffineConstantExprGetValue(self);
151 });
152 }
153};
154
155class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
156public:
157 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
158 static constexpr const char *pyClassName = "AffineDimExpr";
159 using PyConcreteAffineExpr::PyConcreteAffineExpr;
160
161 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
162 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
163 return PyAffineDimExpr(context->getRef(), affineExpr);
164 }
165
166 static void bindDerived(ClassTy &c) {
167 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
168 py::arg("context") = py::none());
169 c.def_property_readonly("position", [](PyAffineDimExpr &self) {
170 return mlirAffineDimExprGetPosition(self);
171 });
172 }
173};
174
175class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
176public:
177 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
178 static constexpr const char *pyClassName = "AffineSymbolExpr";
179 using PyConcreteAffineExpr::PyConcreteAffineExpr;
180
181 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
182 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
183 return PyAffineSymbolExpr(context->getRef(), affineExpr);
184 }
185
186 static void bindDerived(ClassTy &c) {
187 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
188 py::arg("context") = py::none());
189 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
190 return mlirAffineSymbolExprGetPosition(self);
191 });
192 }
193};
194
195class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
196public:
197 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
198 static constexpr const char *pyClassName = "AffineBinaryExpr";
199 using PyConcreteAffineExpr::PyConcreteAffineExpr;
200
201 PyAffineExpr lhs() {
202 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
203 return PyAffineExpr(getContext(), lhsExpr);
204 }
205
206 PyAffineExpr rhs() {
207 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
208 return PyAffineExpr(getContext(), rhsExpr);
209 }
210
211 static void bindDerived(ClassTy &c) {
212 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
213 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
214 }
215};
216
217class PyAffineAddExpr
218 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
219public:
220 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
221 static constexpr const char *pyClassName = "AffineAddExpr";
222 using PyConcreteAffineExpr::PyConcreteAffineExpr;
223
224 static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
225 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
226 return PyAffineAddExpr(lhs.getContext(), expr);
227 }
228
229 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
230 MlirAffineExpr expr = mlirAffineAddExprGet(
231 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
232 return PyAffineAddExpr(lhs.getContext(), expr);
233 }
234
235 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
236 MlirAffineExpr expr = mlirAffineAddExprGet(
237 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
238 return PyAffineAddExpr(rhs.getContext(), expr);
239 }
240
241 static void bindDerived(ClassTy &c) {
242 c.def_static("get", &PyAffineAddExpr::get);
243 }
244};
245
246class PyAffineMulExpr
247 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
248public:
249 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
250 static constexpr const char *pyClassName = "AffineMulExpr";
251 using PyConcreteAffineExpr::PyConcreteAffineExpr;
252
253 static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
254 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
255 return PyAffineMulExpr(lhs.getContext(), expr);
256 }
257
258 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
259 MlirAffineExpr expr = mlirAffineMulExprGet(
260 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
261 return PyAffineMulExpr(lhs.getContext(), expr);
262 }
263
264 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
265 MlirAffineExpr expr = mlirAffineMulExprGet(
266 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
267 return PyAffineMulExpr(rhs.getContext(), expr);
268 }
269
270 static void bindDerived(ClassTy &c) {
271 c.def_static("get", &PyAffineMulExpr::get);
272 }
273};
274
275class PyAffineModExpr
276 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
277public:
278 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
279 static constexpr const char *pyClassName = "AffineModExpr";
280 using PyConcreteAffineExpr::PyConcreteAffineExpr;
281
282 static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
283 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
284 return PyAffineModExpr(lhs.getContext(), expr);
285 }
286
287 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
288 MlirAffineExpr expr = mlirAffineModExprGet(
289 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
290 return PyAffineModExpr(lhs.getContext(), expr);
291 }
292
293 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
294 MlirAffineExpr expr = mlirAffineModExprGet(
295 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
296 return PyAffineModExpr(rhs.getContext(), expr);
297 }
298
299 static void bindDerived(ClassTy &c) {
300 c.def_static("get", &PyAffineModExpr::get);
301 }
302};
303
304class PyAffineFloorDivExpr
305 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
306public:
307 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
308 static constexpr const char *pyClassName = "AffineFloorDivExpr";
309 using PyConcreteAffineExpr::PyConcreteAffineExpr;
310
311 static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
312 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
313 return PyAffineFloorDivExpr(lhs.getContext(), expr);
314 }
315
316 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
317 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
318 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
319 return PyAffineFloorDivExpr(lhs.getContext(), expr);
320 }
321
322 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
323 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
324 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
325 return PyAffineFloorDivExpr(rhs.getContext(), expr);
326 }
327
328 static void bindDerived(ClassTy &c) {
329 c.def_static("get", &PyAffineFloorDivExpr::get);
330 }
331};
332
333class PyAffineCeilDivExpr
334 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
335public:
336 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
337 static constexpr const char *pyClassName = "AffineCeilDivExpr";
338 using PyConcreteAffineExpr::PyConcreteAffineExpr;
339
340 static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
341 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
342 return PyAffineCeilDivExpr(lhs.getContext(), expr);
343 }
344
345 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
346 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
347 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
348 return PyAffineCeilDivExpr(lhs.getContext(), expr);
349 }
350
351 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
352 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
353 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
354 return PyAffineCeilDivExpr(rhs.getContext(), expr);
355 }
356
357 static void bindDerived(ClassTy &c) {
358 c.def_static("get", &PyAffineCeilDivExpr::get);
359 }
360};
361
362} // namespace
363
364bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
365 return mlirAffineExprEqual(affineExpr, other.affineExpr);
366}
367
368py::object PyAffineExpr::getCapsule() {
369 return py::reinterpret_steal<py::object>(
370 mlirPythonAffineExprToCapsule(*this));
371}
372
373PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
374 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
375 if (mlirAffineExprIsNull(rawAffineExpr))
376 throw py::error_already_set();
377 return PyAffineExpr(
378 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
379 rawAffineExpr);
380}
381
382//------------------------------------------------------------------------------
383// PyAffineMap and utilities.
384//------------------------------------------------------------------------------
385namespace {
386
387/// A list of expressions contained in an affine map. Internally these are
388/// stored as a consecutive array leading to inexpensive random access. Both
389/// the map and the expression are owned by the context so we need not bother
390/// with lifetime extension.
391class PyAffineMapExprList
392 : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
393public:
394 static constexpr const char *pyClassName = "AffineExprList";
395
396 PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0,
397 intptr_t length = -1, intptr_t step = 1)
398 : Sliceable(startIndex,
399 length == -1 ? mlirAffineMapGetNumResults(map) : length,
400 step),
401 affineMap(map) {}
402
403private:
404 /// Give the parent CRTP class access to hook implementations below.
405 friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
406
407 intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
408
409 PyAffineExpr getRawElement(intptr_t pos) {
410 return PyAffineExpr(affineMap.getContext(),
411 mlirAffineMapGetResult(affineMap, pos));
412 }
413
414 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
415 intptr_t step) {
416 return PyAffineMapExprList(affineMap, startIndex, length, step);
417 }
418
419 PyAffineMap affineMap;
420};
421} // namespace
422
423bool PyAffineMap::operator==(const PyAffineMap &other) const {
424 return mlirAffineMapEqual(affineMap, other.affineMap);
425}
426
427py::object PyAffineMap::getCapsule() {
428 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
429}
430
431PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
432 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
433 if (mlirAffineMapIsNull(rawAffineMap))
434 throw py::error_already_set();
435 return PyAffineMap(
436 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
437 rawAffineMap);
438}
439
440//------------------------------------------------------------------------------
441// PyIntegerSet and utilities.
442//------------------------------------------------------------------------------
443namespace {
444
445class PyIntegerSetConstraint {
446public:
447 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos)
448 : set(std::move(set)), pos(pos) {}
449
450 PyAffineExpr getExpr() {
451 return PyAffineExpr(set.getContext(),
452 mlirIntegerSetGetConstraint(set, pos));
453 }
454
455 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
456
457 static void bind(py::module &m) {
458 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
459 py::module_local())
460 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
461 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
462 }
463
464private:
465 PyIntegerSet set;
466 intptr_t pos;
467};
468
469class PyIntegerSetConstraintList
470 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
471public:
472 static constexpr const char *pyClassName = "IntegerSetConstraintList";
473
474 PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0,
475 intptr_t length = -1, intptr_t step = 1)
476 : Sliceable(startIndex,
477 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
478 step),
479 set(set) {}
480
481private:
482 /// Give the parent CRTP class access to hook implementations below.
483 friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
484
485 intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
486
487 PyIntegerSetConstraint getRawElement(intptr_t pos) {
488 return PyIntegerSetConstraint(set, pos);
489 }
490
491 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
492 intptr_t step) {
493 return PyIntegerSetConstraintList(set, startIndex, length, step);
494 }
495
496 PyIntegerSet set;
497};
498} // namespace
499
500bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
501 return mlirIntegerSetEqual(integerSet, other.integerSet);
502}
503
504py::object PyIntegerSet::getCapsule() {
505 return py::reinterpret_steal<py::object>(
506 mlirPythonIntegerSetToCapsule(*this));
507}
508
509PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
510 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
511 if (mlirIntegerSetIsNull(rawIntegerSet))
512 throw py::error_already_set();
513 return PyIntegerSet(
514 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
515 rawIntegerSet);
516}
517
518void mlir::python::populateIRAffine(py::module &m) {
519 //----------------------------------------------------------------------------
520 // Mapping of PyAffineExpr and derived classes.
521 //----------------------------------------------------------------------------
522 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
523 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
524 &PyAffineExpr::getCapsule)
525 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
526 .def("__add__", &PyAffineAddExpr::get)
527 .def("__add__", &PyAffineAddExpr::getRHSConstant)
528 .def("__radd__", &PyAffineAddExpr::getRHSConstant)
529 .def("__mul__", &PyAffineMulExpr::get)
530 .def("__mul__", &PyAffineMulExpr::getRHSConstant)
531 .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
532 .def("__mod__", &PyAffineModExpr::get)
533 .def("__mod__", &PyAffineModExpr::getRHSConstant)
534 .def("__rmod__",
535 [](PyAffineExpr &self, intptr_t other) {
536 return PyAffineModExpr::get(
537 PyAffineConstantExpr::get(other, *self.getContext().get()),
538 self);
539 })
540 .def("__sub__",
541 [](PyAffineExpr &self, PyAffineExpr &other) {
542 auto negOne =
543 PyAffineConstantExpr::get(-1, *self.getContext().get());
544 return PyAffineAddExpr::get(self,
545 PyAffineMulExpr::get(negOne, other));
546 })
547 .def("__sub__",
548 [](PyAffineExpr &self, intptr_t other) {
549 return PyAffineAddExpr::get(
550 self,
551 PyAffineConstantExpr::get(-other, *self.getContext().get()));
552 })
553 .def("__rsub__",
554 [](PyAffineExpr &self, intptr_t other) {
555 return PyAffineAddExpr::getLHSConstant(
556 other, PyAffineMulExpr::getLHSConstant(-1, self));
557 })
558 .def("__eq__", [](PyAffineExpr &self,
559 PyAffineExpr &other) { return self == other; })
560 .def("__eq__",
561 [](PyAffineExpr &self, py::object &other) { return false; })
562 .def("__str__",
563 [](PyAffineExpr &self) {
564 PyPrintAccumulator printAccum;
565 mlirAffineExprPrint(self, printAccum.getCallback(),
566 printAccum.getUserData());
567 return printAccum.join();
568 })
569 .def("__repr__",
570 [](PyAffineExpr &self) {
571 PyPrintAccumulator printAccum;
572 printAccum.parts.append("AffineExpr(");
573 mlirAffineExprPrint(self, printAccum.getCallback(),
574 printAccum.getUserData());
575 printAccum.parts.append(")");
576 return printAccum.join();
577 })
578 .def("__hash__",
579 [](PyAffineExpr &self) {
580 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
581 })
582 .def_property_readonly(
583 "context",
584 [](PyAffineExpr &self) { return self.getContext().getObject(); })
585 .def("compose",
586 [](PyAffineExpr &self, PyAffineMap &other) {
587 return PyAffineExpr(self.getContext(),
588 mlirAffineExprCompose(self, other));
589 })
590 .def_static(
591 "get_add", &PyAffineAddExpr::get,
592 "Gets an affine expression containing a sum of two expressions.")
593 .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
594 "Gets an affine expression containing a sum of a constant "
595 "and another expression.")
596 .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
597 "Gets an affine expression containing a sum of an expression "
598 "and a constant.")
599 .def_static(
600 "get_mul", &PyAffineMulExpr::get,
601 "Gets an affine expression containing a product of two expressions.")
602 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
603 "Gets an affine expression containing a product of a "
604 "constant and another expression.")
605 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
606 "Gets an affine expression containing a product of an "
607 "expression and a constant.")
608 .def_static("get_mod", &PyAffineModExpr::get,
609 "Gets an affine expression containing the modulo of dividing "
610 "one expression by another.")
611 .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
612 "Gets a semi-affine expression containing the modulo of "
613 "dividing a constant by an expression.")
614 .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
615 "Gets an affine expression containing the module of dividing"
616 "an expression by a constant.")
617 .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
618 "Gets an affine expression containing the rounded-down "
619 "result of dividing one expression by another.")
620 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
621 "Gets a semi-affine expression containing the rounded-down "
622 "result of dividing a constant by an expression.")
623 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
624 "Gets an affine expression containing the rounded-down "
625 "result of dividing an expression by a constant.")
626 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
627 "Gets an affine expression containing the rounded-up result "
628 "of dividing one expression by another.")
629 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
630 "Gets a semi-affine expression containing the rounded-up "
631 "result of dividing a constant by an expression.")
632 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
633 "Gets an affine expression containing the rounded-up result "
634 "of dividing an expression by a constant.")
635 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
636 py::arg("context") = py::none(),
637 "Gets a constant affine expression with the given value.")
638 .def_static(
639 "get_dim", &PyAffineDimExpr::get, py::arg("position"),
640 py::arg("context") = py::none(),
641 "Gets an affine expression of a dimension at the given position.")
642 .def_static(
643 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
644 py::arg("context") = py::none(),
645 "Gets an affine expression of a symbol at the given position.")
646 .def(
647 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
648 kDumpDocstring);
649 PyAffineConstantExpr::bind(m);
650 PyAffineDimExpr::bind(m);
651 PyAffineSymbolExpr::bind(m);
652 PyAffineBinaryExpr::bind(m);
653 PyAffineAddExpr::bind(m);
654 PyAffineMulExpr::bind(m);
655 PyAffineModExpr::bind(m);
656 PyAffineFloorDivExpr::bind(m);
657 PyAffineCeilDivExpr::bind(m);
658
659 //----------------------------------------------------------------------------
660 // Mapping of PyAffineMap.
661 //----------------------------------------------------------------------------
662 py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
663 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
664 &PyAffineMap::getCapsule)
665 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
666 .def("__eq__",
667 [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
668 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
669 .def("__str__",
670 [](PyAffineMap &self) {
671 PyPrintAccumulator printAccum;
672 mlirAffineMapPrint(self, printAccum.getCallback(),
673 printAccum.getUserData());
674 return printAccum.join();
675 })
676 .def("__repr__",
677 [](PyAffineMap &self) {
678 PyPrintAccumulator printAccum;
679 printAccum.parts.append("AffineMap(");
680 mlirAffineMapPrint(self, printAccum.getCallback(),
681 printAccum.getUserData());
682 printAccum.parts.append(")");
683 return printAccum.join();
684 })
685 .def("__hash__",
686 [](PyAffineMap &self) {
687 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
688 })
689 .def_static("compress_unused_symbols",
690 [](py::list affineMaps, DefaultingPyMlirContext context) {
691 SmallVector<MlirAffineMap> maps;
692 pyListToVector<PyAffineMap, MlirAffineMap>(
693 affineMaps, maps, "attempting to create an AffineMap");
694 std::vector<MlirAffineMap> compressed(affineMaps.size());
695 auto populate = [](void *result, intptr_t idx,
696 MlirAffineMap m) {
697 static_cast<MlirAffineMap *>(result)[idx] = (m);
698 };
699 mlirAffineMapCompressUnusedSymbols(
700 maps.data(), maps.size(), compressed.data(), populate);
701 std::vector<PyAffineMap> res;
702 res.reserve(compressed.size());
703 for (auto m : compressed)
704 res.emplace_back(context->getRef(), m);
705 return res;
706 })
707 .def_property_readonly(
708 "context",
709 [](PyAffineMap &self) { return self.getContext().getObject(); },
710 "Context that owns the Affine Map")
711 .def(
712 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
713 kDumpDocstring)
714 .def_static(
715 "get",
716 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
717 DefaultingPyMlirContext context) {
718 SmallVector<MlirAffineExpr> affineExprs;
719 pyListToVector<PyAffineExpr, MlirAffineExpr>(
720 exprs, affineExprs, "attempting to create an AffineMap");
721 MlirAffineMap map =
722 mlirAffineMapGet(context->get(), dimCount, symbolCount,
723 affineExprs.size(), affineExprs.data());
724 return PyAffineMap(context->getRef(), map);
725 },
726 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
727 py::arg("context") = py::none(),
728 "Gets a map with the given expressions as results.")
729 .def_static(
730 "get_constant",
731 [](intptr_t value, DefaultingPyMlirContext context) {
732 MlirAffineMap affineMap =
733 mlirAffineMapConstantGet(context->get(), value);
734 return PyAffineMap(context->getRef(), affineMap);
735 },
736 py::arg("value"), py::arg("context") = py::none(),
737 "Gets an affine map with a single constant result")
738 .def_static(
739 "get_empty",
740 [](DefaultingPyMlirContext context) {
741 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
742 return PyAffineMap(context->getRef(), affineMap);
743 },
744 py::arg("context") = py::none(), "Gets an empty affine map.")
745 .def_static(
746 "get_identity",
747 [](intptr_t nDims, DefaultingPyMlirContext context) {
748 MlirAffineMap affineMap =
749 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
750 return PyAffineMap(context->getRef(), affineMap);
751 },
752 py::arg("n_dims"), py::arg("context") = py::none(),
753 "Gets an identity map with the given number of dimensions.")
754 .def_static(
755 "get_minor_identity",
756 [](intptr_t nDims, intptr_t nResults,
757 DefaultingPyMlirContext context) {
758 MlirAffineMap affineMap =
759 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
760 return PyAffineMap(context->getRef(), affineMap);
761 },
762 py::arg("n_dims"), py::arg("n_results"),
763 py::arg("context") = py::none(),
764 "Gets a minor identity map with the given number of dimensions and "
765 "results.")
766 .def_static(
767 "get_permutation",
768 [](std::vector<unsigned> permutation,
769 DefaultingPyMlirContext context) {
770 if (!isPermutation(permutation))
771 throw py::cast_error("Invalid permutation when attempting to "
772 "create an AffineMap");
773 MlirAffineMap affineMap = mlirAffineMapPermutationGet(
774 context->get(), permutation.size(), permutation.data());
775 return PyAffineMap(context->getRef(), affineMap);
776 },
777 py::arg("permutation"), py::arg("context") = py::none(),
778 "Gets an affine map that permutes its inputs.")
779 .def(
780 "get_submap",
781 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
782 intptr_t numResults = mlirAffineMapGetNumResults(self);
783 for (intptr_t pos : resultPos) {
784 if (pos < 0 || pos >= numResults)
785 throw py::value_error("result position out of bounds");
786 }
787 MlirAffineMap affineMap = mlirAffineMapGetSubMap(
788 self, resultPos.size(), resultPos.data());
789 return PyAffineMap(self.getContext(), affineMap);
790 },
791 py::arg("result_positions"))
792 .def(
793 "get_major_submap",
794 [](PyAffineMap &self, intptr_t nResults) {
795 if (nResults >= mlirAffineMapGetNumResults(self))
796 throw py::value_error("number of results out of bounds");
797 MlirAffineMap affineMap =
798 mlirAffineMapGetMajorSubMap(self, nResults);
799 return PyAffineMap(self.getContext(), affineMap);
800 },
801 py::arg("n_results"))
802 .def(
803 "get_minor_submap",
804 [](PyAffineMap &self, intptr_t nResults) {
805 if (nResults >= mlirAffineMapGetNumResults(self))
806 throw py::value_error("number of results out of bounds");
807 MlirAffineMap affineMap =
808 mlirAffineMapGetMinorSubMap(self, nResults);
809 return PyAffineMap(self.getContext(), affineMap);
810 },
811 py::arg("n_results"))
812 .def(
813 "replace",
814 [](PyAffineMap &self, PyAffineExpr &expression,
815 PyAffineExpr &replacement, intptr_t numResultDims,
816 intptr_t numResultSyms) {
817 MlirAffineMap affineMap = mlirAffineMapReplace(
818 self, expression, replacement, numResultDims, numResultSyms);
819 return PyAffineMap(self.getContext(), affineMap);
820 },
821 py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
822 py::arg("n_result_syms"))
823 .def_property_readonly(
824 "is_permutation",
825 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
826 .def_property_readonly("is_projected_permutation",
827 [](PyAffineMap &self) {
828 return mlirAffineMapIsProjectedPermutation(self);
829 })
830 .def_property_readonly(
831 "n_dims",
832 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
833 .def_property_readonly(
834 "n_inputs",
835 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
836 .def_property_readonly(
837 "n_symbols",
838 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
839 .def_property_readonly("results", [](PyAffineMap &self) {
840 return PyAffineMapExprList(self);
841 });
842 PyAffineMapExprList::bind(m);
843
844 //----------------------------------------------------------------------------
845 // Mapping of PyIntegerSet.
846 //----------------------------------------------------------------------------
847 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
848 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
849 &PyIntegerSet::getCapsule)
850 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
851 .def("__eq__", [](PyIntegerSet &self,
852 PyIntegerSet &other) { return self == other; })
853 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
854 .def("__str__",
855 [](PyIntegerSet &self) {
856 PyPrintAccumulator printAccum;
857 mlirIntegerSetPrint(self, printAccum.getCallback(),
858 printAccum.getUserData());
859 return printAccum.join();
860 })
861 .def("__repr__",
862 [](PyIntegerSet &self) {
863 PyPrintAccumulator printAccum;
864 printAccum.parts.append("IntegerSet(");
865 mlirIntegerSetPrint(self, printAccum.getCallback(),
866 printAccum.getUserData());
867 printAccum.parts.append(")");
868 return printAccum.join();
869 })
870 .def("__hash__",
871 [](PyIntegerSet &self) {
872 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
873 })
874 .def_property_readonly(
875 "context",
876 [](PyIntegerSet &self) { return self.getContext().getObject(); })
877 .def(
878 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
879 kDumpDocstring)
880 .def_static(
881 "get",
882 [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
883 std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
884 if (exprs.size() != eqFlags.size())
885 throw py::value_error(
886 "Expected the number of constraints to match "
887 "that of equality flags");
888 if (exprs.empty())
889 throw py::value_error("Expected non-empty list of constraints");
890
891 // Copy over to a SmallVector because std::vector has a
892 // specialization for booleans that packs data and does not
893 // expose a `bool *`.
894 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
895
896 SmallVector<MlirAffineExpr> affineExprs;
897 pyListToVector<PyAffineExpr>(exprs, affineExprs,
898 "attempting to create an IntegerSet");
899 MlirIntegerSet set = mlirIntegerSetGet(
900 context->get(), numDims, numSymbols, exprs.size(),
901 affineExprs.data(), flags.data());
902 return PyIntegerSet(context->getRef(), set);
903 },
904 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
905 py::arg("eq_flags"), py::arg("context") = py::none())
906 .def_static(
907 "get_empty",
908 [](intptr_t numDims, intptr_t numSymbols,
909 DefaultingPyMlirContext context) {
910 MlirIntegerSet set =
911 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
912 return PyIntegerSet(context->getRef(), set);
913 },
914 py::arg("num_dims"), py::arg("num_symbols"),
915 py::arg("context") = py::none())
916 .def(
917 "get_replaced",
918 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
919 intptr_t numResultDims, intptr_t numResultSymbols) {
920 if (static_cast<intptr_t>(dimExprs.size()) !=
921 mlirIntegerSetGetNumDims(self))
922 throw py::value_error(
923 "Expected the number of dimension replacement expressions "
924 "to match that of dimensions");
925 if (static_cast<intptr_t>(symbolExprs.size()) !=
926 mlirIntegerSetGetNumSymbols(self))
927 throw py::value_error(
928 "Expected the number of symbol replacement expressions "
929 "to match that of symbols");
930
931 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
932 pyListToVector<PyAffineExpr>(
933 dimExprs, dimAffineExprs,
934 "attempting to create an IntegerSet by replacing dimensions");
935 pyListToVector<PyAffineExpr>(
936 symbolExprs, symbolAffineExprs,
937 "attempting to create an IntegerSet by replacing symbols");
938 MlirIntegerSet set = mlirIntegerSetReplaceGet(
939 self, dimAffineExprs.data(), symbolAffineExprs.data(),
940 numResultDims, numResultSymbols);
941 return PyIntegerSet(self.getContext(), set);
942 },
943 py::arg("dim_exprs"), py::arg("symbol_exprs"),
944 py::arg("num_result_dims"), py::arg("num_result_symbols"))
945 .def_property_readonly("is_canonical_empty",
946 [](PyIntegerSet &self) {
947 return mlirIntegerSetIsCanonicalEmpty(self);
948 })
949 .def_property_readonly(
950 "n_dims",
951 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
952 .def_property_readonly(
953 "n_symbols",
954 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
955 .def_property_readonly(
956 "n_inputs",
957 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
958 .def_property_readonly("n_equalities",
959 [](PyIntegerSet &self) {
960 return mlirIntegerSetGetNumEqualities(self);
961 })
962 .def_property_readonly("n_inequalities",
963 [](PyIntegerSet &self) {
964 return mlirIntegerSetGetNumInequalities(self);
965 })
966 .def_property_readonly("constraints", [](PyIntegerSet &self) {
967 return PyIntegerSetConstraintList(self);
968 });
969 PyIntegerSetConstraint::bind(m);
970 PyIntegerSetConstraintList::bind(m);
971}
972

source code of mlir/lib/Bindings/Python/IRAffine.cpp