1 | //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Dialect/SparseTensor.h" |
10 | #include "mlir-c/IR.h" |
11 | #include "mlir/CAPI/AffineMap.h" |
12 | #include "mlir/CAPI/Registration.h" |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
14 | #include "mlir/Support/LLVM.h" |
15 | |
16 | using namespace llvm; |
17 | using namespace mlir::sparse_tensor; |
18 | |
19 | MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, |
20 | mlir::sparse_tensor::SparseTensorDialect) |
21 | |
22 | // Ensure the C-API enums are int-castable to C++ equivalents. |
23 | static_assert( |
24 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == |
25 | static_cast<int>(LevelFormat::Dense) && |
26 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == |
27 | static_cast<int>(LevelFormat::Compressed) && |
28 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == |
29 | static_cast<int>(LevelFormat::Singleton) && |
30 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == |
31 | static_cast<int>(LevelFormat::LooseCompressed) && |
32 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == |
33 | static_cast<int>(LevelFormat::NOutOfM), |
34 | "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch" ); |
35 | |
36 | static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) == |
37 | static_cast<int>(LevelPropNonDefault::Nonordered) && |
38 | static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == |
39 | static_cast<int>(LevelPropNonDefault::Nonunique), |
40 | "MlirSparseTensorLevelProperty (C-API) and " |
41 | "LevelPropertyNondefault (C++) mismatch" ); |
42 | |
43 | bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { |
44 | return isa<SparseTensorEncodingAttr>(unwrap(attr)); |
45 | } |
46 | |
47 | MlirAttribute |
48 | mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, |
49 | MlirSparseTensorLevelType const *lvlTypes, |
50 | MlirAffineMap dimToLvl, MlirAffineMap lvlToDim, |
51 | int posWidth, int crdWidth) { |
52 | SmallVector<LevelType> cppLvlTypes; |
53 | cppLvlTypes.reserve(N: lvlRank); |
54 | for (intptr_t l = 0; l < lvlRank; ++l) |
55 | cppLvlTypes.push_back(Elt: static_cast<LevelType>(lvlTypes[l])); |
56 | return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes, |
57 | unwrap(dimToLvl), unwrap(lvlToDim), |
58 | posWidth, crdWidth)); |
59 | } |
60 | |
61 | MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { |
62 | return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl()); |
63 | } |
64 | |
65 | MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) { |
66 | return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim()); |
67 | } |
68 | |
69 | intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { |
70 | return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank(); |
71 | } |
72 | |
73 | MlirSparseTensorLevelType |
74 | mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { |
75 | return static_cast<MlirSparseTensorLevelType>( |
76 | cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl)); |
77 | } |
78 | |
79 | enum MlirSparseTensorLevelFormat |
80 | mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { |
81 | LevelType lt = |
82 | static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); |
83 | return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt()); |
84 | } |
85 | |
86 | int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { |
87 | return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth(); |
88 | } |
89 | |
90 | int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { |
91 | return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth(); |
92 | } |
93 | |
94 | MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( |
95 | enum MlirSparseTensorLevelFormat lvlFmt, |
96 | const enum MlirSparseTensorLevelPropertyNondefault *properties, |
97 | unsigned size, unsigned n, unsigned m) { |
98 | |
99 | std::vector<LevelPropNonDefault> props; |
100 | for (unsigned i = 0; i < size; i++) |
101 | props.push_back(x: static_cast<LevelPropNonDefault>(properties[i])); |
102 | |
103 | return static_cast<MlirSparseTensorLevelType>( |
104 | *buildLevelType(lf: static_cast<LevelFormat>(lvlFmt), properties: props, n, m)); |
105 | } |
106 | |
107 | unsigned |
108 | mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) { |
109 | return getN(lt: static_cast<LevelType>(lvlType)); |
110 | } |
111 | |
112 | unsigned |
113 | mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) { |
114 | return getM(lt: static_cast<LevelType>(lvlType)); |
115 | } |
116 | |