1 | //===- File.cpp - Reading/writing sparse tensors from/to files ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements reading and writing sparse tensor files. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/ExecutionEngine/SparseTensor/File.h" |
14 | |
15 | #include <cctype> |
16 | #include <cstring> |
17 | |
18 | using namespace mlir::sparse_tensor; |
19 | |
20 | /// Opens the file for reading. |
21 | void SparseTensorReader::openFile() { |
22 | if (file) { |
23 | fprintf(stderr, format: "Already opened file %s\n" , filename); |
24 | exit(status: 1); |
25 | } |
26 | file = fopen(filename: filename, modes: "r" ); |
27 | if (!file) { |
28 | fprintf(stderr, format: "Cannot find file %s\n" , filename); |
29 | exit(status: 1); |
30 | } |
31 | } |
32 | |
33 | /// Closes the file. |
34 | void SparseTensorReader::closeFile() { |
35 | if (file) { |
36 | fclose(stream: file); |
37 | file = nullptr; |
38 | } |
39 | } |
40 | |
41 | /// Attempts to read a line from the file. |
42 | void SparseTensorReader::readLine() { |
43 | if (!fgets(s: line, n: kColWidth, stream: file)) { |
44 | fprintf(stderr, format: "Cannot read next line of %s\n" , filename); |
45 | exit(status: 1); |
46 | } |
47 | } |
48 | |
49 | /// Reads and parses the file's header. |
50 | void SparseTensorReader::() { |
51 | assert(file && "Attempt to readHeader() before openFile()" ); |
52 | if (strstr(haystack: filename, needle: ".mtx" )) { |
53 | readMMEHeader(); |
54 | } else if (strstr(haystack: filename, needle: ".tns" )) { |
55 | readExtFROSTTHeader(); |
56 | } else { |
57 | fprintf(stderr, format: "Unknown format %s\n" , filename); |
58 | exit(status: 1); |
59 | } |
60 | assert(isValid() && "Failed to read the header" ); |
61 | } |
62 | |
63 | /// Asserts the shape subsumes the actual dimension sizes. Is only |
64 | /// valid after parsing the header. |
65 | void SparseTensorReader::assertMatchesShape(uint64_t rank, |
66 | const uint64_t *shape) const { |
67 | assert(rank == getRank() && "Rank mismatch" ); |
68 | for (uint64_t r = 0; r < rank; r++) |
69 | assert((shape[r] == 0 || shape[r] == idata[2 + r]) && |
70 | "Dimension size mismatch" ); |
71 | } |
72 | |
73 | bool SparseTensorReader::canReadAs(PrimaryType valTy) const { |
74 | switch (valueKind_) { |
75 | case ValueKind::kInvalid: |
76 | assert(false && "Must readHeader() before calling canReadAs()" ); |
77 | return false; // In case assertions are disabled. |
78 | case ValueKind::kPattern: |
79 | return true; |
80 | case ValueKind::kInteger: |
81 | // When the file is specified to store integer values, we still |
82 | // allow implicitly converting those to floating primary-types. |
83 | return isRealPrimaryType(valTy); |
84 | case ValueKind::kReal: |
85 | // When the file is specified to store real/floating values, then |
86 | // we disallow implicit conversion to integer primary-types. |
87 | return isFloatingPrimaryType(valTy); |
88 | case ValueKind::kComplex: |
89 | // When the file is specified to store complex values, then we |
90 | // require a complex primary-type. |
91 | return isComplexPrimaryType(valTy); |
92 | case ValueKind::kUndefined: |
93 | // The "extended" FROSTT format doesn't specify a ValueKind. |
94 | // So we allow implicitly converting the stored values to both |
95 | // integer and floating primary-types. |
96 | return isRealPrimaryType(valTy); |
97 | } |
98 | fprintf(stderr, format: "Unknown ValueKind: %d\n" , static_cast<uint8_t>(valueKind_)); |
99 | return false; |
100 | } |
101 | |
102 | /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case. |
103 | static inline void toLower(char *token) { |
104 | for (char *c = token; *c; c++) |
105 | *c = tolower(c: *c); |
106 | } |
107 | |
108 | /// Idiomatic name for checking string equality. |
109 | static inline bool streq(const char *lhs, const char *rhs) { |
110 | return strcmp(s1: lhs, s2: rhs) == 0; |
111 | } |
112 | |
113 | /// Idiomatic name for checking string inequality. |
114 | static inline bool strne(const char *lhs, const char *rhs) { |
115 | return strcmp(s1: lhs, s2: rhs); // aka `!= 0` |
116 | } |
117 | |
118 | /// Read the MME header of a general sparse matrix of type real. |
119 | void SparseTensorReader::() { |
120 | char [64]; |
121 | char object[64]; |
122 | char format[64]; |
123 | char field[64]; |
124 | char symmetry[64]; |
125 | // Read header line. |
126 | if (fscanf(stream: file, format: "%63s %63s %63s %63s %63s\n" , header, object, format, field, |
127 | symmetry) != 5) { |
128 | fprintf(stderr, format: "Corrupt header in %s\n" , filename); |
129 | exit(status: 1); |
130 | } |
131 | // Convert all to lowercase up front (to avoid accidental redundancy). |
132 | toLower(token: header); |
133 | toLower(token: object); |
134 | toLower(token: format); |
135 | toLower(token: field); |
136 | toLower(token: symmetry); |
137 | // Process `field`, which specify pattern or the data type of the values. |
138 | if (streq(lhs: field, rhs: "pattern" )) { |
139 | valueKind_ = ValueKind::kPattern; |
140 | } else if (streq(lhs: field, rhs: "real" )) { |
141 | valueKind_ = ValueKind::kReal; |
142 | } else if (streq(lhs: field, rhs: "integer" )) { |
143 | valueKind_ = ValueKind::kInteger; |
144 | } else if (streq(lhs: field, rhs: "complex" )) { |
145 | valueKind_ = ValueKind::kComplex; |
146 | } else { |
147 | fprintf(stderr, format: "Unexpected header field value in %s\n" , filename); |
148 | exit(status: 1); |
149 | } |
150 | // Set properties. |
151 | isSymmetric_ = streq(lhs: symmetry, rhs: "symmetric" ); |
152 | // Make sure this is a general sparse matrix. |
153 | if (strne(lhs: header, rhs: "%%matrixmarket" ) || strne(lhs: object, rhs: "matrix" ) || |
154 | strne(lhs: format, rhs: "coordinate" ) || |
155 | (strne(lhs: symmetry, rhs: "general" ) && !isSymmetric_)) { |
156 | fprintf(stderr, format: "Cannot find a general sparse matrix in %s\n" , filename); |
157 | exit(status: 1); |
158 | } |
159 | // Skip comments. |
160 | while (true) { |
161 | readLine(); |
162 | if (line[0] != '%') |
163 | break; |
164 | } |
165 | // Next line contains M N NNZ. |
166 | idata[0] = 2; // rank |
167 | if (sscanf(s: line, format: "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n" , idata + 2, idata + 3, |
168 | idata + 1) != 3) { |
169 | fprintf(stderr, format: "Cannot find size in %s\n" , filename); |
170 | exit(status: 1); |
171 | } |
172 | } |
173 | |
174 | /// Read the "extended" FROSTT header. Although not part of the documented |
175 | /// format, we assume that the file starts with optional comments followed |
176 | /// by two lines that define the rank, the number of nonzeros, and the |
177 | /// dimensions sizes (one per rank) of the sparse tensor. |
178 | void SparseTensorReader::() { |
179 | // Skip comments. |
180 | while (true) { |
181 | readLine(); |
182 | if (line[0] != '#') |
183 | break; |
184 | } |
185 | // Next line contains RANK and NNZ. |
186 | if (sscanf(s: line, format: "%" PRIu64 "%" PRIu64 "\n" , idata, idata + 1) != 2) { |
187 | fprintf(stderr, format: "Cannot find metadata in %s\n" , filename); |
188 | exit(status: 1); |
189 | } |
190 | // Followed by a line with the dimension sizes (one per rank). |
191 | for (uint64_t r = 0; r < idata[0]; r++) { |
192 | if (fscanf(stream: file, format: "%" PRIu64, idata + 2 + r) != 1) { |
193 | fprintf(stderr, format: "Cannot find dimension size %s\n" , filename); |
194 | exit(status: 1); |
195 | } |
196 | } |
197 | readLine(); // end of line |
198 | // The FROSTT format does not define the data type of the nonzero elements. |
199 | valueKind_ = ValueKind::kUndefined; |
200 | } |
201 | |