1 | //===- StorageBase.cpp - TACO-flavored sparse tensor representation -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains method definitions for `SparseTensorStorageBase`. |
10 | // In particular we want to ensure that the default implementations of |
11 | // the "partial method specialization" trick aren't inline (since there's |
12 | // no benefit). |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/ExecutionEngine/SparseTensor/Storage.h" |
17 | |
18 | using namespace mlir::sparse_tensor; |
19 | |
20 | static inline bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes) { |
21 | for (uint64_t l = 0; l < lvlRank; l++) |
22 | if (!isDenseLT(lt: lvlTypes[l])) |
23 | return false; |
24 | return true; |
25 | } |
26 | |
27 | SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT |
28 | uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, |
29 | const uint64_t *lvlSizes, const LevelType *lvlTypes, |
30 | const uint64_t *dim2lvl, const uint64_t *lvl2dim) |
31 | : dimSizes(dimSizes, dimSizes + dimRank), |
32 | lvlSizes(lvlSizes, lvlSizes + lvlRank), |
33 | lvlTypes(lvlTypes, lvlTypes + lvlRank), |
34 | dim2lvlVec(dim2lvl, dim2lvl + lvlRank), |
35 | lvl2dimVec(lvl2dim, lvl2dim + dimRank), |
36 | map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()), |
37 | allDense(isAllDense(lvlRank, lvlTypes)) { |
38 | assert(dimSizes && lvlSizes && lvlTypes && dim2lvl && lvl2dim); |
39 | // Validate dim-indexed parameters. |
40 | assert(dimRank > 0 && "Trivial shape is unsupported" ); |
41 | for (uint64_t d = 0; d < dimRank; d++) |
42 | assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage" ); |
43 | // Validate lvl-indexed parameters. |
44 | assert(lvlRank > 0 && "Trivial shape is unsupported" ); |
45 | for (uint64_t l = 0; l < lvlRank; l++) { |
46 | assert(lvlSizes[l] > 0 && "Level size zero has trivial storage" ); |
47 | assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) || |
48 | isSingletonLvl(l) || isNOutOfMLvl(l)); |
49 | } |
50 | } |
51 | |
52 | // Helper macro for wrong "partial method specialization" errors. |
53 | #define FATAL_PIV(NAME) \ |
54 | fprintf(stderr, "<P,I,V> type mismatch for: " #NAME); \ |
55 | exit(1); |
56 | |
57 | #define IMPL_GETPOSITIONS(PNAME, P) \ |
58 | void SparseTensorStorageBase::getPositions(std::vector<P> **, uint64_t) { \ |
59 | FATAL_PIV("getPositions" #PNAME); \ |
60 | } |
61 | MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS) |
62 | #undef IMPL_GETPOSITIONS |
63 | |
64 | #define IMPL_GETCOORDINATES(CNAME, C) \ |
65 | void SparseTensorStorageBase::getCoordinates(std::vector<C> **, uint64_t) { \ |
66 | FATAL_PIV("getCoordinates" #CNAME); \ |
67 | } |
68 | MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES) |
69 | #undef IMPL_GETCOORDINATES |
70 | |
71 | #define IMPL_GETCOORDINATESBUFFER(CNAME, C) \ |
72 | void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **, \ |
73 | uint64_t) { \ |
74 | FATAL_PIV("getCoordinatesBuffer" #CNAME); \ |
75 | } |
76 | MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER) |
77 | #undef IMPL_GETCOORDINATESBUFFER |
78 | |
79 | #define IMPL_GETVALUES(VNAME, V) \ |
80 | void SparseTensorStorageBase::getValues(std::vector<V> **) { \ |
81 | FATAL_PIV("getValues" #VNAME); \ |
82 | } |
83 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES) |
84 | #undef IMPL_GETVALUES |
85 | |
86 | #define IMPL_LEXINSERT(VNAME, V) \ |
87 | void SparseTensorStorageBase::lexInsert(const uint64_t *, V) { \ |
88 | FATAL_PIV("lexInsert" #VNAME); \ |
89 | } |
90 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT) |
91 | #undef IMPL_LEXINSERT |
92 | |
93 | #define IMPL_EXPINSERT(VNAME, V) \ |
94 | void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \ |
95 | uint64_t, uint64_t) { \ |
96 | FATAL_PIV("expInsert" #VNAME); \ |
97 | } |
98 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT) |
99 | #undef IMPL_EXPINSERT |
100 | |
101 | #undef FATAL_PIV |
102 | |