1 | //===- COO.h - Coordinate-scheme sparse tensor representation ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains a coordinate-scheme representation of sparse tensors. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H |
14 | #define MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H |
15 | |
16 | #include <algorithm> |
17 | #include <cassert> |
18 | #include <cinttypes> |
19 | #include <functional> |
20 | #include <vector> |
21 | |
22 | namespace mlir { |
23 | namespace sparse_tensor { |
24 | |
25 | /// An element of a sparse tensor in coordinate-scheme representation |
26 | /// (i.e., a pair of coordinates and value). For example, a rank-1 |
27 | /// vector element would look like |
28 | /// ({i}, a[i]) |
29 | /// and a rank-5 tensor element would look like |
30 | /// ({i,j,k,l,m}, a[i,j,k,l,m]) |
31 | /// |
32 | /// The coordinates are represented as a (non-owning) pointer into a |
33 | /// shared pool of coordinates, rather than being stored directly in this |
34 | /// object. This significantly improves performance because it reduces the |
35 | /// per-element memory footprint and centralizes the memory management for |
36 | /// coordinates. The only downside is that the coordinates themselves cannot |
37 | /// be retrieved without knowing the rank of the tensor to which this element |
38 | /// belongs (and that rank is not stored in this object). |
39 | template <typename V> |
40 | struct Element final { |
41 | Element(const uint64_t *coords, V val) : coords(coords), value(val){}; |
42 | const uint64_t *coords; // pointer into shared coordinates pool |
43 | V value; |
44 | }; |
45 | |
46 | /// Closure object for `operator<` on `Element` with a given rank. |
47 | template <typename V> |
48 | struct ElementLT final { |
49 | ElementLT(uint64_t rank) : rank(rank) {} |
50 | bool operator()(const Element<V> &e1, const Element<V> &e2) const { |
51 | for (uint64_t d = 0; d < rank; ++d) { |
52 | if (e1.coords[d] == e2.coords[d]) |
53 | continue; |
54 | return e1.coords[d] < e2.coords[d]; |
55 | } |
56 | return false; |
57 | } |
58 | const uint64_t rank; |
59 | }; |
60 | |
61 | /// A memory-resident sparse tensor in coordinate-scheme representation |
62 | /// (a collection of `Element`s). This data structure is used as an |
63 | /// intermediate representation, e.g., for reading sparse tensors from |
64 | /// external formats into memory. |
65 | template <typename V> |
66 | class SparseTensorCOO final { |
67 | public: |
68 | /// Constructs a new coordinate-scheme sparse tensor with the given |
69 | /// sizes and an optional initial storage capacity. |
70 | explicit SparseTensorCOO(const std::vector<uint64_t> &dimSizes, |
71 | uint64_t capacity = 0) |
72 | : SparseTensorCOO(dimSizes.size(), dimSizes.data(), capacity) {} |
73 | |
74 | /// Constructs a new coordinate-scheme sparse tensor with the given |
75 | /// sizes and an optional initial storage capacity. The size of the |
76 | /// dimSizes array is determined by dimRank. |
77 | explicit SparseTensorCOO(uint64_t dimRank, const uint64_t *dimSizes, |
78 | uint64_t capacity = 0) |
79 | : dimSizes(dimSizes, dimSizes + dimRank), isSorted(true) { |
80 | assert(dimRank > 0 && "Trivial shape is not supported" ); |
81 | for (uint64_t d = 0; d < dimRank; ++d) |
82 | assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage" ); |
83 | if (capacity) { |
84 | elements.reserve(capacity); |
85 | coordinates.reserve(n: capacity * dimRank); |
86 | } |
87 | } |
88 | |
89 | /// Gets the dimension-rank of the tensor. |
90 | uint64_t getRank() const { return dimSizes.size(); } |
91 | |
92 | /// Gets the dimension-sizes array. |
93 | const std::vector<uint64_t> &getDimSizes() const { return dimSizes; } |
94 | |
95 | /// Gets the elements array. |
96 | const std::vector<Element<V>> &getElements() const { return elements; } |
97 | |
98 | /// Returns the `operator<` closure object for the COO's element type. |
99 | ElementLT<V> getElementLT() const { return ElementLT<V>(getRank()); } |
100 | |
101 | /// Adds an element to the tensor. |
102 | void add(const std::vector<uint64_t> &dimCoords, V val) { |
103 | const uint64_t *base = coordinates.data(); |
104 | const uint64_t size = coordinates.size(); |
105 | const uint64_t dimRank = getRank(); |
106 | assert(dimCoords.size() == dimRank && "Element rank mismatch" ); |
107 | for (uint64_t d = 0; d < dimRank; ++d) { |
108 | assert(dimCoords[d] < dimSizes[d] && |
109 | "Coordinate is too large for the dimension" ); |
110 | coordinates.push_back(x: dimCoords[d]); |
111 | } |
112 | // This base only changes if `coordinates` was reallocated. In which |
113 | // case, we need to correct all previous pointers into the vector. |
114 | // Note that this only happens if we did not set the initial capacity |
115 | // right, and then only for every internal vector reallocation (which |
116 | // with the doubling rule should only incur an amortized linear overhead). |
117 | const uint64_t *const newBase = coordinates.data(); |
118 | if (newBase != base) { |
119 | for (uint64_t i = 0, n = elements.size(); i < n; ++i) |
120 | elements[i].coords = newBase + (elements[i].coords - base); |
121 | base = newBase; |
122 | } |
123 | // Add the new element and update the sorted bit. |
124 | const Element<V> addedElem(base + size, val); |
125 | if (!elements.empty() && isSorted) |
126 | isSorted = getElementLT()(elements.back(), addedElem); |
127 | elements.push_back(addedElem); |
128 | } |
129 | |
130 | /// Sorts elements lexicographically by coordinates. If a coordinate |
131 | /// is mapped to multiple values, then the relative order of those |
132 | /// values is unspecified. |
133 | void sort() { |
134 | if (isSorted) |
135 | return; |
136 | std::sort(elements.begin(), elements.end(), getElementLT()); |
137 | isSorted = true; |
138 | } |
139 | |
140 | private: |
141 | const std::vector<uint64_t> dimSizes; // per-dimension sizes |
142 | std::vector<Element<V>> elements; // all COO elements |
143 | std::vector<uint64_t> coordinates; // shared coordinate pool |
144 | bool isSorted; |
145 | }; |
146 | |
147 | } // namespace sparse_tensor |
148 | } // namespace mlir |
149 | |
150 | #endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H |
151 | |