1 | //===- SparseTensorRuntime.cpp - SparseTensor runtime support lib ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a light-weight runtime support library for |
10 | // manipulating sparse tensors from MLIR. More specifically, it provides |
11 | // C-API wrappers so that MLIR-generated code can call into the C++ runtime |
12 | // support library. The functionality provided in this library is meant |
13 | // to simplify benchmarking, testing, and debugging of MLIR code operating |
14 | // on sparse tensors. However, the provided functionality is **not** |
15 | // part of core MLIR itself. |
16 | // |
17 | // The following memory-resident sparse storage schemes are supported: |
18 | // |
19 | // (a) A coordinate scheme for temporarily storing and lexicographically |
20 | // sorting a sparse tensor by coordinate (SparseTensorCOO). |
21 | // |
22 | // (b) A "one-size-fits-all" sparse tensor storage scheme defined by |
23 | // per-dimension sparse/dense annnotations together with a dimension |
24 | // ordering used by MLIR compiler-generated code (SparseTensorStorage). |
25 | // |
26 | // The following external formats are supported: |
27 | // |
28 | // (1) Matrix Market Exchange (MME): *.mtx |
29 | // https://math.nist.gov/MatrixMarket/formats.html |
30 | // |
31 | // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns |
32 | // http://frostt.io/tensors/file-formats.html |
33 | // |
34 | // Two public APIs are supported: |
35 | // |
36 | // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse |
37 | // tensors. These methods should be used exclusively by MLIR |
38 | // compiler-generated code. |
39 | // |
40 | // (II) Methods that accept C-style data structures to interact with sparse |
41 | // tensors. These methods can be used by any external runtime that wants |
42 | // to interact with MLIR compiler-generated code. |
43 | // |
44 | // In both cases (I) and (II), the SparseTensorStorage format is externally |
45 | // only visible as an opaque pointer. |
46 | // |
47 | //===----------------------------------------------------------------------===// |
48 | |
49 | #include "mlir/ExecutionEngine/SparseTensorRuntime.h" |
50 | |
51 | #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS |
52 | |
53 | #include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h" |
54 | #include "mlir/ExecutionEngine/SparseTensor/COO.h" |
55 | #include "mlir/ExecutionEngine/SparseTensor/File.h" |
56 | #include "mlir/ExecutionEngine/SparseTensor/Storage.h" |
57 | |
58 | #include <cstring> |
59 | #include <numeric> |
60 | |
61 | using namespace mlir::sparse_tensor; |
62 | |
63 | //===----------------------------------------------------------------------===// |
64 | // |
65 | // Utilities for manipulating `StridedMemRefType`. |
66 | // |
67 | //===----------------------------------------------------------------------===// |
68 | |
69 | namespace { |
70 | |
71 | #define ASSERT_NO_STRIDE(MEMREF) \ |
72 | do { \ |
73 | assert((MEMREF) && "Memref is nullptr"); \ |
74 | assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \ |
75 | } while (false) |
76 | |
77 | #define MEMREF_GET_USIZE(MEMREF) \ |
78 | detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0]) |
79 | |
80 | #define ASSERT_USIZE_EQ(MEMREF, SZ) \ |
81 | assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) && \ |
82 | "Memref size mismatch") |
83 | |
84 | #define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset) |
85 | |
86 | /// Initializes the memref with the provided size and data pointer. This |
87 | /// is designed for functions which want to "return" a memref that aliases |
88 | /// into memory owned by some other object (e.g., `SparseTensorStorage`), |
89 | /// without doing any actual copying. (The "return" is in scarequotes |
90 | /// because the `_mlir_ciface_` calling convention migrates any returned |
91 | /// memrefs into an out-parameter passed before all the other function |
92 | /// parameters.) |
93 | template <typename DataSizeT, typename T> |
94 | static inline void aliasIntoMemref(DataSizeT size, T *data, |
95 | StridedMemRefType<T, 1> &ref) { |
96 | ref.basePtr = ref.data = data; |
97 | ref.offset = 0; |
98 | using MemrefSizeT = std::remove_reference_t<decltype(ref.sizes[0])>; |
99 | ref.sizes[0] = detail::checkOverflowCast<MemrefSizeT>(size); |
100 | ref.strides[0] = 1; |
101 | } |
102 | |
103 | } // anonymous namespace |
104 | |
105 | extern "C" { |
106 | |
107 | //===----------------------------------------------------------------------===// |
108 | // |
109 | // Public functions which operate on MLIR buffers (memrefs) to interact |
110 | // with sparse tensors (which are only visible as opaque pointers externally). |
111 | // |
112 | //===----------------------------------------------------------------------===// |
113 | |
114 | #define CASE(p, c, v, P, C, V) \ |
115 | if (posTp == (p) && crdTp == (c) && valTp == (v)) { \ |
116 | switch (action) { \ |
117 | case Action::kEmpty: { \ |
118 | return SparseTensorStorage<P, C, V>::newEmpty( \ |
119 | dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim); \ |
120 | } \ |
121 | case Action::kFromReader: { \ |
122 | assert(ptr && "Received nullptr for SparseTensorReader object"); \ |
123 | SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \ |
124 | return static_cast<void *>(reader.readSparseTensor<P, C, V>( \ |
125 | lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \ |
126 | } \ |
127 | case Action::kPack: { \ |
128 | assert(ptr && "Received nullptr for SparseTensorStorage object"); \ |
129 | intptr_t *buffers = static_cast<intptr_t *>(ptr); \ |
130 | return SparseTensorStorage<P, C, V>::newFromBuffers( \ |
131 | dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \ |
132 | dimRank, buffers); \ |
133 | } \ |
134 | case Action::kSortCOOInPlace: { \ |
135 | assert(ptr && "Received nullptr for SparseTensorStorage object"); \ |
136 | auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \ |
137 | tensor.sortInPlace(); \ |
138 | return ptr; \ |
139 | } \ |
140 | } \ |
141 | fprintf(stderr, "unknown action %d\n", static_cast<uint32_t>(action)); \ |
142 | exit(1); \ |
143 | } |
144 | |
145 | #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) |
146 | |
147 | // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor |
148 | // can safely rewrite kIndex to kU64. We make this assertion to guarantee |
149 | // that this file cannot get out of sync with its header. |
150 | static_assert(std::is_same<index_type, uint64_t>::value, |
151 | "Expected index_type == uint64_t" ); |
152 | |
153 | // The Swiss-army-knife for sparse tensor creation. |
154 | void *_mlir_ciface_newSparseTensor( // NOLINT |
155 | StridedMemRefType<index_type, 1> *dimSizesRef, |
156 | StridedMemRefType<index_type, 1> *lvlSizesRef, |
157 | StridedMemRefType<LevelType, 1> *lvlTypesRef, |
158 | StridedMemRefType<index_type, 1> *dim2lvlRef, |
159 | StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp, |
160 | OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) { |
161 | ASSERT_NO_STRIDE(dimSizesRef); |
162 | ASSERT_NO_STRIDE(lvlSizesRef); |
163 | ASSERT_NO_STRIDE(lvlTypesRef); |
164 | ASSERT_NO_STRIDE(dim2lvlRef); |
165 | ASSERT_NO_STRIDE(lvl2dimRef); |
166 | const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef); |
167 | const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef); |
168 | ASSERT_USIZE_EQ(lvlTypesRef, lvlRank); |
169 | ASSERT_USIZE_EQ(dim2lvlRef, lvlRank); |
170 | ASSERT_USIZE_EQ(lvl2dimRef, dimRank); |
171 | const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef); |
172 | const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef); |
173 | const LevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef); |
174 | const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); |
175 | const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); |
176 | |
177 | // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. |
178 | // This is safe because of the static_assert above. |
179 | if (posTp == OverheadType::kIndex) |
180 | posTp = OverheadType::kU64; |
181 | if (crdTp == OverheadType::kIndex) |
182 | crdTp = OverheadType::kU64; |
183 | |
184 | // Double matrices with all combinations of overhead storage. |
185 | CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, |
186 | uint64_t, double); |
187 | CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, |
188 | uint32_t, double); |
189 | CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, |
190 | uint16_t, double); |
191 | CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, |
192 | uint8_t, double); |
193 | CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, |
194 | uint64_t, double); |
195 | CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, |
196 | uint32_t, double); |
197 | CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, |
198 | uint16_t, double); |
199 | CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, |
200 | uint8_t, double); |
201 | CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, |
202 | uint64_t, double); |
203 | CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, |
204 | uint32_t, double); |
205 | CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, |
206 | uint16_t, double); |
207 | CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, |
208 | uint8_t, double); |
209 | CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, |
210 | uint64_t, double); |
211 | CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, |
212 | uint32_t, double); |
213 | CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, |
214 | uint16_t, double); |
215 | CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, |
216 | uint8_t, double); |
217 | |
218 | // Float matrices with all combinations of overhead storage. |
219 | CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, |
220 | uint64_t, float); |
221 | CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, |
222 | uint32_t, float); |
223 | CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, |
224 | uint16_t, float); |
225 | CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, |
226 | uint8_t, float); |
227 | CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, |
228 | uint64_t, float); |
229 | CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, |
230 | uint32_t, float); |
231 | CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, |
232 | uint16_t, float); |
233 | CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, |
234 | uint8_t, float); |
235 | CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, |
236 | uint64_t, float); |
237 | CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, |
238 | uint32_t, float); |
239 | CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, |
240 | uint16_t, float); |
241 | CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, |
242 | uint8_t, float); |
243 | CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, |
244 | uint64_t, float); |
245 | CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, |
246 | uint32_t, float); |
247 | CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, |
248 | uint16_t, float); |
249 | CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, |
250 | uint8_t, float); |
251 | |
252 | // Two-byte floats with both overheads of the same type. |
253 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16); |
254 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16); |
255 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16); |
256 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16); |
257 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16); |
258 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16); |
259 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16); |
260 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16); |
261 | |
262 | // Integral matrices with both overheads of the same type. |
263 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); |
264 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); |
265 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); |
266 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); |
267 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t); |
268 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); |
269 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); |
270 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); |
271 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t); |
272 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); |
273 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); |
274 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); |
275 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t); |
276 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); |
277 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); |
278 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); |
279 | |
280 | // Complex matrices with wide overhead. |
281 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64); |
282 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32); |
283 | |
284 | // Unsupported case (add above if needed). |
285 | fprintf(stderr, format: "unsupported combination of types: <P=%d, C=%d, V=%d>\n" , |
286 | static_cast<int>(posTp), static_cast<int>(crdTp), |
287 | static_cast<int>(valTp)); |
288 | exit(status: 1); |
289 | } |
290 | #undef CASE |
291 | #undef CASE_SECSAME |
292 | |
293 | #define IMPL_SPARSEVALUES(VNAME, V) \ |
294 | void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \ |
295 | void *tensor) { \ |
296 | assert(ref &&tensor); \ |
297 | std::vector<V> *v; \ |
298 | static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \ |
299 | assert(v); \ |
300 | aliasIntoMemref(v->size(), v->data(), *ref); \ |
301 | } |
302 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES) |
303 | #undef IMPL_SPARSEVALUES |
304 | |
305 | #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \ |
306 | void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \ |
307 | index_type lvl) { \ |
308 | assert(ref &&tensor); \ |
309 | std::vector<TYPE> *v; \ |
310 | static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, lvl); \ |
311 | assert(v); \ |
312 | aliasIntoMemref(v->size(), v->data(), *ref); \ |
313 | } |
314 | |
315 | #define IMPL_SPARSEPOSITIONS(PNAME, P) \ |
316 | IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions) |
317 | MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS) |
318 | #undef IMPL_SPARSEPOSITIONS |
319 | |
320 | #define IMPL_SPARSECOORDINATES(CNAME, C) \ |
321 | IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates) |
322 | MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES) |
323 | #undef IMPL_SPARSECOORDINATES |
324 | |
325 | #define IMPL_SPARSECOORDINATESBUFFER(CNAME, C) \ |
326 | IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer) |
327 | MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER) |
328 | #undef IMPL_SPARSECOORDINATESBUFFER |
329 | |
330 | #undef IMPL_GETOVERHEAD |
331 | |
332 | #define IMPL_LEXINSERT(VNAME, V) \ |
333 | void _mlir_ciface_lexInsert##VNAME( \ |
334 | void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \ |
335 | StridedMemRefType<V, 0> *vref) { \ |
336 | assert(t &&vref); \ |
337 | auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \ |
338 | ASSERT_NO_STRIDE(lvlCoordsRef); \ |
339 | index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \ |
340 | assert(lvlCoords); \ |
341 | V *value = MEMREF_GET_PAYLOAD(vref); \ |
342 | tensor.lexInsert(lvlCoords, *value); \ |
343 | } |
344 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT) |
345 | #undef IMPL_LEXINSERT |
346 | |
347 | #define IMPL_EXPINSERT(VNAME, V) \ |
348 | void _mlir_ciface_expInsert##VNAME( \ |
349 | void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \ |
350 | StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \ |
351 | StridedMemRefType<index_type, 1> *aref, index_type count) { \ |
352 | assert(t); \ |
353 | auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \ |
354 | ASSERT_NO_STRIDE(lvlCoordsRef); \ |
355 | ASSERT_NO_STRIDE(vref); \ |
356 | ASSERT_NO_STRIDE(fref); \ |
357 | ASSERT_NO_STRIDE(aref); \ |
358 | ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref)); \ |
359 | index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \ |
360 | V *values = MEMREF_GET_PAYLOAD(vref); \ |
361 | bool *filled = MEMREF_GET_PAYLOAD(fref); \ |
362 | index_type *added = MEMREF_GET_PAYLOAD(aref); \ |
363 | uint64_t expsz = vref->sizes[0]; \ |
364 | tensor.expInsert(lvlCoords, values, filled, added, count, expsz); \ |
365 | } |
366 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT) |
367 | #undef IMPL_EXPINSERT |
368 | |
369 | void *_mlir_ciface_createCheckedSparseTensorReader( |
370 | char *filename, StridedMemRefType<index_type, 1> *dimShapeRef, |
371 | PrimaryType valTp) { |
372 | ASSERT_NO_STRIDE(dimShapeRef); |
373 | const uint64_t dimRank = MEMREF_GET_USIZE(dimShapeRef); |
374 | const index_type *dimShape = MEMREF_GET_PAYLOAD(dimShapeRef); |
375 | auto *reader = SparseTensorReader::create(filename, dimRank, dimShape, valTp); |
376 | return static_cast<void *>(reader); |
377 | } |
378 | |
379 | void _mlir_ciface_getSparseTensorReaderDimSizes( |
380 | StridedMemRefType<index_type, 1> *out, void *p) { |
381 | assert(out && p); |
382 | SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p); |
383 | auto *dimSizes = const_cast<uint64_t *>(reader.getDimSizes()); |
384 | aliasIntoMemref(size: reader.getRank(), data: dimSizes, ref&: *out); |
385 | } |
386 | |
387 | #define IMPL_GETNEXT(VNAME, V, CNAME, C) \ |
388 | bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME( \ |
389 | void *p, StridedMemRefType<index_type, 1> *dim2lvlRef, \ |
390 | StridedMemRefType<index_type, 1> *lvl2dimRef, \ |
391 | StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) { \ |
392 | assert(p); \ |
393 | auto &reader = *static_cast<SparseTensorReader *>(p); \ |
394 | ASSERT_NO_STRIDE(dim2lvlRef); \ |
395 | ASSERT_NO_STRIDE(lvl2dimRef); \ |
396 | ASSERT_NO_STRIDE(cref); \ |
397 | ASSERT_NO_STRIDE(vref); \ |
398 | const uint64_t dimRank = reader.getRank(); \ |
399 | const uint64_t lvlRank = MEMREF_GET_USIZE(dim2lvlRef); \ |
400 | const uint64_t cSize = MEMREF_GET_USIZE(cref); \ |
401 | const uint64_t vSize = MEMREF_GET_USIZE(vref); \ |
402 | ASSERT_USIZE_EQ(lvl2dimRef, dimRank); \ |
403 | assert(cSize >= lvlRank * reader.getNSE()); \ |
404 | assert(vSize >= reader.getNSE()); \ |
405 | (void)dimRank; \ |
406 | (void)cSize; \ |
407 | (void)vSize; \ |
408 | index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \ |
409 | index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); \ |
410 | C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref); \ |
411 | V *values = MEMREF_GET_PAYLOAD(vref); \ |
412 | return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim, \ |
413 | lvlCoordinates, values); \ |
414 | } |
415 | MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT) |
416 | #undef IMPL_GETNEXT |
417 | |
418 | void _mlir_ciface_outSparseTensorWriterMetaData( |
419 | void *p, index_type dimRank, index_type nse, |
420 | StridedMemRefType<index_type, 1> *dimSizesRef) { |
421 | assert(p); |
422 | ASSERT_NO_STRIDE(dimSizesRef); |
423 | assert(dimRank != 0); |
424 | index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef); |
425 | std::ostream &file = *static_cast<std::ostream *>(p); |
426 | file << dimRank << " " << nse << '\n'; |
427 | for (index_type d = 0; d < dimRank - 1; d++) |
428 | file << dimSizes[d] << " " ; |
429 | file << dimSizes[dimRank - 1] << '\n'; |
430 | } |
431 | |
432 | #define IMPL_OUTNEXT(VNAME, V) \ |
433 | void _mlir_ciface_outSparseTensorWriterNext##VNAME( \ |
434 | void *p, index_type dimRank, \ |
435 | StridedMemRefType<index_type, 1> *dimCoordsRef, \ |
436 | StridedMemRefType<V, 0> *vref) { \ |
437 | assert(p &&vref); \ |
438 | ASSERT_NO_STRIDE(dimCoordsRef); \ |
439 | const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \ |
440 | std::ostream &file = *static_cast<std::ostream *>(p); \ |
441 | for (index_type d = 0; d < dimRank; d++) \ |
442 | file << (dimCoords[d] + 1) << " "; \ |
443 | V *value = MEMREF_GET_PAYLOAD(vref); \ |
444 | file << *value << '\n'; \ |
445 | } |
446 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT) |
447 | #undef IMPL_OUTNEXT |
448 | |
449 | //===----------------------------------------------------------------------===// |
450 | // |
451 | // Public functions which accept only C-style data structures to interact |
452 | // with sparse tensors (which are only visible as opaque pointers externally). |
453 | // |
454 | //===----------------------------------------------------------------------===// |
455 | |
456 | index_type sparseLvlSize(void *tensor, index_type l) { |
457 | return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l); |
458 | } |
459 | |
460 | index_type sparseDimSize(void *tensor, index_type d) { |
461 | return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d); |
462 | } |
463 | |
464 | void endLexInsert(void *tensor) { |
465 | return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert(); |
466 | } |
467 | |
468 | void delSparseTensor(void *tensor) { |
469 | delete static_cast<SparseTensorStorageBase *>(tensor); |
470 | } |
471 | |
472 | char *getTensorFilename(index_type id) { |
473 | constexpr size_t bufSize = 80; |
474 | char var[bufSize]; |
475 | snprintf(s: var, maxlen: bufSize, format: "TENSOR%" PRIu64, id); |
476 | char *env = getenv(name: var); |
477 | if (!env) { |
478 | fprintf(stderr, format: "Environment variable %s is not set\n" , var); |
479 | exit(status: 1); |
480 | } |
481 | return env; |
482 | } |
483 | |
484 | index_type getSparseTensorReaderNSE(void *p) { |
485 | return static_cast<SparseTensorReader *>(p)->getNSE(); |
486 | } |
487 | |
488 | void delSparseTensorReader(void *p) { |
489 | delete static_cast<SparseTensorReader *>(p); |
490 | } |
491 | |
492 | void *createSparseTensorWriter(char *filename) { |
493 | std::ostream *file = |
494 | (filename[0] == 0) ? &std::cout : new std::ofstream(filename); |
495 | *file << "# extended FROSTT format\n" ; |
496 | return static_cast<void *>(file); |
497 | } |
498 | |
499 | void delSparseTensorWriter(void *p) { |
500 | std::ostream *file = static_cast<std::ostream *>(p); |
501 | file->flush(); |
502 | assert(file->good()); |
503 | if (file != &std::cout) |
504 | delete file; |
505 | } |
506 | |
507 | } // extern "C" |
508 | |
509 | #undef MEMREF_GET_PAYLOAD |
510 | #undef ASSERT_USIZE_EQ |
511 | #undef MEMREF_GET_USIZE |
512 | #undef ASSERT_NO_STRIDE |
513 | |
514 | #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS |
515 | |