| 1 | //===- IntegerSet.cpp - C API for MLIR Integer Sets -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir-c/IntegerSet.h" |
| 10 | #include "mlir-c/AffineExpr.h" |
| 11 | #include "mlir/CAPI/AffineExpr.h" |
| 12 | #include "mlir/CAPI/IR.h" |
| 13 | #include "mlir/CAPI/IntegerSet.h" |
| 14 | #include "mlir/CAPI/Utils.h" |
| 15 | #include "mlir/IR/IntegerSet.h" |
| 16 | |
| 17 | using namespace mlir; |
| 18 | |
| 19 | MlirContext mlirIntegerSetGetContext(MlirIntegerSet set) { |
| 20 | return wrap(cpp: unwrap(c: set).getContext()); |
| 21 | } |
| 22 | |
| 23 | bool mlirIntegerSetEqual(MlirIntegerSet s1, MlirIntegerSet s2) { |
| 24 | return unwrap(c: s1) == unwrap(c: s2); |
| 25 | } |
| 26 | |
| 27 | void mlirIntegerSetPrint(MlirIntegerSet set, MlirStringCallback callback, |
| 28 | void *userData) { |
| 29 | mlir::detail::CallbackOstream stream(callback, userData); |
| 30 | unwrap(c: set).print(os&: stream); |
| 31 | } |
| 32 | |
| 33 | void mlirIntegerSetDump(MlirIntegerSet set) { unwrap(c: set).dump(); } |
| 34 | |
| 35 | MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context, intptr_t numDims, |
| 36 | intptr_t numSymbols) { |
| 37 | return wrap(cpp: IntegerSet::getEmptySet(numDims: static_cast<unsigned>(numDims), |
| 38 | numSymbols: static_cast<unsigned>(numSymbols), |
| 39 | context: unwrap(c: context))); |
| 40 | } |
| 41 | |
| 42 | MlirIntegerSet mlirIntegerSetGet(MlirContext context, intptr_t numDims, |
| 43 | intptr_t numSymbols, intptr_t numConstraints, |
| 44 | const MlirAffineExpr *constraints, |
| 45 | const bool *eqFlags) { |
| 46 | SmallVector<AffineExpr> mlirConstraints; |
| 47 | (void)unwrapList(size: static_cast<size_t>(numConstraints), first: constraints, |
| 48 | storage&: mlirConstraints); |
| 49 | return wrap(cpp: IntegerSet::get( |
| 50 | dimCount: static_cast<unsigned>(numDims), symbolCount: static_cast<unsigned>(numSymbols), |
| 51 | constraints: mlirConstraints, |
| 52 | eqFlags: llvm::ArrayRef(eqFlags, static_cast<size_t>(numConstraints)))); |
| 53 | } |
| 54 | |
| 55 | MlirIntegerSet |
| 56 | mlirIntegerSetReplaceGet(MlirIntegerSet set, |
| 57 | const MlirAffineExpr *dimReplacements, |
| 58 | const MlirAffineExpr *symbolReplacements, |
| 59 | intptr_t numResultDims, intptr_t numResultSymbols) { |
| 60 | SmallVector<AffineExpr> mlirDims, mlirSymbols; |
| 61 | (void)unwrapList(size: unwrap(c: set).getNumDims(), first: dimReplacements, storage&: mlirDims); |
| 62 | (void)unwrapList(size: unwrap(c: set).getNumSymbols(), first: symbolReplacements, |
| 63 | storage&: mlirSymbols); |
| 64 | return wrap(cpp: unwrap(c: set).replaceDimsAndSymbols( |
| 65 | dimReplacements: mlirDims, symReplacements: mlirSymbols, numResultDims: static_cast<unsigned>(numResultDims), |
| 66 | numResultSyms: static_cast<unsigned>(numResultSymbols))); |
| 67 | } |
| 68 | |
| 69 | bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set) { |
| 70 | return unwrap(c: set).isEmptyIntegerSet(); |
| 71 | } |
| 72 | |
| 73 | intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set) { |
| 74 | return static_cast<intptr_t>(unwrap(c: set).getNumDims()); |
| 75 | } |
| 76 | |
| 77 | intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set) { |
| 78 | return static_cast<intptr_t>(unwrap(c: set).getNumSymbols()); |
| 79 | } |
| 80 | |
| 81 | intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set) { |
| 82 | return static_cast<intptr_t>(unwrap(c: set).getNumInputs()); |
| 83 | } |
| 84 | |
| 85 | intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set) { |
| 86 | return static_cast<intptr_t>(unwrap(c: set).getNumConstraints()); |
| 87 | } |
| 88 | |
| 89 | intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set) { |
| 90 | return static_cast<intptr_t>(unwrap(c: set).getNumEqualities()); |
| 91 | } |
| 92 | |
| 93 | intptr_t mlirIntegerSetGetNumInequalities(MlirIntegerSet set) { |
| 94 | return static_cast<intptr_t>(unwrap(c: set).getNumInequalities()); |
| 95 | } |
| 96 | |
| 97 | MlirAffineExpr mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos) { |
| 98 | return wrap(cpp: unwrap(c: set).getConstraint(idx: static_cast<unsigned>(pos))); |
| 99 | } |
| 100 | |
| 101 | bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set, intptr_t pos) { |
| 102 | return unwrap(c: set).isEq(idx: pos); |
| 103 | } |
| 104 | |