1 | //===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Dialect/SparseTensor.h" |
10 | #include "mlir-c/IR.h" |
11 | #include "mlir/CAPI/AffineMap.h" |
12 | #include "mlir/CAPI/Registration.h" |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
14 | #include "mlir/Support/LLVM.h" |
15 | |
16 | using namespace llvm; |
17 | using namespace mlir::sparse_tensor; |
18 | |
19 | MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, |
20 | mlir::sparse_tensor::SparseTensorDialect) |
21 | |
22 | // Ensure the C-API enums are int-castable to C++ equivalents. |
23 | static_assert( |
24 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == |
25 | static_cast<int>(LevelFormat::Dense) && |
26 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == |
27 | static_cast<int>(LevelFormat::Compressed) && |
28 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == |
29 | static_cast<int>(LevelFormat::Singleton) && |
30 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == |
31 | static_cast<int>(LevelFormat::LooseCompressed) && |
32 | static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == |
33 | static_cast<int>(LevelFormat::NOutOfM), |
34 | "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); |
35 | |
36 | static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) == |
37 | static_cast<int>(LevelPropNonDefault::Nonordered) && |
38 | static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == |
39 | static_cast<int>(LevelPropNonDefault::Nonunique) && |
40 | static_cast<int>(MLIR_SPARSE_PROPERTY_SOA) == |
41 | static_cast<int>(LevelPropNonDefault::SoA), |
42 | "MlirSparseTensorLevelProperty (C-API) and " |
43 | "LevelPropertyNondefault (C++) mismatch"); |
44 | |
45 | bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { |
46 | return isa<SparseTensorEncodingAttr>(unwrap(attr)); |
47 | } |
48 | |
49 | MlirAttribute mlirSparseTensorEncodingAttrGet( |
50 | MlirContext ctx, intptr_t lvlRank, |
51 | MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, |
52 | MlirAffineMap lvlToDim, int posWidth, int crdWidth, |
53 | MlirAttribute explicitVal, MlirAttribute implicitVal) { |
54 | SmallVector<LevelType> cppLvlTypes; |
55 | |
56 | cppLvlTypes.reserve(N: lvlRank); |
57 | for (intptr_t l = 0; l < lvlRank; ++l) |
58 | cppLvlTypes.push_back(Elt: static_cast<LevelType>(lvlTypes[l])); |
59 | |
60 | return wrap(SparseTensorEncodingAttr::get( |
61 | unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth, |
62 | crdWidth, unwrap(explicitVal), unwrap(implicitVal))); |
63 | } |
64 | |
65 | MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { |
66 | return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl()); |
67 | } |
68 | |
69 | MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) { |
70 | return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim()); |
71 | } |
72 | |
73 | intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { |
74 | return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank(); |
75 | } |
76 | |
77 | MlirSparseTensorLevelType |
78 | mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { |
79 | return static_cast<MlirSparseTensorLevelType>( |
80 | cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl)); |
81 | } |
82 | |
83 | enum MlirSparseTensorLevelFormat |
84 | mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { |
85 | LevelType lt = |
86 | static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); |
87 | return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt()); |
88 | } |
89 | |
90 | int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { |
91 | return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth(); |
92 | } |
93 | |
94 | int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { |
95 | return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth(); |
96 | } |
97 | |
98 | MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) { |
99 | return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getExplicitVal()); |
100 | } |
101 | |
102 | MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) { |
103 | return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getImplicitVal()); |
104 | } |
105 | |
106 | MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( |
107 | enum MlirSparseTensorLevelFormat lvlFmt, |
108 | const enum MlirSparseTensorLevelPropertyNondefault *properties, |
109 | unsigned size, unsigned n, unsigned m) { |
110 | |
111 | std::vector<LevelPropNonDefault> props; |
112 | props.reserve(n: size); |
113 | for (unsigned i = 0; i < size; i++) |
114 | props.push_back(x: static_cast<LevelPropNonDefault>(properties[i])); |
115 | |
116 | return static_cast<MlirSparseTensorLevelType>( |
117 | *buildLevelType(lf: static_cast<LevelFormat>(lvlFmt), properties: props, n, m)); |
118 | } |
119 | |
120 | unsigned |
121 | mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) { |
122 | return getN(lt: static_cast<LevelType>(lvlType)); |
123 | } |
124 | |
125 | unsigned |
126 | mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) { |
127 | return getM(lt: static_cast<LevelType>(lvlType)); |
128 | } |
129 |
Definitions
- mlirAttributeIsASparseTensorEncodingAttr
- mlirSparseTensorEncodingAttrGet
- mlirSparseTensorEncodingAttrGetDimToLvl
- mlirSparseTensorEncodingAttrGetLvlToDim
- mlirSparseTensorEncodingGetLvlRank
- mlirSparseTensorEncodingAttrGetLvlType
- mlirSparseTensorEncodingAttrGetLvlFmt
- mlirSparseTensorEncodingAttrGetPosWidth
- mlirSparseTensorEncodingAttrGetCrdWidth
- mlirSparseTensorEncodingAttrGetExplicitVal
- mlirSparseTensorEncodingAttrGetImplicitVal
- mlirSparseTensorEncodingAttrBuildLvlType
- mlirSparseTensorEncodingAttrGetStructuredN
Improve your Profiling and Debugging skills
Find out more