1//===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstddef>
10#include <cstdint>
11#include <stdexcept>
12#include <string>
13#include <utility>
14#include <vector>
15
16#include "IRModule.h"
17#include "NanobindUtils.h"
18#include "mlir-c/AffineExpr.h"
19#include "mlir-c/AffineMap.h"
20#include "mlir-c/IntegerSet.h"
21#include "mlir/Bindings/Python/Nanobind.h"
22#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
23#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/Hashing.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringRef.h"
27#include "llvm/ADT/Twine.h"
28
29namespace nb = nanobind;
30using namespace mlir;
31using namespace mlir::python;
32
33using llvm::SmallVector;
34using llvm::StringRef;
35using llvm::Twine;
36
37static const char kDumpDocstring[] =
38 R"(Dumps a debug representation of the object to stderr.)";
39
40/// Attempts to populate `result` with the content of `list` casted to the
41/// appropriate type (Python and C types are provided as template arguments).
42/// Throws errors in case of failure, using "action" to describe what the caller
43/// was attempting to do.
44template <typename PyType, typename CType>
45static void pyListToVector(const nb::list &list,
46 llvm::SmallVectorImpl<CType> &result,
47 StringRef action) {
48 result.reserve(nb::len(list));
49 for (nb::handle item : list) {
50 try {
51 result.push_back(nb::cast<PyType>(item));
52 } catch (nb::cast_error &err) {
53 std::string msg = (llvm::Twine("Invalid expression when ") + action +
54 " (" + err.what() + ")")
55 .str();
56 throw std::runtime_error(msg.c_str());
57 } catch (std::runtime_error &err) {
58 std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
59 action + " (" + err.what() + ")")
60 .str();
61 throw std::runtime_error(msg.c_str());
62 }
63 }
64}
65
66template <typename PermutationTy>
67static bool isPermutation(std::vector<PermutationTy> permutation) {
68 llvm::SmallVector<bool, 8> seen(permutation.size(), false);
69 for (auto val : permutation) {
70 if (val < permutation.size()) {
71 if (seen[val])
72 return false;
73 seen[val] = true;
74 continue;
75 }
76 return false;
77 }
78 return true;
79}
80
81namespace {
82
83/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
84/// and should be castable from it. Intermediate hierarchy classes can be
85/// modeled by specifying BaseTy.
86template <typename DerivedTy, typename BaseTy = PyAffineExpr>
87class PyConcreteAffineExpr : public BaseTy {
88public:
89 // Derived classes must define statics for:
90 // IsAFunctionTy isaFunction
91 // const char *pyClassName
92 // and redefine bindDerived.
93 using ClassTy = nb::class_<DerivedTy, BaseTy>;
94 using IsAFunctionTy = bool (*)(MlirAffineExpr);
95
96 PyConcreteAffineExpr() = default;
97 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
98 : BaseTy(std::move(contextRef), affineExpr) {}
99 PyConcreteAffineExpr(PyAffineExpr &orig)
100 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
101
102 static MlirAffineExpr castFrom(PyAffineExpr &orig) {
103 if (!DerivedTy::isaFunction(orig)) {
104 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
105 throw nb::value_error((Twine("Cannot cast affine expression to ") +
106 DerivedTy::pyClassName + " (from " + origRepr +
107 ")")
108 .str()
109 .c_str());
110 }
111 return orig;
112 }
113
114 static void bind(nb::module_ &m) {
115 auto cls = ClassTy(m, DerivedTy::pyClassName);
116 cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
117 cls.def_static(
118 "isinstance",
119 [](PyAffineExpr &otherAffineExpr) -> bool {
120 return DerivedTy::isaFunction(otherAffineExpr);
121 },
122 nb::arg("other"));
123 DerivedTy::bindDerived(cls);
124 }
125
126 /// Implemented by derived classes to add methods to the Python subclass.
127 static void bindDerived(ClassTy &m) {}
128};
129
130class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
131public:
132 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
133 static constexpr const char *pyClassName = "AffineConstantExpr";
134 using PyConcreteAffineExpr::PyConcreteAffineExpr;
135
136 static PyAffineConstantExpr get(intptr_t value,
137 DefaultingPyMlirContext context) {
138 MlirAffineExpr affineExpr =
139 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
140 return PyAffineConstantExpr(context->getRef(), affineExpr);
141 }
142
143 static void bindDerived(ClassTy &c) {
144 c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"),
145 nb::arg("context").none() = nb::none());
146 c.def_prop_ro("value", [](PyAffineConstantExpr &self) {
147 return mlirAffineConstantExprGetValue(self);
148 });
149 }
150};
151
152class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
153public:
154 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
155 static constexpr const char *pyClassName = "AffineDimExpr";
156 using PyConcreteAffineExpr::PyConcreteAffineExpr;
157
158 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
159 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
160 return PyAffineDimExpr(context->getRef(), affineExpr);
161 }
162
163 static void bindDerived(ClassTy &c) {
164 c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"),
165 nb::arg("context").none() = nb::none());
166 c.def_prop_ro("position", [](PyAffineDimExpr &self) {
167 return mlirAffineDimExprGetPosition(self);
168 });
169 }
170};
171
172class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
173public:
174 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
175 static constexpr const char *pyClassName = "AffineSymbolExpr";
176 using PyConcreteAffineExpr::PyConcreteAffineExpr;
177
178 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
179 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
180 return PyAffineSymbolExpr(context->getRef(), affineExpr);
181 }
182
183 static void bindDerived(ClassTy &c) {
184 c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"),
185 nb::arg("context").none() = nb::none());
186 c.def_prop_ro("position", [](PyAffineSymbolExpr &self) {
187 return mlirAffineSymbolExprGetPosition(self);
188 });
189 }
190};
191
192class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
193public:
194 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
195 static constexpr const char *pyClassName = "AffineBinaryExpr";
196 using PyConcreteAffineExpr::PyConcreteAffineExpr;
197
198 PyAffineExpr lhs() {
199 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
200 return PyAffineExpr(getContext(), lhsExpr);
201 }
202
203 PyAffineExpr rhs() {
204 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
205 return PyAffineExpr(getContext(), rhsExpr);
206 }
207
208 static void bindDerived(ClassTy &c) {
209 c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs);
210 c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs);
211 }
212};
213
214class PyAffineAddExpr
215 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
216public:
217 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
218 static constexpr const char *pyClassName = "AffineAddExpr";
219 using PyConcreteAffineExpr::PyConcreteAffineExpr;
220
221 static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
222 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
223 return PyAffineAddExpr(lhs.getContext(), expr);
224 }
225
226 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
227 MlirAffineExpr expr = mlirAffineAddExprGet(
228 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
229 return PyAffineAddExpr(lhs.getContext(), expr);
230 }
231
232 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
233 MlirAffineExpr expr = mlirAffineAddExprGet(
234 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
235 return PyAffineAddExpr(rhs.getContext(), expr);
236 }
237
238 static void bindDerived(ClassTy &c) {
239 c.def_static("get", &PyAffineAddExpr::get);
240 }
241};
242
243class PyAffineMulExpr
244 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
245public:
246 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
247 static constexpr const char *pyClassName = "AffineMulExpr";
248 using PyConcreteAffineExpr::PyConcreteAffineExpr;
249
250 static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
251 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
252 return PyAffineMulExpr(lhs.getContext(), expr);
253 }
254
255 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
256 MlirAffineExpr expr = mlirAffineMulExprGet(
257 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
258 return PyAffineMulExpr(lhs.getContext(), expr);
259 }
260
261 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
262 MlirAffineExpr expr = mlirAffineMulExprGet(
263 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
264 return PyAffineMulExpr(rhs.getContext(), expr);
265 }
266
267 static void bindDerived(ClassTy &c) {
268 c.def_static("get", &PyAffineMulExpr::get);
269 }
270};
271
272class PyAffineModExpr
273 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
274public:
275 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
276 static constexpr const char *pyClassName = "AffineModExpr";
277 using PyConcreteAffineExpr::PyConcreteAffineExpr;
278
279 static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
280 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
281 return PyAffineModExpr(lhs.getContext(), expr);
282 }
283
284 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
285 MlirAffineExpr expr = mlirAffineModExprGet(
286 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
287 return PyAffineModExpr(lhs.getContext(), expr);
288 }
289
290 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
291 MlirAffineExpr expr = mlirAffineModExprGet(
292 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
293 return PyAffineModExpr(rhs.getContext(), expr);
294 }
295
296 static void bindDerived(ClassTy &c) {
297 c.def_static("get", &PyAffineModExpr::get);
298 }
299};
300
301class PyAffineFloorDivExpr
302 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
303public:
304 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
305 static constexpr const char *pyClassName = "AffineFloorDivExpr";
306 using PyConcreteAffineExpr::PyConcreteAffineExpr;
307
308 static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
309 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
310 return PyAffineFloorDivExpr(lhs.getContext(), expr);
311 }
312
313 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
314 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
315 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
316 return PyAffineFloorDivExpr(lhs.getContext(), expr);
317 }
318
319 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
320 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
321 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
322 return PyAffineFloorDivExpr(rhs.getContext(), expr);
323 }
324
325 static void bindDerived(ClassTy &c) {
326 c.def_static("get", &PyAffineFloorDivExpr::get);
327 }
328};
329
330class PyAffineCeilDivExpr
331 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
332public:
333 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
334 static constexpr const char *pyClassName = "AffineCeilDivExpr";
335 using PyConcreteAffineExpr::PyConcreteAffineExpr;
336
337 static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
338 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
339 return PyAffineCeilDivExpr(lhs.getContext(), expr);
340 }
341
342 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
343 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
344 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
345 return PyAffineCeilDivExpr(lhs.getContext(), expr);
346 }
347
348 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
349 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
350 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
351 return PyAffineCeilDivExpr(rhs.getContext(), expr);
352 }
353
354 static void bindDerived(ClassTy &c) {
355 c.def_static("get", &PyAffineCeilDivExpr::get);
356 }
357};
358
359} // namespace
360
361bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
362 return mlirAffineExprEqual(affineExpr, other.affineExpr);
363}
364
365nb::object PyAffineExpr::getCapsule() {
366 return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this));
367}
368
369PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) {
370 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
371 if (mlirAffineExprIsNull(rawAffineExpr))
372 throw nb::python_error();
373 return PyAffineExpr(
374 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
375 rawAffineExpr);
376}
377
378//------------------------------------------------------------------------------
379// PyAffineMap and utilities.
380//------------------------------------------------------------------------------
381namespace {
382
383/// A list of expressions contained in an affine map. Internally these are
384/// stored as a consecutive array leading to inexpensive random access. Both
385/// the map and the expression are owned by the context so we need not bother
386/// with lifetime extension.
387class PyAffineMapExprList
388 : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
389public:
390 static constexpr const char *pyClassName = "AffineExprList";
391
392 PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0,
393 intptr_t length = -1, intptr_t step = 1)
394 : Sliceable(startIndex,
395 length == -1 ? mlirAffineMapGetNumResults(map) : length,
396 step),
397 affineMap(map) {}
398
399private:
400 /// Give the parent CRTP class access to hook implementations below.
401 friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
402
403 intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
404
405 PyAffineExpr getRawElement(intptr_t pos) {
406 return PyAffineExpr(affineMap.getContext(),
407 mlirAffineMapGetResult(affineMap, pos));
408 }
409
410 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
411 intptr_t step) {
412 return PyAffineMapExprList(affineMap, startIndex, length, step);
413 }
414
415 PyAffineMap affineMap;
416};
417} // namespace
418
419bool PyAffineMap::operator==(const PyAffineMap &other) const {
420 return mlirAffineMapEqual(affineMap, other.affineMap);
421}
422
423nb::object PyAffineMap::getCapsule() {
424 return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this));
425}
426
427PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) {
428 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
429 if (mlirAffineMapIsNull(rawAffineMap))
430 throw nb::python_error();
431 return PyAffineMap(
432 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
433 rawAffineMap);
434}
435
436//------------------------------------------------------------------------------
437// PyIntegerSet and utilities.
438//------------------------------------------------------------------------------
439namespace {
440
441class PyIntegerSetConstraint {
442public:
443 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos)
444 : set(std::move(set)), pos(pos) {}
445
446 PyAffineExpr getExpr() {
447 return PyAffineExpr(set.getContext(),
448 mlirIntegerSetGetConstraint(set, pos));
449 }
450
451 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
452
453 static void bind(nb::module_ &m) {
454 nb::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
455 .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr)
456 .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq);
457 }
458
459private:
460 PyIntegerSet set;
461 intptr_t pos;
462};
463
464class PyIntegerSetConstraintList
465 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
466public:
467 static constexpr const char *pyClassName = "IntegerSetConstraintList";
468
469 PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0,
470 intptr_t length = -1, intptr_t step = 1)
471 : Sliceable(startIndex,
472 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
473 step),
474 set(set) {}
475
476private:
477 /// Give the parent CRTP class access to hook implementations below.
478 friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
479
480 intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
481
482 PyIntegerSetConstraint getRawElement(intptr_t pos) {
483 return PyIntegerSetConstraint(set, pos);
484 }
485
486 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
487 intptr_t step) {
488 return PyIntegerSetConstraintList(set, startIndex, length, step);
489 }
490
491 PyIntegerSet set;
492};
493} // namespace
494
495bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
496 return mlirIntegerSetEqual(integerSet, other.integerSet);
497}
498
499nb::object PyIntegerSet::getCapsule() {
500 return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this));
501}
502
503PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) {
504 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
505 if (mlirIntegerSetIsNull(rawIntegerSet))
506 throw nb::python_error();
507 return PyIntegerSet(
508 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
509 rawIntegerSet);
510}
511
512void mlir::python::populateIRAffine(nb::module_ &m) {
513 //----------------------------------------------------------------------------
514 // Mapping of PyAffineExpr and derived classes.
515 //----------------------------------------------------------------------------
516 nb::class_<PyAffineExpr>(m, "AffineExpr")
517 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule)
518 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
519 .def("__add__", &PyAffineAddExpr::get)
520 .def("__add__", &PyAffineAddExpr::getRHSConstant)
521 .def("__radd__", &PyAffineAddExpr::getRHSConstant)
522 .def("__mul__", &PyAffineMulExpr::get)
523 .def("__mul__", &PyAffineMulExpr::getRHSConstant)
524 .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
525 .def("__mod__", &PyAffineModExpr::get)
526 .def("__mod__", &PyAffineModExpr::getRHSConstant)
527 .def("__rmod__",
528 [](PyAffineExpr &self, intptr_t other) {
529 return PyAffineModExpr::get(
530 PyAffineConstantExpr::get(other, *self.getContext().get()),
531 self);
532 })
533 .def("__sub__",
534 [](PyAffineExpr &self, PyAffineExpr &other) {
535 auto negOne =
536 PyAffineConstantExpr::get(-1, *self.getContext().get());
537 return PyAffineAddExpr::get(self,
538 PyAffineMulExpr::get(negOne, other));
539 })
540 .def("__sub__",
541 [](PyAffineExpr &self, intptr_t other) {
542 return PyAffineAddExpr::get(
543 self,
544 PyAffineConstantExpr::get(-other, *self.getContext().get()));
545 })
546 .def("__rsub__",
547 [](PyAffineExpr &self, intptr_t other) {
548 return PyAffineAddExpr::getLHSConstant(
549 other, PyAffineMulExpr::getLHSConstant(-1, self));
550 })
551 .def("__eq__", [](PyAffineExpr &self,
552 PyAffineExpr &other) { return self == other; })
553 .def("__eq__",
554 [](PyAffineExpr &self, nb::object &other) { return false; })
555 .def("__str__",
556 [](PyAffineExpr &self) {
557 PyPrintAccumulator printAccum;
558 mlirAffineExprPrint(self, printAccum.getCallback(),
559 printAccum.getUserData());
560 return printAccum.join();
561 })
562 .def("__repr__",
563 [](PyAffineExpr &self) {
564 PyPrintAccumulator printAccum;
565 printAccum.parts.append("AffineExpr(");
566 mlirAffineExprPrint(self, printAccum.getCallback(),
567 printAccum.getUserData());
568 printAccum.parts.append(")");
569 return printAccum.join();
570 })
571 .def("__hash__",
572 [](PyAffineExpr &self) {
573 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
574 })
575 .def_prop_ro(
576 "context",
577 [](PyAffineExpr &self) { return self.getContext().getObject(); })
578 .def("compose",
579 [](PyAffineExpr &self, PyAffineMap &other) {
580 return PyAffineExpr(self.getContext(),
581 mlirAffineExprCompose(self, other));
582 })
583 .def(
584 "shift_dims",
585 [](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
586 uint32_t offset) {
587 return PyAffineExpr(
588 self.getContext(),
589 mlirAffineExprShiftDims(self, numDims, shift, offset));
590 },
591 nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0)
592 .def(
593 "shift_symbols",
594 [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift,
595 uint32_t offset) {
596 return PyAffineExpr(
597 self.getContext(),
598 mlirAffineExprShiftSymbols(self, numSymbols, shift, offset));
599 },
600 nb::arg("num_symbols"), nb::arg("shift"),
601 nb::arg("offset").none() = 0)
602 .def_static(
603 "simplify_affine_expr",
604 [](PyAffineExpr &self, uint32_t numDims, uint32_t numSymbols) {
605 return PyAffineExpr(
606 self.getContext(),
607 mlirSimplifyAffineExpr(self, numDims, numSymbols));
608 },
609 nb::arg("expr"), nb::arg("num_dims"), nb::arg("num_symbols"),
610 "Simplify an affine expression by flattening and some amount of "
611 "simple analysis.")
612 .def_static(
613 "get_add", &PyAffineAddExpr::get,
614 "Gets an affine expression containing a sum of two expressions.")
615 .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
616 "Gets an affine expression containing a sum of a constant "
617 "and another expression.")
618 .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
619 "Gets an affine expression containing a sum of an expression "
620 "and a constant.")
621 .def_static(
622 "get_mul", &PyAffineMulExpr::get,
623 "Gets an affine expression containing a product of two expressions.")
624 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
625 "Gets an affine expression containing a product of a "
626 "constant and another expression.")
627 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
628 "Gets an affine expression containing a product of an "
629 "expression and a constant.")
630 .def_static("get_mod", &PyAffineModExpr::get,
631 "Gets an affine expression containing the modulo of dividing "
632 "one expression by another.")
633 .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
634 "Gets a semi-affine expression containing the modulo of "
635 "dividing a constant by an expression.")
636 .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
637 "Gets an affine expression containing the module of dividing"
638 "an expression by a constant.")
639 .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
640 "Gets an affine expression containing the rounded-down "
641 "result of dividing one expression by another.")
642 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
643 "Gets a semi-affine expression containing the rounded-down "
644 "result of dividing a constant by an expression.")
645 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
646 "Gets an affine expression containing the rounded-down "
647 "result of dividing an expression by a constant.")
648 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
649 "Gets an affine expression containing the rounded-up result "
650 "of dividing one expression by another.")
651 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
652 "Gets a semi-affine expression containing the rounded-up "
653 "result of dividing a constant by an expression.")
654 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
655 "Gets an affine expression containing the rounded-up result "
656 "of dividing an expression by a constant.")
657 .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"),
658 nb::arg("context").none() = nb::none(),
659 "Gets a constant affine expression with the given value.")
660 .def_static(
661 "get_dim", &PyAffineDimExpr::get, nb::arg("position"),
662 nb::arg("context").none() = nb::none(),
663 "Gets an affine expression of a dimension at the given position.")
664 .def_static(
665 "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"),
666 nb::arg("context").none() = nb::none(),
667 "Gets an affine expression of a symbol at the given position.")
668 .def(
669 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
670 kDumpDocstring);
671 PyAffineConstantExpr::bind(m);
672 PyAffineDimExpr::bind(m);
673 PyAffineSymbolExpr::bind(m);
674 PyAffineBinaryExpr::bind(m);
675 PyAffineAddExpr::bind(m);
676 PyAffineMulExpr::bind(m);
677 PyAffineModExpr::bind(m);
678 PyAffineFloorDivExpr::bind(m);
679 PyAffineCeilDivExpr::bind(m);
680
681 //----------------------------------------------------------------------------
682 // Mapping of PyAffineMap.
683 //----------------------------------------------------------------------------
684 nb::class_<PyAffineMap>(m, "AffineMap")
685 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule)
686 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
687 .def("__eq__",
688 [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
689 .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; })
690 .def("__str__",
691 [](PyAffineMap &self) {
692 PyPrintAccumulator printAccum;
693 mlirAffineMapPrint(self, printAccum.getCallback(),
694 printAccum.getUserData());
695 return printAccum.join();
696 })
697 .def("__repr__",
698 [](PyAffineMap &self) {
699 PyPrintAccumulator printAccum;
700 printAccum.parts.append("AffineMap(");
701 mlirAffineMapPrint(self, printAccum.getCallback(),
702 printAccum.getUserData());
703 printAccum.parts.append(")");
704 return printAccum.join();
705 })
706 .def("__hash__",
707 [](PyAffineMap &self) {
708 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
709 })
710 .def_static("compress_unused_symbols",
711 [](nb::list affineMaps, DefaultingPyMlirContext context) {
712 SmallVector<MlirAffineMap> maps;
713 pyListToVector<PyAffineMap, MlirAffineMap>(
714 affineMaps, maps, "attempting to create an AffineMap");
715 std::vector<MlirAffineMap> compressed(affineMaps.size());
716 auto populate = [](void *result, intptr_t idx,
717 MlirAffineMap m) {
718 static_cast<MlirAffineMap *>(result)[idx] = (m);
719 };
720 mlirAffineMapCompressUnusedSymbols(
721 maps.data(), maps.size(), compressed.data(), populate);
722 std::vector<PyAffineMap> res;
723 res.reserve(compressed.size());
724 for (auto m : compressed)
725 res.emplace_back(context->getRef(), m);
726 return res;
727 })
728 .def_prop_ro(
729 "context",
730 [](PyAffineMap &self) { return self.getContext().getObject(); },
731 "Context that owns the Affine Map")
732 .def(
733 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
734 kDumpDocstring)
735 .def_static(
736 "get",
737 [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs,
738 DefaultingPyMlirContext context) {
739 SmallVector<MlirAffineExpr> affineExprs;
740 pyListToVector<PyAffineExpr, MlirAffineExpr>(
741 exprs, affineExprs, "attempting to create an AffineMap");
742 MlirAffineMap map =
743 mlirAffineMapGet(context->get(), dimCount, symbolCount,
744 affineExprs.size(), affineExprs.data());
745 return PyAffineMap(context->getRef(), map);
746 },
747 nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"),
748 nb::arg("context").none() = nb::none(),
749 "Gets a map with the given expressions as results.")
750 .def_static(
751 "get_constant",
752 [](intptr_t value, DefaultingPyMlirContext context) {
753 MlirAffineMap affineMap =
754 mlirAffineMapConstantGet(context->get(), value);
755 return PyAffineMap(context->getRef(), affineMap);
756 },
757 nb::arg("value"), nb::arg("context").none() = nb::none(),
758 "Gets an affine map with a single constant result")
759 .def_static(
760 "get_empty",
761 [](DefaultingPyMlirContext context) {
762 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
763 return PyAffineMap(context->getRef(), affineMap);
764 },
765 nb::arg("context").none() = nb::none(), "Gets an empty affine map.")
766 .def_static(
767 "get_identity",
768 [](intptr_t nDims, DefaultingPyMlirContext context) {
769 MlirAffineMap affineMap =
770 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
771 return PyAffineMap(context->getRef(), affineMap);
772 },
773 nb::arg("n_dims"), nb::arg("context").none() = nb::none(),
774 "Gets an identity map with the given number of dimensions.")
775 .def_static(
776 "get_minor_identity",
777 [](intptr_t nDims, intptr_t nResults,
778 DefaultingPyMlirContext context) {
779 MlirAffineMap affineMap =
780 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
781 return PyAffineMap(context->getRef(), affineMap);
782 },
783 nb::arg("n_dims"), nb::arg("n_results"),
784 nb::arg("context").none() = nb::none(),
785 "Gets a minor identity map with the given number of dimensions and "
786 "results.")
787 .def_static(
788 "get_permutation",
789 [](std::vector<unsigned> permutation,
790 DefaultingPyMlirContext context) {
791 if (!isPermutation(permutation))
792 throw std::runtime_error("Invalid permutation when attempting to "
793 "create an AffineMap");
794 MlirAffineMap affineMap = mlirAffineMapPermutationGet(
795 context->get(), permutation.size(), permutation.data());
796 return PyAffineMap(context->getRef(), affineMap);
797 },
798 nb::arg("permutation"), nb::arg("context").none() = nb::none(),
799 "Gets an affine map that permutes its inputs.")
800 .def(
801 "get_submap",
802 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
803 intptr_t numResults = mlirAffineMapGetNumResults(self);
804 for (intptr_t pos : resultPos) {
805 if (pos < 0 || pos >= numResults)
806 throw nb::value_error("result position out of bounds");
807 }
808 MlirAffineMap affineMap = mlirAffineMapGetSubMap(
809 self, resultPos.size(), resultPos.data());
810 return PyAffineMap(self.getContext(), affineMap);
811 },
812 nb::arg("result_positions"))
813 .def(
814 "get_major_submap",
815 [](PyAffineMap &self, intptr_t nResults) {
816 if (nResults >= mlirAffineMapGetNumResults(self))
817 throw nb::value_error("number of results out of bounds");
818 MlirAffineMap affineMap =
819 mlirAffineMapGetMajorSubMap(self, nResults);
820 return PyAffineMap(self.getContext(), affineMap);
821 },
822 nb::arg("n_results"))
823 .def(
824 "get_minor_submap",
825 [](PyAffineMap &self, intptr_t nResults) {
826 if (nResults >= mlirAffineMapGetNumResults(self))
827 throw nb::value_error("number of results out of bounds");
828 MlirAffineMap affineMap =
829 mlirAffineMapGetMinorSubMap(self, nResults);
830 return PyAffineMap(self.getContext(), affineMap);
831 },
832 nb::arg("n_results"))
833 .def(
834 "replace",
835 [](PyAffineMap &self, PyAffineExpr &expression,
836 PyAffineExpr &replacement, intptr_t numResultDims,
837 intptr_t numResultSyms) {
838 MlirAffineMap affineMap = mlirAffineMapReplace(
839 self, expression, replacement, numResultDims, numResultSyms);
840 return PyAffineMap(self.getContext(), affineMap);
841 },
842 nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"),
843 nb::arg("n_result_syms"))
844 .def_prop_ro(
845 "is_permutation",
846 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
847 .def_prop_ro("is_projected_permutation",
848 [](PyAffineMap &self) {
849 return mlirAffineMapIsProjectedPermutation(self);
850 })
851 .def_prop_ro(
852 "n_dims",
853 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
854 .def_prop_ro(
855 "n_inputs",
856 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
857 .def_prop_ro(
858 "n_symbols",
859 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
860 .def_prop_ro("results",
861 [](PyAffineMap &self) { return PyAffineMapExprList(self); });
862 PyAffineMapExprList::bind(m);
863
864 //----------------------------------------------------------------------------
865 // Mapping of PyIntegerSet.
866 //----------------------------------------------------------------------------
867 nb::class_<PyIntegerSet>(m, "IntegerSet")
868 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule)
869 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
870 .def("__eq__", [](PyIntegerSet &self,
871 PyIntegerSet &other) { return self == other; })
872 .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; })
873 .def("__str__",
874 [](PyIntegerSet &self) {
875 PyPrintAccumulator printAccum;
876 mlirIntegerSetPrint(self, printAccum.getCallback(),
877 printAccum.getUserData());
878 return printAccum.join();
879 })
880 .def("__repr__",
881 [](PyIntegerSet &self) {
882 PyPrintAccumulator printAccum;
883 printAccum.parts.append("IntegerSet(");
884 mlirIntegerSetPrint(self, printAccum.getCallback(),
885 printAccum.getUserData());
886 printAccum.parts.append(")");
887 return printAccum.join();
888 })
889 .def("__hash__",
890 [](PyIntegerSet &self) {
891 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
892 })
893 .def_prop_ro(
894 "context",
895 [](PyIntegerSet &self) { return self.getContext().getObject(); })
896 .def(
897 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
898 kDumpDocstring)
899 .def_static(
900 "get",
901 [](intptr_t numDims, intptr_t numSymbols, nb::list exprs,
902 std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
903 if (exprs.size() != eqFlags.size())
904 throw nb::value_error(
905 "Expected the number of constraints to match "
906 "that of equality flags");
907 if (exprs.size() == 0)
908 throw nb::value_error("Expected non-empty list of constraints");
909
910 // Copy over to a SmallVector because std::vector has a
911 // specialization for booleans that packs data and does not
912 // expose a `bool *`.
913 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
914
915 SmallVector<MlirAffineExpr> affineExprs;
916 pyListToVector<PyAffineExpr>(exprs, affineExprs,
917 "attempting to create an IntegerSet");
918 MlirIntegerSet set = mlirIntegerSetGet(
919 context->get(), numDims, numSymbols, exprs.size(),
920 affineExprs.data(), flags.data());
921 return PyIntegerSet(context->getRef(), set);
922 },
923 nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"),
924 nb::arg("eq_flags"), nb::arg("context").none() = nb::none())
925 .def_static(
926 "get_empty",
927 [](intptr_t numDims, intptr_t numSymbols,
928 DefaultingPyMlirContext context) {
929 MlirIntegerSet set =
930 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
931 return PyIntegerSet(context->getRef(), set);
932 },
933 nb::arg("num_dims"), nb::arg("num_symbols"),
934 nb::arg("context").none() = nb::none())
935 .def(
936 "get_replaced",
937 [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs,
938 intptr_t numResultDims, intptr_t numResultSymbols) {
939 if (static_cast<intptr_t>(dimExprs.size()) !=
940 mlirIntegerSetGetNumDims(self))
941 throw nb::value_error(
942 "Expected the number of dimension replacement expressions "
943 "to match that of dimensions");
944 if (static_cast<intptr_t>(symbolExprs.size()) !=
945 mlirIntegerSetGetNumSymbols(self))
946 throw nb::value_error(
947 "Expected the number of symbol replacement expressions "
948 "to match that of symbols");
949
950 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
951 pyListToVector<PyAffineExpr>(
952 dimExprs, dimAffineExprs,
953 "attempting to create an IntegerSet by replacing dimensions");
954 pyListToVector<PyAffineExpr>(
955 symbolExprs, symbolAffineExprs,
956 "attempting to create an IntegerSet by replacing symbols");
957 MlirIntegerSet set = mlirIntegerSetReplaceGet(
958 self, dimAffineExprs.data(), symbolAffineExprs.data(),
959 numResultDims, numResultSymbols);
960 return PyIntegerSet(self.getContext(), set);
961 },
962 nb::arg("dim_exprs"), nb::arg("symbol_exprs"),
963 nb::arg("num_result_dims"), nb::arg("num_result_symbols"))
964 .def_prop_ro("is_canonical_empty",
965 [](PyIntegerSet &self) {
966 return mlirIntegerSetIsCanonicalEmpty(self);
967 })
968 .def_prop_ro(
969 "n_dims",
970 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
971 .def_prop_ro(
972 "n_symbols",
973 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
974 .def_prop_ro(
975 "n_inputs",
976 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
977 .def_prop_ro("n_equalities",
978 [](PyIntegerSet &self) {
979 return mlirIntegerSetGetNumEqualities(self);
980 })
981 .def_prop_ro("n_inequalities",
982 [](PyIntegerSet &self) {
983 return mlirIntegerSetGetNumInequalities(self);
984 })
985 .def_prop_ro("constraints", [](PyIntegerSet &self) {
986 return PyIntegerSetConstraintList(self);
987 });
988 PyIntegerSetConstraint::bind(m);
989 PyIntegerSetConstraintList::bind(m);
990}
991

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Bindings/Python/IRAffine.cpp