| 1 | //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include <cstddef> |
| 10 | #include <cstdint> |
| 11 | #include <stdexcept> |
| 12 | #include <string> |
| 13 | #include <utility> |
| 14 | #include <vector> |
| 15 | |
| 16 | #include "IRModule.h" |
| 17 | #include "NanobindUtils.h" |
| 18 | #include "mlir-c/AffineExpr.h" |
| 19 | #include "mlir-c/AffineMap.h" |
| 20 | #include "mlir-c/IntegerSet.h" |
| 21 | #include "mlir/Bindings/Python/Nanobind.h" |
| 22 | #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. |
| 23 | #include "mlir/Support/LLVM.h" |
| 24 | #include "llvm/ADT/Hashing.h" |
| 25 | #include "llvm/ADT/SmallVector.h" |
| 26 | #include "llvm/ADT/StringRef.h" |
| 27 | #include "llvm/ADT/Twine.h" |
| 28 | |
| 29 | namespace nb = nanobind; |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::python; |
| 32 | |
| 33 | using llvm::SmallVector; |
| 34 | using llvm::StringRef; |
| 35 | using llvm::Twine; |
| 36 | |
| 37 | static const char kDumpDocstring[] = |
| 38 | R"(Dumps a debug representation of the object to stderr.)" ; |
| 39 | |
| 40 | /// Attempts to populate `result` with the content of `list` casted to the |
| 41 | /// appropriate type (Python and C types are provided as template arguments). |
| 42 | /// Throws errors in case of failure, using "action" to describe what the caller |
| 43 | /// was attempting to do. |
| 44 | template <typename PyType, typename CType> |
| 45 | static void pyListToVector(const nb::list &list, |
| 46 | llvm::SmallVectorImpl<CType> &result, |
| 47 | StringRef action) { |
| 48 | result.reserve(nb::len(list)); |
| 49 | for (nb::handle item : list) { |
| 50 | try { |
| 51 | result.push_back(nb::cast<PyType>(item)); |
| 52 | } catch (nb::cast_error &err) { |
| 53 | std::string msg = (llvm::Twine("Invalid expression when " ) + action + |
| 54 | " (" + err.what() + ")" ) |
| 55 | .str(); |
| 56 | throw std::runtime_error(msg.c_str()); |
| 57 | } catch (std::runtime_error &err) { |
| 58 | std::string msg = (llvm::Twine("Invalid expression (None?) when " ) + |
| 59 | action + " (" + err.what() + ")" ) |
| 60 | .str(); |
| 61 | throw std::runtime_error(msg.c_str()); |
| 62 | } |
| 63 | } |
| 64 | } |
| 65 | |
| 66 | template <typename PermutationTy> |
| 67 | static bool isPermutation(std::vector<PermutationTy> permutation) { |
| 68 | llvm::SmallVector<bool, 8> seen(permutation.size(), false); |
| 69 | for (auto val : permutation) { |
| 70 | if (val < permutation.size()) { |
| 71 | if (seen[val]) |
| 72 | return false; |
| 73 | seen[val] = true; |
| 74 | continue; |
| 75 | } |
| 76 | return false; |
| 77 | } |
| 78 | return true; |
| 79 | } |
| 80 | |
| 81 | namespace { |
| 82 | |
| 83 | /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr |
| 84 | /// and should be castable from it. Intermediate hierarchy classes can be |
| 85 | /// modeled by specifying BaseTy. |
| 86 | template <typename DerivedTy, typename BaseTy = PyAffineExpr> |
| 87 | class PyConcreteAffineExpr : public BaseTy { |
| 88 | public: |
| 89 | // Derived classes must define statics for: |
| 90 | // IsAFunctionTy isaFunction |
| 91 | // const char *pyClassName |
| 92 | // and redefine bindDerived. |
| 93 | using ClassTy = nb::class_<DerivedTy, BaseTy>; |
| 94 | using IsAFunctionTy = bool (*)(MlirAffineExpr); |
| 95 | |
| 96 | PyConcreteAffineExpr() = default; |
| 97 | PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) |
| 98 | : BaseTy(std::move(contextRef), affineExpr) {} |
| 99 | PyConcreteAffineExpr(PyAffineExpr &orig) |
| 100 | : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} |
| 101 | |
| 102 | static MlirAffineExpr castFrom(PyAffineExpr &orig) { |
| 103 | if (!DerivedTy::isaFunction(orig)) { |
| 104 | auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig))); |
| 105 | throw nb::value_error((Twine("Cannot cast affine expression to " ) + |
| 106 | DerivedTy::pyClassName + " (from " + origRepr + |
| 107 | ")" ) |
| 108 | .str() |
| 109 | .c_str()); |
| 110 | } |
| 111 | return orig; |
| 112 | } |
| 113 | |
| 114 | static void bind(nb::module_ &m) { |
| 115 | auto cls = ClassTy(m, DerivedTy::pyClassName); |
| 116 | cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr" )); |
| 117 | cls.def_static( |
| 118 | "isinstance" , |
| 119 | [](PyAffineExpr &otherAffineExpr) -> bool { |
| 120 | return DerivedTy::isaFunction(otherAffineExpr); |
| 121 | }, |
| 122 | nb::arg("other" )); |
| 123 | DerivedTy::bindDerived(cls); |
| 124 | } |
| 125 | |
| 126 | /// Implemented by derived classes to add methods to the Python subclass. |
| 127 | static void bindDerived(ClassTy &m) {} |
| 128 | }; |
| 129 | |
| 130 | class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { |
| 131 | public: |
| 132 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; |
| 133 | static constexpr const char *pyClassName = "AffineConstantExpr" ; |
| 134 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 135 | |
| 136 | static PyAffineConstantExpr get(intptr_t value, |
| 137 | DefaultingPyMlirContext context) { |
| 138 | MlirAffineExpr affineExpr = |
| 139 | mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); |
| 140 | return PyAffineConstantExpr(context->getRef(), affineExpr); |
| 141 | } |
| 142 | |
| 143 | static void bindDerived(ClassTy &c) { |
| 144 | c.def_static("get" , &PyAffineConstantExpr::get, nb::arg("value" ), |
| 145 | nb::arg("context" ).none() = nb::none()); |
| 146 | c.def_prop_ro("value" , [](PyAffineConstantExpr &self) { |
| 147 | return mlirAffineConstantExprGetValue(self); |
| 148 | }); |
| 149 | } |
| 150 | }; |
| 151 | |
| 152 | class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { |
| 153 | public: |
| 154 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; |
| 155 | static constexpr const char *pyClassName = "AffineDimExpr" ; |
| 156 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 157 | |
| 158 | static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { |
| 159 | MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); |
| 160 | return PyAffineDimExpr(context->getRef(), affineExpr); |
| 161 | } |
| 162 | |
| 163 | static void bindDerived(ClassTy &c) { |
| 164 | c.def_static("get" , &PyAffineDimExpr::get, nb::arg("position" ), |
| 165 | nb::arg("context" ).none() = nb::none()); |
| 166 | c.def_prop_ro("position" , [](PyAffineDimExpr &self) { |
| 167 | return mlirAffineDimExprGetPosition(self); |
| 168 | }); |
| 169 | } |
| 170 | }; |
| 171 | |
| 172 | class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { |
| 173 | public: |
| 174 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; |
| 175 | static constexpr const char *pyClassName = "AffineSymbolExpr" ; |
| 176 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 177 | |
| 178 | static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { |
| 179 | MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); |
| 180 | return PyAffineSymbolExpr(context->getRef(), affineExpr); |
| 181 | } |
| 182 | |
| 183 | static void bindDerived(ClassTy &c) { |
| 184 | c.def_static("get" , &PyAffineSymbolExpr::get, nb::arg("position" ), |
| 185 | nb::arg("context" ).none() = nb::none()); |
| 186 | c.def_prop_ro("position" , [](PyAffineSymbolExpr &self) { |
| 187 | return mlirAffineSymbolExprGetPosition(self); |
| 188 | }); |
| 189 | } |
| 190 | }; |
| 191 | |
| 192 | class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { |
| 193 | public: |
| 194 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; |
| 195 | static constexpr const char *pyClassName = "AffineBinaryExpr" ; |
| 196 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 197 | |
| 198 | PyAffineExpr lhs() { |
| 199 | MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); |
| 200 | return PyAffineExpr(getContext(), lhsExpr); |
| 201 | } |
| 202 | |
| 203 | PyAffineExpr rhs() { |
| 204 | MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); |
| 205 | return PyAffineExpr(getContext(), rhsExpr); |
| 206 | } |
| 207 | |
| 208 | static void bindDerived(ClassTy &c) { |
| 209 | c.def_prop_ro("lhs" , &PyAffineBinaryExpr::lhs); |
| 210 | c.def_prop_ro("rhs" , &PyAffineBinaryExpr::rhs); |
| 211 | } |
| 212 | }; |
| 213 | |
| 214 | class PyAffineAddExpr |
| 215 | : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { |
| 216 | public: |
| 217 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; |
| 218 | static constexpr const char *pyClassName = "AffineAddExpr" ; |
| 219 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 220 | |
| 221 | static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
| 222 | MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); |
| 223 | return PyAffineAddExpr(lhs.getContext(), expr); |
| 224 | } |
| 225 | |
| 226 | static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 227 | MlirAffineExpr expr = mlirAffineAddExprGet( |
| 228 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 229 | return PyAffineAddExpr(lhs.getContext(), expr); |
| 230 | } |
| 231 | |
| 232 | static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 233 | MlirAffineExpr expr = mlirAffineAddExprGet( |
| 234 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 235 | return PyAffineAddExpr(rhs.getContext(), expr); |
| 236 | } |
| 237 | |
| 238 | static void bindDerived(ClassTy &c) { |
| 239 | c.def_static("get" , &PyAffineAddExpr::get); |
| 240 | } |
| 241 | }; |
| 242 | |
| 243 | class PyAffineMulExpr |
| 244 | : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { |
| 245 | public: |
| 246 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; |
| 247 | static constexpr const char *pyClassName = "AffineMulExpr" ; |
| 248 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 249 | |
| 250 | static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
| 251 | MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); |
| 252 | return PyAffineMulExpr(lhs.getContext(), expr); |
| 253 | } |
| 254 | |
| 255 | static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 256 | MlirAffineExpr expr = mlirAffineMulExprGet( |
| 257 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 258 | return PyAffineMulExpr(lhs.getContext(), expr); |
| 259 | } |
| 260 | |
| 261 | static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 262 | MlirAffineExpr expr = mlirAffineMulExprGet( |
| 263 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 264 | return PyAffineMulExpr(rhs.getContext(), expr); |
| 265 | } |
| 266 | |
| 267 | static void bindDerived(ClassTy &c) { |
| 268 | c.def_static("get" , &PyAffineMulExpr::get); |
| 269 | } |
| 270 | }; |
| 271 | |
| 272 | class PyAffineModExpr |
| 273 | : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { |
| 274 | public: |
| 275 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; |
| 276 | static constexpr const char *pyClassName = "AffineModExpr" ; |
| 277 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 278 | |
| 279 | static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
| 280 | MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); |
| 281 | return PyAffineModExpr(lhs.getContext(), expr); |
| 282 | } |
| 283 | |
| 284 | static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 285 | MlirAffineExpr expr = mlirAffineModExprGet( |
| 286 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 287 | return PyAffineModExpr(lhs.getContext(), expr); |
| 288 | } |
| 289 | |
| 290 | static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 291 | MlirAffineExpr expr = mlirAffineModExprGet( |
| 292 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 293 | return PyAffineModExpr(rhs.getContext(), expr); |
| 294 | } |
| 295 | |
| 296 | static void bindDerived(ClassTy &c) { |
| 297 | c.def_static("get" , &PyAffineModExpr::get); |
| 298 | } |
| 299 | }; |
| 300 | |
| 301 | class PyAffineFloorDivExpr |
| 302 | : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { |
| 303 | public: |
| 304 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; |
| 305 | static constexpr const char *pyClassName = "AffineFloorDivExpr" ; |
| 306 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 307 | |
| 308 | static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
| 309 | MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); |
| 310 | return PyAffineFloorDivExpr(lhs.getContext(), expr); |
| 311 | } |
| 312 | |
| 313 | static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 314 | MlirAffineExpr expr = mlirAffineFloorDivExprGet( |
| 315 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 316 | return PyAffineFloorDivExpr(lhs.getContext(), expr); |
| 317 | } |
| 318 | |
| 319 | static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 320 | MlirAffineExpr expr = mlirAffineFloorDivExprGet( |
| 321 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 322 | return PyAffineFloorDivExpr(rhs.getContext(), expr); |
| 323 | } |
| 324 | |
| 325 | static void bindDerived(ClassTy &c) { |
| 326 | c.def_static("get" , &PyAffineFloorDivExpr::get); |
| 327 | } |
| 328 | }; |
| 329 | |
| 330 | class PyAffineCeilDivExpr |
| 331 | : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { |
| 332 | public: |
| 333 | static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; |
| 334 | static constexpr const char *pyClassName = "AffineCeilDivExpr" ; |
| 335 | using PyConcreteAffineExpr::PyConcreteAffineExpr; |
| 336 | |
| 337 | static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { |
| 338 | MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); |
| 339 | return PyAffineCeilDivExpr(lhs.getContext(), expr); |
| 340 | } |
| 341 | |
| 342 | static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { |
| 343 | MlirAffineExpr expr = mlirAffineCeilDivExprGet( |
| 344 | lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); |
| 345 | return PyAffineCeilDivExpr(lhs.getContext(), expr); |
| 346 | } |
| 347 | |
| 348 | static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { |
| 349 | MlirAffineExpr expr = mlirAffineCeilDivExprGet( |
| 350 | mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); |
| 351 | return PyAffineCeilDivExpr(rhs.getContext(), expr); |
| 352 | } |
| 353 | |
| 354 | static void bindDerived(ClassTy &c) { |
| 355 | c.def_static("get" , &PyAffineCeilDivExpr::get); |
| 356 | } |
| 357 | }; |
| 358 | |
| 359 | } // namespace |
| 360 | |
| 361 | bool PyAffineExpr::operator==(const PyAffineExpr &other) const { |
| 362 | return mlirAffineExprEqual(affineExpr, other.affineExpr); |
| 363 | } |
| 364 | |
| 365 | nb::object PyAffineExpr::getCapsule() { |
| 366 | return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this)); |
| 367 | } |
| 368 | |
| 369 | PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { |
| 370 | MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); |
| 371 | if (mlirAffineExprIsNull(rawAffineExpr)) |
| 372 | throw nb::python_error(); |
| 373 | return PyAffineExpr( |
| 374 | PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), |
| 375 | rawAffineExpr); |
| 376 | } |
| 377 | |
| 378 | //------------------------------------------------------------------------------ |
| 379 | // PyAffineMap and utilities. |
| 380 | //------------------------------------------------------------------------------ |
| 381 | namespace { |
| 382 | |
| 383 | /// A list of expressions contained in an affine map. Internally these are |
| 384 | /// stored as a consecutive array leading to inexpensive random access. Both |
| 385 | /// the map and the expression are owned by the context so we need not bother |
| 386 | /// with lifetime extension. |
| 387 | class PyAffineMapExprList |
| 388 | : public Sliceable<PyAffineMapExprList, PyAffineExpr> { |
| 389 | public: |
| 390 | static constexpr const char *pyClassName = "AffineExprList" ; |
| 391 | |
| 392 | PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0, |
| 393 | intptr_t length = -1, intptr_t step = 1) |
| 394 | : Sliceable(startIndex, |
| 395 | length == -1 ? mlirAffineMapGetNumResults(map) : length, |
| 396 | step), |
| 397 | affineMap(map) {} |
| 398 | |
| 399 | private: |
| 400 | /// Give the parent CRTP class access to hook implementations below. |
| 401 | friend class Sliceable<PyAffineMapExprList, PyAffineExpr>; |
| 402 | |
| 403 | intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); } |
| 404 | |
| 405 | PyAffineExpr getRawElement(intptr_t pos) { |
| 406 | return PyAffineExpr(affineMap.getContext(), |
| 407 | mlirAffineMapGetResult(affineMap, pos)); |
| 408 | } |
| 409 | |
| 410 | PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, |
| 411 | intptr_t step) { |
| 412 | return PyAffineMapExprList(affineMap, startIndex, length, step); |
| 413 | } |
| 414 | |
| 415 | PyAffineMap affineMap; |
| 416 | }; |
| 417 | } // namespace |
| 418 | |
| 419 | bool PyAffineMap::operator==(const PyAffineMap &other) const { |
| 420 | return mlirAffineMapEqual(affineMap, other.affineMap); |
| 421 | } |
| 422 | |
| 423 | nb::object PyAffineMap::getCapsule() { |
| 424 | return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this)); |
| 425 | } |
| 426 | |
| 427 | PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { |
| 428 | MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); |
| 429 | if (mlirAffineMapIsNull(rawAffineMap)) |
| 430 | throw nb::python_error(); |
| 431 | return PyAffineMap( |
| 432 | PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), |
| 433 | rawAffineMap); |
| 434 | } |
| 435 | |
| 436 | //------------------------------------------------------------------------------ |
| 437 | // PyIntegerSet and utilities. |
| 438 | //------------------------------------------------------------------------------ |
| 439 | namespace { |
| 440 | |
| 441 | class PyIntegerSetConstraint { |
| 442 | public: |
| 443 | PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) |
| 444 | : set(std::move(set)), pos(pos) {} |
| 445 | |
| 446 | PyAffineExpr getExpr() { |
| 447 | return PyAffineExpr(set.getContext(), |
| 448 | mlirIntegerSetGetConstraint(set, pos)); |
| 449 | } |
| 450 | |
| 451 | bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } |
| 452 | |
| 453 | static void bind(nb::module_ &m) { |
| 454 | nb::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint" ) |
| 455 | .def_prop_ro("expr" , &PyIntegerSetConstraint::getExpr) |
| 456 | .def_prop_ro("is_eq" , &PyIntegerSetConstraint::isEq); |
| 457 | } |
| 458 | |
| 459 | private: |
| 460 | PyIntegerSet set; |
| 461 | intptr_t pos; |
| 462 | }; |
| 463 | |
| 464 | class PyIntegerSetConstraintList |
| 465 | : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { |
| 466 | public: |
| 467 | static constexpr const char *pyClassName = "IntegerSetConstraintList" ; |
| 468 | |
| 469 | PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0, |
| 470 | intptr_t length = -1, intptr_t step = 1) |
| 471 | : Sliceable(startIndex, |
| 472 | length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, |
| 473 | step), |
| 474 | set(set) {} |
| 475 | |
| 476 | private: |
| 477 | /// Give the parent CRTP class access to hook implementations below. |
| 478 | friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>; |
| 479 | |
| 480 | intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); } |
| 481 | |
| 482 | PyIntegerSetConstraint getRawElement(intptr_t pos) { |
| 483 | return PyIntegerSetConstraint(set, pos); |
| 484 | } |
| 485 | |
| 486 | PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, |
| 487 | intptr_t step) { |
| 488 | return PyIntegerSetConstraintList(set, startIndex, length, step); |
| 489 | } |
| 490 | |
| 491 | PyIntegerSet set; |
| 492 | }; |
| 493 | } // namespace |
| 494 | |
| 495 | bool PyIntegerSet::operator==(const PyIntegerSet &other) const { |
| 496 | return mlirIntegerSetEqual(integerSet, other.integerSet); |
| 497 | } |
| 498 | |
| 499 | nb::object PyIntegerSet::getCapsule() { |
| 500 | return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this)); |
| 501 | } |
| 502 | |
| 503 | PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { |
| 504 | MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); |
| 505 | if (mlirIntegerSetIsNull(rawIntegerSet)) |
| 506 | throw nb::python_error(); |
| 507 | return PyIntegerSet( |
| 508 | PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), |
| 509 | rawIntegerSet); |
| 510 | } |
| 511 | |
| 512 | void mlir::python::populateIRAffine(nb::module_ &m) { |
| 513 | //---------------------------------------------------------------------------- |
| 514 | // Mapping of PyAffineExpr and derived classes. |
| 515 | //---------------------------------------------------------------------------- |
| 516 | nb::class_<PyAffineExpr>(m, "AffineExpr" ) |
| 517 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) |
| 518 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) |
| 519 | .def("__add__" , &PyAffineAddExpr::get) |
| 520 | .def("__add__" , &PyAffineAddExpr::getRHSConstant) |
| 521 | .def("__radd__" , &PyAffineAddExpr::getRHSConstant) |
| 522 | .def("__mul__" , &PyAffineMulExpr::get) |
| 523 | .def("__mul__" , &PyAffineMulExpr::getRHSConstant) |
| 524 | .def("__rmul__" , &PyAffineMulExpr::getRHSConstant) |
| 525 | .def("__mod__" , &PyAffineModExpr::get) |
| 526 | .def("__mod__" , &PyAffineModExpr::getRHSConstant) |
| 527 | .def("__rmod__" , |
| 528 | [](PyAffineExpr &self, intptr_t other) { |
| 529 | return PyAffineModExpr::get( |
| 530 | PyAffineConstantExpr::get(other, *self.getContext().get()), |
| 531 | self); |
| 532 | }) |
| 533 | .def("__sub__" , |
| 534 | [](PyAffineExpr &self, PyAffineExpr &other) { |
| 535 | auto negOne = |
| 536 | PyAffineConstantExpr::get(-1, *self.getContext().get()); |
| 537 | return PyAffineAddExpr::get(self, |
| 538 | PyAffineMulExpr::get(negOne, other)); |
| 539 | }) |
| 540 | .def("__sub__" , |
| 541 | [](PyAffineExpr &self, intptr_t other) { |
| 542 | return PyAffineAddExpr::get( |
| 543 | self, |
| 544 | PyAffineConstantExpr::get(-other, *self.getContext().get())); |
| 545 | }) |
| 546 | .def("__rsub__" , |
| 547 | [](PyAffineExpr &self, intptr_t other) { |
| 548 | return PyAffineAddExpr::getLHSConstant( |
| 549 | other, PyAffineMulExpr::getLHSConstant(-1, self)); |
| 550 | }) |
| 551 | .def("__eq__" , [](PyAffineExpr &self, |
| 552 | PyAffineExpr &other) { return self == other; }) |
| 553 | .def("__eq__" , |
| 554 | [](PyAffineExpr &self, nb::object &other) { return false; }) |
| 555 | .def("__str__" , |
| 556 | [](PyAffineExpr &self) { |
| 557 | PyPrintAccumulator printAccum; |
| 558 | mlirAffineExprPrint(self, printAccum.getCallback(), |
| 559 | printAccum.getUserData()); |
| 560 | return printAccum.join(); |
| 561 | }) |
| 562 | .def("__repr__" , |
| 563 | [](PyAffineExpr &self) { |
| 564 | PyPrintAccumulator printAccum; |
| 565 | printAccum.parts.append("AffineExpr(" ); |
| 566 | mlirAffineExprPrint(self, printAccum.getCallback(), |
| 567 | printAccum.getUserData()); |
| 568 | printAccum.parts.append(")" ); |
| 569 | return printAccum.join(); |
| 570 | }) |
| 571 | .def("__hash__" , |
| 572 | [](PyAffineExpr &self) { |
| 573 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
| 574 | }) |
| 575 | .def_prop_ro( |
| 576 | "context" , |
| 577 | [](PyAffineExpr &self) { return self.getContext().getObject(); }) |
| 578 | .def("compose" , |
| 579 | [](PyAffineExpr &self, PyAffineMap &other) { |
| 580 | return PyAffineExpr(self.getContext(), |
| 581 | mlirAffineExprCompose(self, other)); |
| 582 | }) |
| 583 | .def( |
| 584 | "shift_dims" , |
| 585 | [](PyAffineExpr &self, uint32_t numDims, uint32_t shift, |
| 586 | uint32_t offset) { |
| 587 | return PyAffineExpr( |
| 588 | self.getContext(), |
| 589 | mlirAffineExprShiftDims(self, numDims, shift, offset)); |
| 590 | }, |
| 591 | nb::arg("num_dims" ), nb::arg("shift" ), nb::arg("offset" ).none() = 0) |
| 592 | .def( |
| 593 | "shift_symbols" , |
| 594 | [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift, |
| 595 | uint32_t offset) { |
| 596 | return PyAffineExpr( |
| 597 | self.getContext(), |
| 598 | mlirAffineExprShiftSymbols(self, numSymbols, shift, offset)); |
| 599 | }, |
| 600 | nb::arg("num_symbols" ), nb::arg("shift" ), |
| 601 | nb::arg("offset" ).none() = 0) |
| 602 | .def_static( |
| 603 | "simplify_affine_expr" , |
| 604 | [](PyAffineExpr &self, uint32_t numDims, uint32_t numSymbols) { |
| 605 | return PyAffineExpr( |
| 606 | self.getContext(), |
| 607 | mlirSimplifyAffineExpr(self, numDims, numSymbols)); |
| 608 | }, |
| 609 | nb::arg("expr" ), nb::arg("num_dims" ), nb::arg("num_symbols" ), |
| 610 | "Simplify an affine expression by flattening and some amount of " |
| 611 | "simple analysis." ) |
| 612 | .def_static( |
| 613 | "get_add" , &PyAffineAddExpr::get, |
| 614 | "Gets an affine expression containing a sum of two expressions." ) |
| 615 | .def_static("get_add" , &PyAffineAddExpr::getLHSConstant, |
| 616 | "Gets an affine expression containing a sum of a constant " |
| 617 | "and another expression." ) |
| 618 | .def_static("get_add" , &PyAffineAddExpr::getRHSConstant, |
| 619 | "Gets an affine expression containing a sum of an expression " |
| 620 | "and a constant." ) |
| 621 | .def_static( |
| 622 | "get_mul" , &PyAffineMulExpr::get, |
| 623 | "Gets an affine expression containing a product of two expressions." ) |
| 624 | .def_static("get_mul" , &PyAffineMulExpr::getLHSConstant, |
| 625 | "Gets an affine expression containing a product of a " |
| 626 | "constant and another expression." ) |
| 627 | .def_static("get_mul" , &PyAffineMulExpr::getRHSConstant, |
| 628 | "Gets an affine expression containing a product of an " |
| 629 | "expression and a constant." ) |
| 630 | .def_static("get_mod" , &PyAffineModExpr::get, |
| 631 | "Gets an affine expression containing the modulo of dividing " |
| 632 | "one expression by another." ) |
| 633 | .def_static("get_mod" , &PyAffineModExpr::getLHSConstant, |
| 634 | "Gets a semi-affine expression containing the modulo of " |
| 635 | "dividing a constant by an expression." ) |
| 636 | .def_static("get_mod" , &PyAffineModExpr::getRHSConstant, |
| 637 | "Gets an affine expression containing the module of dividing" |
| 638 | "an expression by a constant." ) |
| 639 | .def_static("get_floor_div" , &PyAffineFloorDivExpr::get, |
| 640 | "Gets an affine expression containing the rounded-down " |
| 641 | "result of dividing one expression by another." ) |
| 642 | .def_static("get_floor_div" , &PyAffineFloorDivExpr::getLHSConstant, |
| 643 | "Gets a semi-affine expression containing the rounded-down " |
| 644 | "result of dividing a constant by an expression." ) |
| 645 | .def_static("get_floor_div" , &PyAffineFloorDivExpr::getRHSConstant, |
| 646 | "Gets an affine expression containing the rounded-down " |
| 647 | "result of dividing an expression by a constant." ) |
| 648 | .def_static("get_ceil_div" , &PyAffineCeilDivExpr::get, |
| 649 | "Gets an affine expression containing the rounded-up result " |
| 650 | "of dividing one expression by another." ) |
| 651 | .def_static("get_ceil_div" , &PyAffineCeilDivExpr::getLHSConstant, |
| 652 | "Gets a semi-affine expression containing the rounded-up " |
| 653 | "result of dividing a constant by an expression." ) |
| 654 | .def_static("get_ceil_div" , &PyAffineCeilDivExpr::getRHSConstant, |
| 655 | "Gets an affine expression containing the rounded-up result " |
| 656 | "of dividing an expression by a constant." ) |
| 657 | .def_static("get_constant" , &PyAffineConstantExpr::get, nb::arg("value" ), |
| 658 | nb::arg("context" ).none() = nb::none(), |
| 659 | "Gets a constant affine expression with the given value." ) |
| 660 | .def_static( |
| 661 | "get_dim" , &PyAffineDimExpr::get, nb::arg("position" ), |
| 662 | nb::arg("context" ).none() = nb::none(), |
| 663 | "Gets an affine expression of a dimension at the given position." ) |
| 664 | .def_static( |
| 665 | "get_symbol" , &PyAffineSymbolExpr::get, nb::arg("position" ), |
| 666 | nb::arg("context" ).none() = nb::none(), |
| 667 | "Gets an affine expression of a symbol at the given position." ) |
| 668 | .def( |
| 669 | "dump" , [](PyAffineExpr &self) { mlirAffineExprDump(self); }, |
| 670 | kDumpDocstring); |
| 671 | PyAffineConstantExpr::bind(m); |
| 672 | PyAffineDimExpr::bind(m); |
| 673 | PyAffineSymbolExpr::bind(m); |
| 674 | PyAffineBinaryExpr::bind(m); |
| 675 | PyAffineAddExpr::bind(m); |
| 676 | PyAffineMulExpr::bind(m); |
| 677 | PyAffineModExpr::bind(m); |
| 678 | PyAffineFloorDivExpr::bind(m); |
| 679 | PyAffineCeilDivExpr::bind(m); |
| 680 | |
| 681 | //---------------------------------------------------------------------------- |
| 682 | // Mapping of PyAffineMap. |
| 683 | //---------------------------------------------------------------------------- |
| 684 | nb::class_<PyAffineMap>(m, "AffineMap" ) |
| 685 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) |
| 686 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) |
| 687 | .def("__eq__" , |
| 688 | [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) |
| 689 | .def("__eq__" , [](PyAffineMap &self, nb::object &other) { return false; }) |
| 690 | .def("__str__" , |
| 691 | [](PyAffineMap &self) { |
| 692 | PyPrintAccumulator printAccum; |
| 693 | mlirAffineMapPrint(self, printAccum.getCallback(), |
| 694 | printAccum.getUserData()); |
| 695 | return printAccum.join(); |
| 696 | }) |
| 697 | .def("__repr__" , |
| 698 | [](PyAffineMap &self) { |
| 699 | PyPrintAccumulator printAccum; |
| 700 | printAccum.parts.append("AffineMap(" ); |
| 701 | mlirAffineMapPrint(self, printAccum.getCallback(), |
| 702 | printAccum.getUserData()); |
| 703 | printAccum.parts.append(")" ); |
| 704 | return printAccum.join(); |
| 705 | }) |
| 706 | .def("__hash__" , |
| 707 | [](PyAffineMap &self) { |
| 708 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
| 709 | }) |
| 710 | .def_static("compress_unused_symbols" , |
| 711 | [](nb::list affineMaps, DefaultingPyMlirContext context) { |
| 712 | SmallVector<MlirAffineMap> maps; |
| 713 | pyListToVector<PyAffineMap, MlirAffineMap>( |
| 714 | affineMaps, maps, "attempting to create an AffineMap" ); |
| 715 | std::vector<MlirAffineMap> compressed(affineMaps.size()); |
| 716 | auto populate = [](void *result, intptr_t idx, |
| 717 | MlirAffineMap m) { |
| 718 | static_cast<MlirAffineMap *>(result)[idx] = (m); |
| 719 | }; |
| 720 | mlirAffineMapCompressUnusedSymbols( |
| 721 | maps.data(), maps.size(), compressed.data(), populate); |
| 722 | std::vector<PyAffineMap> res; |
| 723 | res.reserve(compressed.size()); |
| 724 | for (auto m : compressed) |
| 725 | res.emplace_back(context->getRef(), m); |
| 726 | return res; |
| 727 | }) |
| 728 | .def_prop_ro( |
| 729 | "context" , |
| 730 | [](PyAffineMap &self) { return self.getContext().getObject(); }, |
| 731 | "Context that owns the Affine Map" ) |
| 732 | .def( |
| 733 | "dump" , [](PyAffineMap &self) { mlirAffineMapDump(self); }, |
| 734 | kDumpDocstring) |
| 735 | .def_static( |
| 736 | "get" , |
| 737 | [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, |
| 738 | DefaultingPyMlirContext context) { |
| 739 | SmallVector<MlirAffineExpr> affineExprs; |
| 740 | pyListToVector<PyAffineExpr, MlirAffineExpr>( |
| 741 | exprs, affineExprs, "attempting to create an AffineMap" ); |
| 742 | MlirAffineMap map = |
| 743 | mlirAffineMapGet(context->get(), dimCount, symbolCount, |
| 744 | affineExprs.size(), affineExprs.data()); |
| 745 | return PyAffineMap(context->getRef(), map); |
| 746 | }, |
| 747 | nb::arg("dim_count" ), nb::arg("symbol_count" ), nb::arg("exprs" ), |
| 748 | nb::arg("context" ).none() = nb::none(), |
| 749 | "Gets a map with the given expressions as results." ) |
| 750 | .def_static( |
| 751 | "get_constant" , |
| 752 | [](intptr_t value, DefaultingPyMlirContext context) { |
| 753 | MlirAffineMap affineMap = |
| 754 | mlirAffineMapConstantGet(context->get(), value); |
| 755 | return PyAffineMap(context->getRef(), affineMap); |
| 756 | }, |
| 757 | nb::arg("value" ), nb::arg("context" ).none() = nb::none(), |
| 758 | "Gets an affine map with a single constant result" ) |
| 759 | .def_static( |
| 760 | "get_empty" , |
| 761 | [](DefaultingPyMlirContext context) { |
| 762 | MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); |
| 763 | return PyAffineMap(context->getRef(), affineMap); |
| 764 | }, |
| 765 | nb::arg("context" ).none() = nb::none(), "Gets an empty affine map." ) |
| 766 | .def_static( |
| 767 | "get_identity" , |
| 768 | [](intptr_t nDims, DefaultingPyMlirContext context) { |
| 769 | MlirAffineMap affineMap = |
| 770 | mlirAffineMapMultiDimIdentityGet(context->get(), nDims); |
| 771 | return PyAffineMap(context->getRef(), affineMap); |
| 772 | }, |
| 773 | nb::arg("n_dims" ), nb::arg("context" ).none() = nb::none(), |
| 774 | "Gets an identity map with the given number of dimensions." ) |
| 775 | .def_static( |
| 776 | "get_minor_identity" , |
| 777 | [](intptr_t nDims, intptr_t nResults, |
| 778 | DefaultingPyMlirContext context) { |
| 779 | MlirAffineMap affineMap = |
| 780 | mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); |
| 781 | return PyAffineMap(context->getRef(), affineMap); |
| 782 | }, |
| 783 | nb::arg("n_dims" ), nb::arg("n_results" ), |
| 784 | nb::arg("context" ).none() = nb::none(), |
| 785 | "Gets a minor identity map with the given number of dimensions and " |
| 786 | "results." ) |
| 787 | .def_static( |
| 788 | "get_permutation" , |
| 789 | [](std::vector<unsigned> permutation, |
| 790 | DefaultingPyMlirContext context) { |
| 791 | if (!isPermutation(permutation)) |
| 792 | throw std::runtime_error("Invalid permutation when attempting to " |
| 793 | "create an AffineMap" ); |
| 794 | MlirAffineMap affineMap = mlirAffineMapPermutationGet( |
| 795 | context->get(), permutation.size(), permutation.data()); |
| 796 | return PyAffineMap(context->getRef(), affineMap); |
| 797 | }, |
| 798 | nb::arg("permutation" ), nb::arg("context" ).none() = nb::none(), |
| 799 | "Gets an affine map that permutes its inputs." ) |
| 800 | .def( |
| 801 | "get_submap" , |
| 802 | [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { |
| 803 | intptr_t numResults = mlirAffineMapGetNumResults(self); |
| 804 | for (intptr_t pos : resultPos) { |
| 805 | if (pos < 0 || pos >= numResults) |
| 806 | throw nb::value_error("result position out of bounds" ); |
| 807 | } |
| 808 | MlirAffineMap affineMap = mlirAffineMapGetSubMap( |
| 809 | self, resultPos.size(), resultPos.data()); |
| 810 | return PyAffineMap(self.getContext(), affineMap); |
| 811 | }, |
| 812 | nb::arg("result_positions" )) |
| 813 | .def( |
| 814 | "get_major_submap" , |
| 815 | [](PyAffineMap &self, intptr_t nResults) { |
| 816 | if (nResults >= mlirAffineMapGetNumResults(self)) |
| 817 | throw nb::value_error("number of results out of bounds" ); |
| 818 | MlirAffineMap affineMap = |
| 819 | mlirAffineMapGetMajorSubMap(self, nResults); |
| 820 | return PyAffineMap(self.getContext(), affineMap); |
| 821 | }, |
| 822 | nb::arg("n_results" )) |
| 823 | .def( |
| 824 | "get_minor_submap" , |
| 825 | [](PyAffineMap &self, intptr_t nResults) { |
| 826 | if (nResults >= mlirAffineMapGetNumResults(self)) |
| 827 | throw nb::value_error("number of results out of bounds" ); |
| 828 | MlirAffineMap affineMap = |
| 829 | mlirAffineMapGetMinorSubMap(self, nResults); |
| 830 | return PyAffineMap(self.getContext(), affineMap); |
| 831 | }, |
| 832 | nb::arg("n_results" )) |
| 833 | .def( |
| 834 | "replace" , |
| 835 | [](PyAffineMap &self, PyAffineExpr &expression, |
| 836 | PyAffineExpr &replacement, intptr_t numResultDims, |
| 837 | intptr_t numResultSyms) { |
| 838 | MlirAffineMap affineMap = mlirAffineMapReplace( |
| 839 | self, expression, replacement, numResultDims, numResultSyms); |
| 840 | return PyAffineMap(self.getContext(), affineMap); |
| 841 | }, |
| 842 | nb::arg("expr" ), nb::arg("replacement" ), nb::arg("n_result_dims" ), |
| 843 | nb::arg("n_result_syms" )) |
| 844 | .def_prop_ro( |
| 845 | "is_permutation" , |
| 846 | [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) |
| 847 | .def_prop_ro("is_projected_permutation" , |
| 848 | [](PyAffineMap &self) { |
| 849 | return mlirAffineMapIsProjectedPermutation(self); |
| 850 | }) |
| 851 | .def_prop_ro( |
| 852 | "n_dims" , |
| 853 | [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) |
| 854 | .def_prop_ro( |
| 855 | "n_inputs" , |
| 856 | [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) |
| 857 | .def_prop_ro( |
| 858 | "n_symbols" , |
| 859 | [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) |
| 860 | .def_prop_ro("results" , |
| 861 | [](PyAffineMap &self) { return PyAffineMapExprList(self); }); |
| 862 | PyAffineMapExprList::bind(m); |
| 863 | |
| 864 | //---------------------------------------------------------------------------- |
| 865 | // Mapping of PyIntegerSet. |
| 866 | //---------------------------------------------------------------------------- |
| 867 | nb::class_<PyIntegerSet>(m, "IntegerSet" ) |
| 868 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) |
| 869 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) |
| 870 | .def("__eq__" , [](PyIntegerSet &self, |
| 871 | PyIntegerSet &other) { return self == other; }) |
| 872 | .def("__eq__" , [](PyIntegerSet &self, nb::object other) { return false; }) |
| 873 | .def("__str__" , |
| 874 | [](PyIntegerSet &self) { |
| 875 | PyPrintAccumulator printAccum; |
| 876 | mlirIntegerSetPrint(self, printAccum.getCallback(), |
| 877 | printAccum.getUserData()); |
| 878 | return printAccum.join(); |
| 879 | }) |
| 880 | .def("__repr__" , |
| 881 | [](PyIntegerSet &self) { |
| 882 | PyPrintAccumulator printAccum; |
| 883 | printAccum.parts.append("IntegerSet(" ); |
| 884 | mlirIntegerSetPrint(self, printAccum.getCallback(), |
| 885 | printAccum.getUserData()); |
| 886 | printAccum.parts.append(")" ); |
| 887 | return printAccum.join(); |
| 888 | }) |
| 889 | .def("__hash__" , |
| 890 | [](PyIntegerSet &self) { |
| 891 | return static_cast<size_t>(llvm::hash_value(self.get().ptr)); |
| 892 | }) |
| 893 | .def_prop_ro( |
| 894 | "context" , |
| 895 | [](PyIntegerSet &self) { return self.getContext().getObject(); }) |
| 896 | .def( |
| 897 | "dump" , [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, |
| 898 | kDumpDocstring) |
| 899 | .def_static( |
| 900 | "get" , |
| 901 | [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, |
| 902 | std::vector<bool> eqFlags, DefaultingPyMlirContext context) { |
| 903 | if (exprs.size() != eqFlags.size()) |
| 904 | throw nb::value_error( |
| 905 | "Expected the number of constraints to match " |
| 906 | "that of equality flags" ); |
| 907 | if (exprs.size() == 0) |
| 908 | throw nb::value_error("Expected non-empty list of constraints" ); |
| 909 | |
| 910 | // Copy over to a SmallVector because std::vector has a |
| 911 | // specialization for booleans that packs data and does not |
| 912 | // expose a `bool *`. |
| 913 | SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); |
| 914 | |
| 915 | SmallVector<MlirAffineExpr> affineExprs; |
| 916 | pyListToVector<PyAffineExpr>(exprs, affineExprs, |
| 917 | "attempting to create an IntegerSet" ); |
| 918 | MlirIntegerSet set = mlirIntegerSetGet( |
| 919 | context->get(), numDims, numSymbols, exprs.size(), |
| 920 | affineExprs.data(), flags.data()); |
| 921 | return PyIntegerSet(context->getRef(), set); |
| 922 | }, |
| 923 | nb::arg("num_dims" ), nb::arg("num_symbols" ), nb::arg("exprs" ), |
| 924 | nb::arg("eq_flags" ), nb::arg("context" ).none() = nb::none()) |
| 925 | .def_static( |
| 926 | "get_empty" , |
| 927 | [](intptr_t numDims, intptr_t numSymbols, |
| 928 | DefaultingPyMlirContext context) { |
| 929 | MlirIntegerSet set = |
| 930 | mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); |
| 931 | return PyIntegerSet(context->getRef(), set); |
| 932 | }, |
| 933 | nb::arg("num_dims" ), nb::arg("num_symbols" ), |
| 934 | nb::arg("context" ).none() = nb::none()) |
| 935 | .def( |
| 936 | "get_replaced" , |
| 937 | [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, |
| 938 | intptr_t numResultDims, intptr_t numResultSymbols) { |
| 939 | if (static_cast<intptr_t>(dimExprs.size()) != |
| 940 | mlirIntegerSetGetNumDims(self)) |
| 941 | throw nb::value_error( |
| 942 | "Expected the number of dimension replacement expressions " |
| 943 | "to match that of dimensions" ); |
| 944 | if (static_cast<intptr_t>(symbolExprs.size()) != |
| 945 | mlirIntegerSetGetNumSymbols(self)) |
| 946 | throw nb::value_error( |
| 947 | "Expected the number of symbol replacement expressions " |
| 948 | "to match that of symbols" ); |
| 949 | |
| 950 | SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; |
| 951 | pyListToVector<PyAffineExpr>( |
| 952 | dimExprs, dimAffineExprs, |
| 953 | "attempting to create an IntegerSet by replacing dimensions" ); |
| 954 | pyListToVector<PyAffineExpr>( |
| 955 | symbolExprs, symbolAffineExprs, |
| 956 | "attempting to create an IntegerSet by replacing symbols" ); |
| 957 | MlirIntegerSet set = mlirIntegerSetReplaceGet( |
| 958 | self, dimAffineExprs.data(), symbolAffineExprs.data(), |
| 959 | numResultDims, numResultSymbols); |
| 960 | return PyIntegerSet(self.getContext(), set); |
| 961 | }, |
| 962 | nb::arg("dim_exprs" ), nb::arg("symbol_exprs" ), |
| 963 | nb::arg("num_result_dims" ), nb::arg("num_result_symbols" )) |
| 964 | .def_prop_ro("is_canonical_empty" , |
| 965 | [](PyIntegerSet &self) { |
| 966 | return mlirIntegerSetIsCanonicalEmpty(self); |
| 967 | }) |
| 968 | .def_prop_ro( |
| 969 | "n_dims" , |
| 970 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) |
| 971 | .def_prop_ro( |
| 972 | "n_symbols" , |
| 973 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) |
| 974 | .def_prop_ro( |
| 975 | "n_inputs" , |
| 976 | [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) |
| 977 | .def_prop_ro("n_equalities" , |
| 978 | [](PyIntegerSet &self) { |
| 979 | return mlirIntegerSetGetNumEqualities(self); |
| 980 | }) |
| 981 | .def_prop_ro("n_inequalities" , |
| 982 | [](PyIntegerSet &self) { |
| 983 | return mlirIntegerSetGetNumInequalities(self); |
| 984 | }) |
| 985 | .def_prop_ro("constraints" , [](PyIntegerSet &self) { |
| 986 | return PyIntegerSetConstraintList(self); |
| 987 | }); |
| 988 | PyIntegerSetConstraint::bind(m); |
| 989 | PyIntegerSetConstraintList::bind(m); |
| 990 | } |
| 991 | |