| 1 | //===- File.cpp - Reading/writing sparse tensors from/to files ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements reading and writing sparse tensor files. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/ExecutionEngine/SparseTensor/File.h" |
| 14 | |
| 15 | #include <cctype> |
| 16 | #include <cstring> |
| 17 | |
| 18 | using namespace mlir::sparse_tensor; |
| 19 | |
| 20 | /// Opens the file for reading. |
| 21 | void SparseTensorReader::openFile() { |
| 22 | if (file) { |
| 23 | fprintf(stderr, format: "Already opened file %s\n" , filename); |
| 24 | exit(status: 1); |
| 25 | } |
| 26 | file = fopen(filename: filename, modes: "r" ); |
| 27 | if (!file) { |
| 28 | fprintf(stderr, format: "Cannot find file %s\n" , filename); |
| 29 | exit(status: 1); |
| 30 | } |
| 31 | } |
| 32 | |
| 33 | /// Closes the file. |
| 34 | void SparseTensorReader::closeFile() { |
| 35 | if (file) { |
| 36 | fclose(stream: file); |
| 37 | file = nullptr; |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | /// Attempts to read a line from the file. |
| 42 | void SparseTensorReader::readLine() { |
| 43 | if (!fgets(s: line, n: kColWidth, stream: file)) { |
| 44 | fprintf(stderr, format: "Cannot read next line of %s\n" , filename); |
| 45 | exit(status: 1); |
| 46 | } |
| 47 | } |
| 48 | |
| 49 | /// Reads and parses the file's header. |
| 50 | void SparseTensorReader::() { |
| 51 | assert(file && "Attempt to readHeader() before openFile()" ); |
| 52 | if (strstr(haystack: filename, needle: ".mtx" )) { |
| 53 | readMMEHeader(); |
| 54 | } else if (strstr(haystack: filename, needle: ".tns" )) { |
| 55 | readExtFROSTTHeader(); |
| 56 | } else { |
| 57 | fprintf(stderr, format: "Unknown format %s\n" , filename); |
| 58 | exit(status: 1); |
| 59 | } |
| 60 | assert(isValid() && "Failed to read the header" ); |
| 61 | } |
| 62 | |
| 63 | /// Asserts the shape subsumes the actual dimension sizes. Is only |
| 64 | /// valid after parsing the header. |
| 65 | void SparseTensorReader::assertMatchesShape(uint64_t rank, |
| 66 | const uint64_t *shape) const { |
| 67 | assert(rank == getRank() && "Rank mismatch" ); |
| 68 | for (uint64_t r = 0; r < rank; r++) |
| 69 | assert((shape[r] == 0 || shape[r] == idata[2 + r]) && |
| 70 | "Dimension size mismatch" ); |
| 71 | } |
| 72 | |
| 73 | bool SparseTensorReader::canReadAs(PrimaryType valTy) const { |
| 74 | switch (valueKind_) { |
| 75 | case ValueKind::kInvalid: |
| 76 | assert(false && "Must readHeader() before calling canReadAs()" ); |
| 77 | return false; // In case assertions are disabled. |
| 78 | case ValueKind::kPattern: |
| 79 | return true; |
| 80 | case ValueKind::kInteger: |
| 81 | // When the file is specified to store integer values, we still |
| 82 | // allow implicitly converting those to floating primary-types. |
| 83 | return isRealPrimaryType(valTy); |
| 84 | case ValueKind::kReal: |
| 85 | // When the file is specified to store real/floating values, then |
| 86 | // we disallow implicit conversion to integer primary-types. |
| 87 | return isFloatingPrimaryType(valTy); |
| 88 | case ValueKind::kComplex: |
| 89 | // When the file is specified to store complex values, then we |
| 90 | // require a complex primary-type. |
| 91 | return isComplexPrimaryType(valTy); |
| 92 | case ValueKind::kUndefined: |
| 93 | // The "extended" FROSTT format doesn't specify a ValueKind. |
| 94 | // So we allow implicitly converting the stored values to both |
| 95 | // integer and floating primary-types. |
| 96 | return isRealPrimaryType(valTy); |
| 97 | } |
| 98 | fprintf(stderr, format: "Unknown ValueKind: %d\n" , static_cast<uint8_t>(valueKind_)); |
| 99 | return false; |
| 100 | } |
| 101 | |
| 102 | /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case. |
| 103 | static inline void toLower(char *token) { |
| 104 | for (char *c = token; *c; c++) |
| 105 | *c = tolower(c: *c); |
| 106 | } |
| 107 | |
| 108 | /// Idiomatic name for checking string equality. |
| 109 | static inline bool streq(const char *lhs, const char *rhs) { |
| 110 | return strcmp(s1: lhs, s2: rhs) == 0; |
| 111 | } |
| 112 | |
| 113 | /// Idiomatic name for checking string inequality. |
| 114 | static inline bool strne(const char *lhs, const char *rhs) { |
| 115 | return strcmp(s1: lhs, s2: rhs); // aka `!= 0` |
| 116 | } |
| 117 | |
| 118 | /// Read the MME header of a general sparse matrix of type real. |
| 119 | void SparseTensorReader::() { |
| 120 | char [64]; |
| 121 | char object[64]; |
| 122 | char format[64]; |
| 123 | char field[64]; |
| 124 | char symmetry[64]; |
| 125 | // Read header line. |
| 126 | if (fscanf(stream: file, format: "%63s %63s %63s %63s %63s\n" , header, object, format, field, |
| 127 | symmetry) != 5) { |
| 128 | fprintf(stderr, format: "Corrupt header in %s\n" , filename); |
| 129 | exit(status: 1); |
| 130 | } |
| 131 | // Convert all to lowercase up front (to avoid accidental redundancy). |
| 132 | toLower(token: header); |
| 133 | toLower(token: object); |
| 134 | toLower(token: format); |
| 135 | toLower(token: field); |
| 136 | toLower(token: symmetry); |
| 137 | // Process `field`, which specify pattern or the data type of the values. |
| 138 | if (streq(lhs: field, rhs: "pattern" )) { |
| 139 | valueKind_ = ValueKind::kPattern; |
| 140 | } else if (streq(lhs: field, rhs: "real" )) { |
| 141 | valueKind_ = ValueKind::kReal; |
| 142 | } else if (streq(lhs: field, rhs: "integer" )) { |
| 143 | valueKind_ = ValueKind::kInteger; |
| 144 | } else if (streq(lhs: field, rhs: "complex" )) { |
| 145 | valueKind_ = ValueKind::kComplex; |
| 146 | } else { |
| 147 | fprintf(stderr, format: "Unexpected header field value in %s\n" , filename); |
| 148 | exit(status: 1); |
| 149 | } |
| 150 | // Set properties. |
| 151 | isSymmetric_ = streq(lhs: symmetry, rhs: "symmetric" ); |
| 152 | // Make sure this is a general sparse matrix. |
| 153 | if (strne(lhs: header, rhs: "%%matrixmarket" ) || strne(lhs: object, rhs: "matrix" ) || |
| 154 | strne(lhs: format, rhs: "coordinate" ) || |
| 155 | (strne(lhs: symmetry, rhs: "general" ) && !isSymmetric_)) { |
| 156 | fprintf(stderr, format: "Cannot find a general sparse matrix in %s\n" , filename); |
| 157 | exit(status: 1); |
| 158 | } |
| 159 | // Skip comments. |
| 160 | while (true) { |
| 161 | readLine(); |
| 162 | if (line[0] != '%') |
| 163 | break; |
| 164 | } |
| 165 | // Next line contains M N NNZ. |
| 166 | idata[0] = 2; // rank |
| 167 | if (sscanf(s: line, format: "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n" , idata + 2, idata + 3, |
| 168 | idata + 1) != 3) { |
| 169 | fprintf(stderr, format: "Cannot find size in %s\n" , filename); |
| 170 | exit(status: 1); |
| 171 | } |
| 172 | } |
| 173 | |
| 174 | /// Read the "extended" FROSTT header. Although not part of the documented |
| 175 | /// format, we assume that the file starts with optional comments followed |
| 176 | /// by two lines that define the rank, the number of nonzeros, and the |
| 177 | /// dimensions sizes (one per rank) of the sparse tensor. |
| 178 | void SparseTensorReader::() { |
| 179 | // Skip comments. |
| 180 | while (true) { |
| 181 | readLine(); |
| 182 | if (line[0] != '#') |
| 183 | break; |
| 184 | } |
| 185 | // Next line contains RANK and NNZ. |
| 186 | if (sscanf(s: line, format: "%" PRIu64 "%" PRIu64 "\n" , idata, idata + 1) != 2) { |
| 187 | fprintf(stderr, format: "Cannot find metadata in %s\n" , filename); |
| 188 | exit(status: 1); |
| 189 | } |
| 190 | // Followed by a line with the dimension sizes (one per rank). |
| 191 | for (uint64_t r = 0; r < idata[0]; r++) { |
| 192 | if (fscanf(stream: file, format: "%" PRIu64, idata + 2 + r) != 1) { |
| 193 | fprintf(stderr, format: "Cannot find dimension size %s\n" , filename); |
| 194 | exit(status: 1); |
| 195 | } |
| 196 | } |
| 197 | readLine(); // end of line |
| 198 | // The FROSTT format does not define the data type of the nonzero elements. |
| 199 | valueKind_ = ValueKind::kUndefined; |
| 200 | } |
| 201 | |