| 1 | //===- SparseTensorRuntime.cpp - SparseTensor runtime support lib ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a light-weight runtime support library for |
| 10 | // manipulating sparse tensors from MLIR. More specifically, it provides |
| 11 | // C-API wrappers so that MLIR-generated code can call into the C++ runtime |
| 12 | // support library. The functionality provided in this library is meant |
| 13 | // to simplify benchmarking, testing, and debugging of MLIR code operating |
| 14 | // on sparse tensors. However, the provided functionality is **not** |
| 15 | // part of core MLIR itself. |
| 16 | // |
| 17 | // The following memory-resident sparse storage schemes are supported: |
| 18 | // |
| 19 | // (a) A coordinate scheme for temporarily storing and lexicographically |
| 20 | // sorting a sparse tensor by coordinate (SparseTensorCOO). |
| 21 | // |
| 22 | // (b) A "one-size-fits-all" sparse tensor storage scheme defined by |
| 23 | // per-dimension sparse/dense annnotations together with a dimension |
| 24 | // ordering used by MLIR compiler-generated code (SparseTensorStorage). |
| 25 | // |
| 26 | // The following external formats are supported: |
| 27 | // |
| 28 | // (1) Matrix Market Exchange (MME): *.mtx |
| 29 | // https://math.nist.gov/MatrixMarket/formats.html |
| 30 | // |
| 31 | // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns |
| 32 | // http://frostt.io/tensors/file-formats.html |
| 33 | // |
| 34 | // Two public APIs are supported: |
| 35 | // |
| 36 | // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse |
| 37 | // tensors. These methods should be used exclusively by MLIR |
| 38 | // compiler-generated code. |
| 39 | // |
| 40 | // (II) Methods that accept C-style data structures to interact with sparse |
| 41 | // tensors. These methods can be used by any external runtime that wants |
| 42 | // to interact with MLIR compiler-generated code. |
| 43 | // |
| 44 | // In both cases (I) and (II), the SparseTensorStorage format is externally |
| 45 | // only visible as an opaque pointer. |
| 46 | // |
| 47 | //===----------------------------------------------------------------------===// |
| 48 | |
| 49 | #include "mlir/ExecutionEngine/SparseTensorRuntime.h" |
| 50 | |
| 51 | #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS |
| 52 | |
| 53 | #include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h" |
| 54 | #include "mlir/ExecutionEngine/SparseTensor/COO.h" |
| 55 | #include "mlir/ExecutionEngine/SparseTensor/File.h" |
| 56 | #include "mlir/ExecutionEngine/SparseTensor/Storage.h" |
| 57 | |
| 58 | #include <cstring> |
| 59 | #include <numeric> |
| 60 | |
| 61 | using namespace mlir::sparse_tensor; |
| 62 | |
| 63 | //===----------------------------------------------------------------------===// |
| 64 | // |
| 65 | // Utilities for manipulating `StridedMemRefType`. |
| 66 | // |
| 67 | //===----------------------------------------------------------------------===// |
| 68 | |
| 69 | namespace { |
| 70 | |
| 71 | #define ASSERT_NO_STRIDE(MEMREF) \ |
| 72 | do { \ |
| 73 | assert((MEMREF) && "Memref is nullptr"); \ |
| 74 | assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \ |
| 75 | } while (false) |
| 76 | |
| 77 | #define MEMREF_GET_USIZE(MEMREF) \ |
| 78 | detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0]) |
| 79 | |
| 80 | #define ASSERT_USIZE_EQ(MEMREF, SZ) \ |
| 81 | assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) && \ |
| 82 | "Memref size mismatch") |
| 83 | |
| 84 | #define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset) |
| 85 | |
| 86 | /// Initializes the memref with the provided size and data pointer. This |
| 87 | /// is designed for functions which want to "return" a memref that aliases |
| 88 | /// into memory owned by some other object (e.g., `SparseTensorStorage`), |
| 89 | /// without doing any actual copying. (The "return" is in scarequotes |
| 90 | /// because the `_mlir_ciface_` calling convention migrates any returned |
| 91 | /// memrefs into an out-parameter passed before all the other function |
| 92 | /// parameters.) |
| 93 | template <typename DataSizeT, typename T> |
| 94 | static inline void aliasIntoMemref(DataSizeT size, T *data, |
| 95 | StridedMemRefType<T, 1> &ref) { |
| 96 | ref.basePtr = ref.data = data; |
| 97 | ref.offset = 0; |
| 98 | using MemrefSizeT = std::remove_reference_t<decltype(ref.sizes[0])>; |
| 99 | ref.sizes[0] = detail::checkOverflowCast<MemrefSizeT>(size); |
| 100 | ref.strides[0] = 1; |
| 101 | } |
| 102 | |
| 103 | } // anonymous namespace |
| 104 | |
| 105 | extern "C" { |
| 106 | |
| 107 | //===----------------------------------------------------------------------===// |
| 108 | // |
| 109 | // Public functions which operate on MLIR buffers (memrefs) to interact |
| 110 | // with sparse tensors (which are only visible as opaque pointers externally). |
| 111 | // |
| 112 | //===----------------------------------------------------------------------===// |
| 113 | |
| 114 | #define CASE(p, c, v, P, C, V) \ |
| 115 | if (posTp == (p) && crdTp == (c) && valTp == (v)) { \ |
| 116 | switch (action) { \ |
| 117 | case Action::kEmpty: { \ |
| 118 | return SparseTensorStorage<P, C, V>::newEmpty( \ |
| 119 | dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim); \ |
| 120 | } \ |
| 121 | case Action::kFromReader: { \ |
| 122 | assert(ptr && "Received nullptr for SparseTensorReader object"); \ |
| 123 | SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \ |
| 124 | return static_cast<void *>(reader.readSparseTensor<P, C, V>( \ |
| 125 | lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \ |
| 126 | } \ |
| 127 | case Action::kPack: { \ |
| 128 | assert(ptr && "Received nullptr for SparseTensorStorage object"); \ |
| 129 | intptr_t *buffers = static_cast<intptr_t *>(ptr); \ |
| 130 | return SparseTensorStorage<P, C, V>::newFromBuffers( \ |
| 131 | dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \ |
| 132 | dimRank, buffers); \ |
| 133 | } \ |
| 134 | case Action::kSortCOOInPlace: { \ |
| 135 | assert(ptr && "Received nullptr for SparseTensorStorage object"); \ |
| 136 | auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \ |
| 137 | tensor.sortInPlace(); \ |
| 138 | return ptr; \ |
| 139 | } \ |
| 140 | } \ |
| 141 | fprintf(stderr, "unknown action %d\n", static_cast<uint32_t>(action)); \ |
| 142 | exit(1); \ |
| 143 | } |
| 144 | |
| 145 | #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V) |
| 146 | |
| 147 | // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor |
| 148 | // can safely rewrite kIndex to kU64. We make this assertion to guarantee |
| 149 | // that this file cannot get out of sync with its header. |
| 150 | static_assert(std::is_same<index_type, uint64_t>::value, |
| 151 | "Expected index_type == uint64_t" ); |
| 152 | |
| 153 | // The Swiss-army-knife for sparse tensor creation. |
| 154 | void *_mlir_ciface_newSparseTensor( // NOLINT |
| 155 | StridedMemRefType<index_type, 1> *dimSizesRef, |
| 156 | StridedMemRefType<index_type, 1> *lvlSizesRef, |
| 157 | StridedMemRefType<LevelType, 1> *lvlTypesRef, |
| 158 | StridedMemRefType<index_type, 1> *dim2lvlRef, |
| 159 | StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp, |
| 160 | OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) { |
| 161 | ASSERT_NO_STRIDE(dimSizesRef); |
| 162 | ASSERT_NO_STRIDE(lvlSizesRef); |
| 163 | ASSERT_NO_STRIDE(lvlTypesRef); |
| 164 | ASSERT_NO_STRIDE(dim2lvlRef); |
| 165 | ASSERT_NO_STRIDE(lvl2dimRef); |
| 166 | const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef); |
| 167 | const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef); |
| 168 | ASSERT_USIZE_EQ(lvlTypesRef, lvlRank); |
| 169 | ASSERT_USIZE_EQ(dim2lvlRef, lvlRank); |
| 170 | ASSERT_USIZE_EQ(lvl2dimRef, dimRank); |
| 171 | const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef); |
| 172 | const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef); |
| 173 | const LevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef); |
| 174 | const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); |
| 175 | const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); |
| 176 | |
| 177 | // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. |
| 178 | // This is safe because of the static_assert above. |
| 179 | if (posTp == OverheadType::kIndex) |
| 180 | posTp = OverheadType::kU64; |
| 181 | if (crdTp == OverheadType::kIndex) |
| 182 | crdTp = OverheadType::kU64; |
| 183 | |
| 184 | // Double matrices with all combinations of overhead storage. |
| 185 | CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, |
| 186 | uint64_t, double); |
| 187 | CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t, |
| 188 | uint32_t, double); |
| 189 | CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t, |
| 190 | uint16_t, double); |
| 191 | CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t, |
| 192 | uint8_t, double); |
| 193 | CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t, |
| 194 | uint64_t, double); |
| 195 | CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t, |
| 196 | uint32_t, double); |
| 197 | CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t, |
| 198 | uint16_t, double); |
| 199 | CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t, |
| 200 | uint8_t, double); |
| 201 | CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t, |
| 202 | uint64_t, double); |
| 203 | CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t, |
| 204 | uint32_t, double); |
| 205 | CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t, |
| 206 | uint16_t, double); |
| 207 | CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t, |
| 208 | uint8_t, double); |
| 209 | CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t, |
| 210 | uint64_t, double); |
| 211 | CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t, |
| 212 | uint32_t, double); |
| 213 | CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t, |
| 214 | uint16_t, double); |
| 215 | CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t, |
| 216 | uint8_t, double); |
| 217 | |
| 218 | // Float matrices with all combinations of overhead storage. |
| 219 | CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t, |
| 220 | uint64_t, float); |
| 221 | CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t, |
| 222 | uint32_t, float); |
| 223 | CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t, |
| 224 | uint16_t, float); |
| 225 | CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t, |
| 226 | uint8_t, float); |
| 227 | CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t, |
| 228 | uint64_t, float); |
| 229 | CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t, |
| 230 | uint32_t, float); |
| 231 | CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t, |
| 232 | uint16_t, float); |
| 233 | CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t, |
| 234 | uint8_t, float); |
| 235 | CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t, |
| 236 | uint64_t, float); |
| 237 | CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t, |
| 238 | uint32_t, float); |
| 239 | CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t, |
| 240 | uint16_t, float); |
| 241 | CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t, |
| 242 | uint8_t, float); |
| 243 | CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t, |
| 244 | uint64_t, float); |
| 245 | CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t, |
| 246 | uint32_t, float); |
| 247 | CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t, |
| 248 | uint16_t, float); |
| 249 | CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t, |
| 250 | uint8_t, float); |
| 251 | |
| 252 | // Two-byte floats with both overheads of the same type. |
| 253 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16); |
| 254 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16); |
| 255 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16); |
| 256 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16); |
| 257 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16); |
| 258 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16); |
| 259 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16); |
| 260 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16); |
| 261 | |
| 262 | // Integral matrices with both overheads of the same type. |
| 263 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t); |
| 264 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t); |
| 265 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t); |
| 266 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t); |
| 267 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t); |
| 268 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t); |
| 269 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t); |
| 270 | CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t); |
| 271 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t); |
| 272 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t); |
| 273 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t); |
| 274 | CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t); |
| 275 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t); |
| 276 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t); |
| 277 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); |
| 278 | CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); |
| 279 | |
| 280 | // Complex matrices with wide overhead. |
| 281 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64); |
| 282 | CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32); |
| 283 | |
| 284 | // Unsupported case (add above if needed). |
| 285 | fprintf(stderr, format: "unsupported combination of types: <P=%d, C=%d, V=%d>\n" , |
| 286 | static_cast<int>(posTp), static_cast<int>(crdTp), |
| 287 | static_cast<int>(valTp)); |
| 288 | exit(status: 1); |
| 289 | } |
| 290 | #undef CASE |
| 291 | #undef CASE_SECSAME |
| 292 | |
| 293 | #define IMPL_SPARSEVALUES(VNAME, V) \ |
| 294 | void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \ |
| 295 | void *tensor) { \ |
| 296 | assert(ref &&tensor); \ |
| 297 | std::vector<V> *v; \ |
| 298 | static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \ |
| 299 | assert(v); \ |
| 300 | aliasIntoMemref(v->size(), v->data(), *ref); \ |
| 301 | } |
| 302 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES) |
| 303 | #undef IMPL_SPARSEVALUES |
| 304 | |
| 305 | #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \ |
| 306 | void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \ |
| 307 | index_type lvl) { \ |
| 308 | assert(ref &&tensor); \ |
| 309 | std::vector<TYPE> *v; \ |
| 310 | static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, lvl); \ |
| 311 | assert(v); \ |
| 312 | aliasIntoMemref(v->size(), v->data(), *ref); \ |
| 313 | } |
| 314 | |
| 315 | #define IMPL_SPARSEPOSITIONS(PNAME, P) \ |
| 316 | IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions) |
| 317 | MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS) |
| 318 | #undef IMPL_SPARSEPOSITIONS |
| 319 | |
| 320 | #define IMPL_SPARSECOORDINATES(CNAME, C) \ |
| 321 | IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates) |
| 322 | MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES) |
| 323 | #undef IMPL_SPARSECOORDINATES |
| 324 | |
| 325 | #define IMPL_SPARSECOORDINATESBUFFER(CNAME, C) \ |
| 326 | IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer) |
| 327 | MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER) |
| 328 | #undef IMPL_SPARSECOORDINATESBUFFER |
| 329 | |
| 330 | #undef IMPL_GETOVERHEAD |
| 331 | |
| 332 | #define IMPL_LEXINSERT(VNAME, V) \ |
| 333 | void _mlir_ciface_lexInsert##VNAME( \ |
| 334 | void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \ |
| 335 | StridedMemRefType<V, 0> *vref) { \ |
| 336 | assert(t &&vref); \ |
| 337 | auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \ |
| 338 | ASSERT_NO_STRIDE(lvlCoordsRef); \ |
| 339 | index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \ |
| 340 | assert(lvlCoords); \ |
| 341 | V *value = MEMREF_GET_PAYLOAD(vref); \ |
| 342 | tensor.lexInsert(lvlCoords, *value); \ |
| 343 | } |
| 344 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT) |
| 345 | #undef IMPL_LEXINSERT |
| 346 | |
| 347 | #define IMPL_EXPINSERT(VNAME, V) \ |
| 348 | void _mlir_ciface_expInsert##VNAME( \ |
| 349 | void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \ |
| 350 | StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \ |
| 351 | StridedMemRefType<index_type, 1> *aref, index_type count) { \ |
| 352 | assert(t); \ |
| 353 | auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \ |
| 354 | ASSERT_NO_STRIDE(lvlCoordsRef); \ |
| 355 | ASSERT_NO_STRIDE(vref); \ |
| 356 | ASSERT_NO_STRIDE(fref); \ |
| 357 | ASSERT_NO_STRIDE(aref); \ |
| 358 | ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref)); \ |
| 359 | index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \ |
| 360 | V *values = MEMREF_GET_PAYLOAD(vref); \ |
| 361 | bool *filled = MEMREF_GET_PAYLOAD(fref); \ |
| 362 | index_type *added = MEMREF_GET_PAYLOAD(aref); \ |
| 363 | uint64_t expsz = vref->sizes[0]; \ |
| 364 | tensor.expInsert(lvlCoords, values, filled, added, count, expsz); \ |
| 365 | } |
| 366 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT) |
| 367 | #undef IMPL_EXPINSERT |
| 368 | |
| 369 | void *_mlir_ciface_createCheckedSparseTensorReader( |
| 370 | char *filename, StridedMemRefType<index_type, 1> *dimShapeRef, |
| 371 | PrimaryType valTp) { |
| 372 | ASSERT_NO_STRIDE(dimShapeRef); |
| 373 | const uint64_t dimRank = MEMREF_GET_USIZE(dimShapeRef); |
| 374 | const index_type *dimShape = MEMREF_GET_PAYLOAD(dimShapeRef); |
| 375 | auto *reader = SparseTensorReader::create(filename, dimRank, dimShape, valTp); |
| 376 | return static_cast<void *>(reader); |
| 377 | } |
| 378 | |
| 379 | void _mlir_ciface_getSparseTensorReaderDimSizes( |
| 380 | StridedMemRefType<index_type, 1> *out, void *p) { |
| 381 | assert(out && p); |
| 382 | SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p); |
| 383 | auto *dimSizes = const_cast<uint64_t *>(reader.getDimSizes()); |
| 384 | aliasIntoMemref(size: reader.getRank(), data: dimSizes, ref&: *out); |
| 385 | } |
| 386 | |
| 387 | #define IMPL_GETNEXT(VNAME, V, CNAME, C) \ |
| 388 | bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME( \ |
| 389 | void *p, StridedMemRefType<index_type, 1> *dim2lvlRef, \ |
| 390 | StridedMemRefType<index_type, 1> *lvl2dimRef, \ |
| 391 | StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) { \ |
| 392 | assert(p); \ |
| 393 | auto &reader = *static_cast<SparseTensorReader *>(p); \ |
| 394 | ASSERT_NO_STRIDE(dim2lvlRef); \ |
| 395 | ASSERT_NO_STRIDE(lvl2dimRef); \ |
| 396 | ASSERT_NO_STRIDE(cref); \ |
| 397 | ASSERT_NO_STRIDE(vref); \ |
| 398 | const uint64_t dimRank = reader.getRank(); \ |
| 399 | const uint64_t lvlRank = MEMREF_GET_USIZE(dim2lvlRef); \ |
| 400 | const uint64_t cSize = MEMREF_GET_USIZE(cref); \ |
| 401 | const uint64_t vSize = MEMREF_GET_USIZE(vref); \ |
| 402 | ASSERT_USIZE_EQ(lvl2dimRef, dimRank); \ |
| 403 | assert(cSize >= lvlRank * reader.getNSE()); \ |
| 404 | assert(vSize >= reader.getNSE()); \ |
| 405 | (void)dimRank; \ |
| 406 | (void)cSize; \ |
| 407 | (void)vSize; \ |
| 408 | index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \ |
| 409 | index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); \ |
| 410 | C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref); \ |
| 411 | V *values = MEMREF_GET_PAYLOAD(vref); \ |
| 412 | return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim, \ |
| 413 | lvlCoordinates, values); \ |
| 414 | } |
| 415 | MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT) |
| 416 | #undef IMPL_GETNEXT |
| 417 | |
| 418 | void _mlir_ciface_outSparseTensorWriterMetaData( |
| 419 | void *p, index_type dimRank, index_type nse, |
| 420 | StridedMemRefType<index_type, 1> *dimSizesRef) { |
| 421 | assert(p); |
| 422 | ASSERT_NO_STRIDE(dimSizesRef); |
| 423 | assert(dimRank != 0); |
| 424 | index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef); |
| 425 | std::ostream &file = *static_cast<std::ostream *>(p); |
| 426 | file << dimRank << " " << nse << '\n'; |
| 427 | for (index_type d = 0; d < dimRank - 1; d++) |
| 428 | file << dimSizes[d] << " " ; |
| 429 | file << dimSizes[dimRank - 1] << '\n'; |
| 430 | } |
| 431 | |
| 432 | #define IMPL_OUTNEXT(VNAME, V) \ |
| 433 | void _mlir_ciface_outSparseTensorWriterNext##VNAME( \ |
| 434 | void *p, index_type dimRank, \ |
| 435 | StridedMemRefType<index_type, 1> *dimCoordsRef, \ |
| 436 | StridedMemRefType<V, 0> *vref) { \ |
| 437 | assert(p &&vref); \ |
| 438 | ASSERT_NO_STRIDE(dimCoordsRef); \ |
| 439 | const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \ |
| 440 | std::ostream &file = *static_cast<std::ostream *>(p); \ |
| 441 | for (index_type d = 0; d < dimRank; d++) \ |
| 442 | file << (dimCoords[d] + 1) << " "; \ |
| 443 | V *value = MEMREF_GET_PAYLOAD(vref); \ |
| 444 | file << *value << '\n'; \ |
| 445 | } |
| 446 | MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT) |
| 447 | #undef IMPL_OUTNEXT |
| 448 | |
| 449 | //===----------------------------------------------------------------------===// |
| 450 | // |
| 451 | // Public functions which accept only C-style data structures to interact |
| 452 | // with sparse tensors (which are only visible as opaque pointers externally). |
| 453 | // |
| 454 | //===----------------------------------------------------------------------===// |
| 455 | |
| 456 | index_type sparseLvlSize(void *tensor, index_type l) { |
| 457 | return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l); |
| 458 | } |
| 459 | |
| 460 | index_type sparseDimSize(void *tensor, index_type d) { |
| 461 | return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d); |
| 462 | } |
| 463 | |
| 464 | void endLexInsert(void *tensor) { |
| 465 | return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert(); |
| 466 | } |
| 467 | |
| 468 | void delSparseTensor(void *tensor) { |
| 469 | delete static_cast<SparseTensorStorageBase *>(tensor); |
| 470 | } |
| 471 | |
| 472 | char *getTensorFilename(index_type id) { |
| 473 | constexpr size_t bufSize = 80; |
| 474 | char var[bufSize]; |
| 475 | snprintf(s: var, maxlen: bufSize, format: "TENSOR%" PRIu64, id); |
| 476 | char *env = getenv(name: var); |
| 477 | if (!env) { |
| 478 | fprintf(stderr, format: "Environment variable %s is not set\n" , var); |
| 479 | exit(status: 1); |
| 480 | } |
| 481 | return env; |
| 482 | } |
| 483 | |
| 484 | index_type getSparseTensorReaderNSE(void *p) { |
| 485 | return static_cast<SparseTensorReader *>(p)->getNSE(); |
| 486 | } |
| 487 | |
| 488 | void delSparseTensorReader(void *p) { |
| 489 | delete static_cast<SparseTensorReader *>(p); |
| 490 | } |
| 491 | |
| 492 | void *createSparseTensorWriter(char *filename) { |
| 493 | std::ostream *file = |
| 494 | (filename[0] == 0) ? &std::cout : new std::ofstream(filename); |
| 495 | *file << "# extended FROSTT format\n" ; |
| 496 | return static_cast<void *>(file); |
| 497 | } |
| 498 | |
| 499 | void delSparseTensorWriter(void *p) { |
| 500 | std::ostream *file = static_cast<std::ostream *>(p); |
| 501 | file->flush(); |
| 502 | assert(file->good()); |
| 503 | if (file != &std::cout) |
| 504 | delete file; |
| 505 | } |
| 506 | |
| 507 | } // extern "C" |
| 508 | |
| 509 | #undef MEMREF_GET_PAYLOAD |
| 510 | #undef ASSERT_USIZE_EQ |
| 511 | #undef MEMREF_GET_USIZE |
| 512 | #undef ASSERT_NO_STRIDE |
| 513 | |
| 514 | #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS |
| 515 | |