1//===- SparseTensorRuntime.cpp - SparseTensor runtime support lib ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a light-weight runtime support library for
10// manipulating sparse tensors from MLIR. More specifically, it provides
11// C-API wrappers so that MLIR-generated code can call into the C++ runtime
12// support library. The functionality provided in this library is meant
13// to simplify benchmarking, testing, and debugging of MLIR code operating
14// on sparse tensors. However, the provided functionality is **not**
15// part of core MLIR itself.
16//
17// The following memory-resident sparse storage schemes are supported:
18//
19// (a) A coordinate scheme for temporarily storing and lexicographically
20// sorting a sparse tensor by coordinate (SparseTensorCOO).
21//
22// (b) A "one-size-fits-all" sparse tensor storage scheme defined by
23// per-dimension sparse/dense annnotations together with a dimension
24// ordering used by MLIR compiler-generated code (SparseTensorStorage).
25//
26// The following external formats are supported:
27//
28// (1) Matrix Market Exchange (MME): *.mtx
29// https://math.nist.gov/MatrixMarket/formats.html
30//
31// (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
32// http://frostt.io/tensors/file-formats.html
33//
34// Two public APIs are supported:
35//
36// (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
37// tensors. These methods should be used exclusively by MLIR
38// compiler-generated code.
39//
40// (II) Methods that accept C-style data structures to interact with sparse
41// tensors. These methods can be used by any external runtime that wants
42// to interact with MLIR compiler-generated code.
43//
44// In both cases (I) and (II), the SparseTensorStorage format is externally
45// only visible as an opaque pointer.
46//
47//===----------------------------------------------------------------------===//
48
49#include "mlir/ExecutionEngine/SparseTensorRuntime.h"
50
51#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
52
53#include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
54#include "mlir/ExecutionEngine/SparseTensor/COO.h"
55#include "mlir/ExecutionEngine/SparseTensor/File.h"
56#include "mlir/ExecutionEngine/SparseTensor/Storage.h"
57
58#include <cstring>
59#include <numeric>
60
61using namespace mlir::sparse_tensor;
62
63//===----------------------------------------------------------------------===//
64//
65// Utilities for manipulating `StridedMemRefType`.
66//
67//===----------------------------------------------------------------------===//
68
69namespace {
70
71#define ASSERT_NO_STRIDE(MEMREF) \
72 do { \
73 assert((MEMREF) && "Memref is nullptr"); \
74 assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \
75 } while (false)
76
77#define MEMREF_GET_USIZE(MEMREF) \
78 detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0])
79
80#define ASSERT_USIZE_EQ(MEMREF, SZ) \
81 assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) && \
82 "Memref size mismatch")
83
84#define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
85
86/// Initializes the memref with the provided size and data pointer. This
87/// is designed for functions which want to "return" a memref that aliases
88/// into memory owned by some other object (e.g., `SparseTensorStorage`),
89/// without doing any actual copying. (The "return" is in scarequotes
90/// because the `_mlir_ciface_` calling convention migrates any returned
91/// memrefs into an out-parameter passed before all the other function
92/// parameters.)
93template <typename DataSizeT, typename T>
94static inline void aliasIntoMemref(DataSizeT size, T *data,
95 StridedMemRefType<T, 1> &ref) {
96 ref.basePtr = ref.data = data;
97 ref.offset = 0;
98 using MemrefSizeT = std::remove_reference_t<decltype(ref.sizes[0])>;
99 ref.sizes[0] = detail::checkOverflowCast<MemrefSizeT>(size);
100 ref.strides[0] = 1;
101}
102
103} // anonymous namespace
104
105extern "C" {
106
107//===----------------------------------------------------------------------===//
108//
109// Public functions which operate on MLIR buffers (memrefs) to interact
110// with sparse tensors (which are only visible as opaque pointers externally).
111//
112//===----------------------------------------------------------------------===//
113
114#define CASE(p, c, v, P, C, V) \
115 if (posTp == (p) && crdTp == (c) && valTp == (v)) { \
116 switch (action) { \
117 case Action::kEmpty: { \
118 return SparseTensorStorage<P, C, V>::newEmpty( \
119 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim); \
120 } \
121 case Action::kFromReader: { \
122 assert(ptr && "Received nullptr for SparseTensorReader object"); \
123 SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \
124 return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
125 lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
126 } \
127 case Action::kPack: { \
128 assert(ptr && "Received nullptr for SparseTensorStorage object"); \
129 intptr_t *buffers = static_cast<intptr_t *>(ptr); \
130 return SparseTensorStorage<P, C, V>::newFromBuffers( \
131 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
132 dimRank, buffers); \
133 } \
134 case Action::kSortCOOInPlace: { \
135 assert(ptr && "Received nullptr for SparseTensorStorage object"); \
136 auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
137 tensor.sortInPlace(); \
138 return ptr; \
139 } \
140 } \
141 fprintf(stderr, "unknown action %d\n", static_cast<uint32_t>(action)); \
142 exit(1); \
143 }
144
145#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
146
147// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
148// can safely rewrite kIndex to kU64. We make this assertion to guarantee
149// that this file cannot get out of sync with its header.
150static_assert(std::is_same<index_type, uint64_t>::value,
151 "Expected index_type == uint64_t");
152
153// The Swiss-army-knife for sparse tensor creation.
154void *_mlir_ciface_newSparseTensor( // NOLINT
155 StridedMemRefType<index_type, 1> *dimSizesRef,
156 StridedMemRefType<index_type, 1> *lvlSizesRef,
157 StridedMemRefType<LevelType, 1> *lvlTypesRef,
158 StridedMemRefType<index_type, 1> *dim2lvlRef,
159 StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
160 OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) {
161 ASSERT_NO_STRIDE(dimSizesRef);
162 ASSERT_NO_STRIDE(lvlSizesRef);
163 ASSERT_NO_STRIDE(lvlTypesRef);
164 ASSERT_NO_STRIDE(dim2lvlRef);
165 ASSERT_NO_STRIDE(lvl2dimRef);
166 const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef);
167 const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
168 ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
169 ASSERT_USIZE_EQ(dim2lvlRef, lvlRank);
170 ASSERT_USIZE_EQ(lvl2dimRef, dimRank);
171 const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
172 const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
173 const LevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
174 const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
175 const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
176
177 // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
178 // This is safe because of the static_assert above.
179 if (posTp == OverheadType::kIndex)
180 posTp = OverheadType::kU64;
181 if (crdTp == OverheadType::kIndex)
182 crdTp = OverheadType::kU64;
183
184 // Double matrices with all combinations of overhead storage.
185 CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
186 uint64_t, double);
187 CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
188 uint32_t, double);
189 CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
190 uint16_t, double);
191 CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
192 uint8_t, double);
193 CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
194 uint64_t, double);
195 CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
196 uint32_t, double);
197 CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
198 uint16_t, double);
199 CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
200 uint8_t, double);
201 CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
202 uint64_t, double);
203 CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
204 uint32_t, double);
205 CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
206 uint16_t, double);
207 CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
208 uint8_t, double);
209 CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
210 uint64_t, double);
211 CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
212 uint32_t, double);
213 CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
214 uint16_t, double);
215 CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
216 uint8_t, double);
217
218 // Float matrices with all combinations of overhead storage.
219 CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
220 uint64_t, float);
221 CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
222 uint32_t, float);
223 CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
224 uint16_t, float);
225 CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
226 uint8_t, float);
227 CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
228 uint64_t, float);
229 CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
230 uint32_t, float);
231 CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
232 uint16_t, float);
233 CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
234 uint8_t, float);
235 CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
236 uint64_t, float);
237 CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
238 uint32_t, float);
239 CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
240 uint16_t, float);
241 CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
242 uint8_t, float);
243 CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
244 uint64_t, float);
245 CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
246 uint32_t, float);
247 CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
248 uint16_t, float);
249 CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
250 uint8_t, float);
251
252 // Two-byte floats with both overheads of the same type.
253 CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
254 CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
255 CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
256 CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
257 CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
258 CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
259 CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
260 CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
261
262 // Integral matrices with both overheads of the same type.
263 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
264 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
265 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
266 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
267 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
268 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
269 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
270 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
271 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
272 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
273 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
274 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
275 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
276 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
277 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
278 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
279
280 // Complex matrices with wide overhead.
281 CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
282 CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
283
284 // Unsupported case (add above if needed).
285 fprintf(stderr, format: "unsupported combination of types: <P=%d, C=%d, V=%d>\n",
286 static_cast<int>(posTp), static_cast<int>(crdTp),
287 static_cast<int>(valTp));
288 exit(status: 1);
289}
290#undef CASE
291#undef CASE_SECSAME
292
293#define IMPL_SPARSEVALUES(VNAME, V) \
294 void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \
295 void *tensor) { \
296 assert(ref &&tensor); \
297 std::vector<V> *v; \
298 static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
299 assert(v); \
300 aliasIntoMemref(v->size(), v->data(), *ref); \
301 }
302MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
303#undef IMPL_SPARSEVALUES
304
305#define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \
306 void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
307 index_type lvl) { \
308 assert(ref &&tensor); \
309 std::vector<TYPE> *v; \
310 static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, lvl); \
311 assert(v); \
312 aliasIntoMemref(v->size(), v->data(), *ref); \
313 }
314
315#define IMPL_SPARSEPOSITIONS(PNAME, P) \
316 IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
317MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
318#undef IMPL_SPARSEPOSITIONS
319
320#define IMPL_SPARSECOORDINATES(CNAME, C) \
321 IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
322MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
323#undef IMPL_SPARSECOORDINATES
324
325#define IMPL_SPARSECOORDINATESBUFFER(CNAME, C) \
326 IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer)
327MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER)
328#undef IMPL_SPARSECOORDINATESBUFFER
329
330#undef IMPL_GETOVERHEAD
331
332#define IMPL_LEXINSERT(VNAME, V) \
333 void _mlir_ciface_lexInsert##VNAME( \
334 void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \
335 StridedMemRefType<V, 0> *vref) { \
336 assert(t &&vref); \
337 auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \
338 ASSERT_NO_STRIDE(lvlCoordsRef); \
339 index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \
340 assert(lvlCoords); \
341 V *value = MEMREF_GET_PAYLOAD(vref); \
342 tensor.lexInsert(lvlCoords, *value); \
343 }
344MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
345#undef IMPL_LEXINSERT
346
347#define IMPL_EXPINSERT(VNAME, V) \
348 void _mlir_ciface_expInsert##VNAME( \
349 void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \
350 StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
351 StridedMemRefType<index_type, 1> *aref, index_type count) { \
352 assert(t); \
353 auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \
354 ASSERT_NO_STRIDE(lvlCoordsRef); \
355 ASSERT_NO_STRIDE(vref); \
356 ASSERT_NO_STRIDE(fref); \
357 ASSERT_NO_STRIDE(aref); \
358 ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref)); \
359 index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \
360 V *values = MEMREF_GET_PAYLOAD(vref); \
361 bool *filled = MEMREF_GET_PAYLOAD(fref); \
362 index_type *added = MEMREF_GET_PAYLOAD(aref); \
363 uint64_t expsz = vref->sizes[0]; \
364 tensor.expInsert(lvlCoords, values, filled, added, count, expsz); \
365 }
366MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
367#undef IMPL_EXPINSERT
368
369void *_mlir_ciface_createCheckedSparseTensorReader(
370 char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
371 PrimaryType valTp) {
372 ASSERT_NO_STRIDE(dimShapeRef);
373 const uint64_t dimRank = MEMREF_GET_USIZE(dimShapeRef);
374 const index_type *dimShape = MEMREF_GET_PAYLOAD(dimShapeRef);
375 auto *reader = SparseTensorReader::create(filename, dimRank, dimShape, valTp);
376 return static_cast<void *>(reader);
377}
378
379void _mlir_ciface_getSparseTensorReaderDimSizes(
380 StridedMemRefType<index_type, 1> *out, void *p) {
381 assert(out && p);
382 SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
383 auto *dimSizes = const_cast<uint64_t *>(reader.getDimSizes());
384 aliasIntoMemref(size: reader.getRank(), data: dimSizes, ref&: *out);
385}
386
387#define IMPL_GETNEXT(VNAME, V, CNAME, C) \
388 bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME( \
389 void *p, StridedMemRefType<index_type, 1> *dim2lvlRef, \
390 StridedMemRefType<index_type, 1> *lvl2dimRef, \
391 StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) { \
392 assert(p); \
393 auto &reader = *static_cast<SparseTensorReader *>(p); \
394 ASSERT_NO_STRIDE(dim2lvlRef); \
395 ASSERT_NO_STRIDE(lvl2dimRef); \
396 ASSERT_NO_STRIDE(cref); \
397 ASSERT_NO_STRIDE(vref); \
398 const uint64_t dimRank = reader.getRank(); \
399 const uint64_t lvlRank = MEMREF_GET_USIZE(dim2lvlRef); \
400 const uint64_t cSize = MEMREF_GET_USIZE(cref); \
401 const uint64_t vSize = MEMREF_GET_USIZE(vref); \
402 ASSERT_USIZE_EQ(lvl2dimRef, dimRank); \
403 assert(cSize >= lvlRank * reader.getNSE()); \
404 assert(vSize >= reader.getNSE()); \
405 (void)dimRank; \
406 (void)cSize; \
407 (void)vSize; \
408 index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
409 index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); \
410 C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref); \
411 V *values = MEMREF_GET_PAYLOAD(vref); \
412 return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim, \
413 lvlCoordinates, values); \
414 }
415MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
416#undef IMPL_GETNEXT
417
418void _mlir_ciface_outSparseTensorWriterMetaData(
419 void *p, index_type dimRank, index_type nse,
420 StridedMemRefType<index_type, 1> *dimSizesRef) {
421 assert(p);
422 ASSERT_NO_STRIDE(dimSizesRef);
423 assert(dimRank != 0);
424 index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
425 std::ostream &file = *static_cast<std::ostream *>(p);
426 file << dimRank << " " << nse << '\n';
427 for (index_type d = 0; d < dimRank - 1; d++)
428 file << dimSizes[d] << " ";
429 file << dimSizes[dimRank - 1] << '\n';
430}
431
432#define IMPL_OUTNEXT(VNAME, V) \
433 void _mlir_ciface_outSparseTensorWriterNext##VNAME( \
434 void *p, index_type dimRank, \
435 StridedMemRefType<index_type, 1> *dimCoordsRef, \
436 StridedMemRefType<V, 0> *vref) { \
437 assert(p &&vref); \
438 ASSERT_NO_STRIDE(dimCoordsRef); \
439 const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \
440 std::ostream &file = *static_cast<std::ostream *>(p); \
441 for (index_type d = 0; d < dimRank; d++) \
442 file << (dimCoords[d] + 1) << " "; \
443 V *value = MEMREF_GET_PAYLOAD(vref); \
444 file << *value << '\n'; \
445 }
446MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
447#undef IMPL_OUTNEXT
448
449//===----------------------------------------------------------------------===//
450//
451// Public functions which accept only C-style data structures to interact
452// with sparse tensors (which are only visible as opaque pointers externally).
453//
454//===----------------------------------------------------------------------===//
455
456index_type sparseLvlSize(void *tensor, index_type l) {
457 return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l);
458}
459
460index_type sparseDimSize(void *tensor, index_type d) {
461 return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
462}
463
464void endLexInsert(void *tensor) {
465 return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert();
466}
467
468void delSparseTensor(void *tensor) {
469 delete static_cast<SparseTensorStorageBase *>(tensor);
470}
471
472char *getTensorFilename(index_type id) {
473 constexpr size_t bufSize = 80;
474 char var[bufSize];
475 snprintf(s: var, maxlen: bufSize, format: "TENSOR%" PRIu64, id);
476 char *env = getenv(name: var);
477 if (!env) {
478 fprintf(stderr, format: "Environment variable %s is not set\n", var);
479 exit(status: 1);
480 }
481 return env;
482}
483
484index_type getSparseTensorReaderNSE(void *p) {
485 return static_cast<SparseTensorReader *>(p)->getNSE();
486}
487
488void delSparseTensorReader(void *p) {
489 delete static_cast<SparseTensorReader *>(p);
490}
491
492void *createSparseTensorWriter(char *filename) {
493 std::ostream *file =
494 (filename[0] == 0) ? &std::cout : new std::ofstream(filename);
495 *file << "# extended FROSTT format\n";
496 return static_cast<void *>(file);
497}
498
499void delSparseTensorWriter(void *p) {
500 std::ostream *file = static_cast<std::ostream *>(p);
501 file->flush();
502 assert(file->good());
503 if (file != &std::cout)
504 delete file;
505}
506
507} // extern "C"
508
509#undef MEMREF_GET_PAYLOAD
510#undef ASSERT_USIZE_EQ
511#undef MEMREF_GET_USIZE
512#undef ASSERT_NO_STRIDE
513
514#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
515

source code of mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp