1//===- Enums.h - Enums for the SparseTensor dialect -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Typedefs and enums shared between MLIR code for manipulating the
10// IR, and the lightweight runtime support library for sparse tensor
11// manipulations. That is, all the enums are used to define the API
12// of the runtime library and hence are also needed when generating
13// calls into the runtime library. Moveover, the `LevelType` enum
14// is also used as the internal IR encoding of dimension level types,
15// to avoid code duplication (e.g., for the predicates).
16//
17// This file also defines x-macros <https://en.wikipedia.org/wiki/X_Macro>
18// so that we can generate variations of the public functions for each
19// supported primary- and/or overhead-type.
20//
21// Because this file defines a library which is a dependency of the
22// runtime library itself, this file must not depend on any MLIR internals
23// (e.g., operators, attributes, ArrayRefs, etc) lest the runtime library
24// inherit those dependencies.
25//
26//===----------------------------------------------------------------------===//
27
28#ifndef MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
29#define MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
30
31// NOTE: Client code will need to include "mlir/ExecutionEngine/Float16bits.h"
32// if they want to use the `MLIR_SPARSETENSOR_FOREVERY_V` macro.
33
34#include <cassert>
35#include <cinttypes>
36#include <complex>
37#include <optional>
38#include <vector>
39
40namespace mlir {
41namespace sparse_tensor {
42
43/// This type is used in the public API at all places where MLIR expects
44/// values with the built-in type "index". For now, we simply assume that
45/// type is 64-bit, but targets with different "index" bitwidths should
46/// link with an alternatively built runtime support library.
47using index_type = uint64_t;
48
49/// Encoding of overhead types (both position overhead and coordinate
50/// overhead), for "overloading" @newSparseTensor.
51enum class OverheadType : uint32_t {
52 kIndex = 0,
53 kU64 = 1,
54 kU32 = 2,
55 kU16 = 3,
56 kU8 = 4
57};
58
59// This x-macro calls its argument on every overhead type which has
60// fixed-width. It excludes `index_type` because that type is often
61// handled specially (e.g., by translating it into the architecture-dependent
62// equivalent fixed-width overhead type).
63#define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \
64 DO(64, uint64_t) \
65 DO(32, uint32_t) \
66 DO(16, uint16_t) \
67 DO(8, uint8_t)
68
69// This x-macro calls its argument on every overhead type, including
70// `index_type`.
71#define MLIR_SPARSETENSOR_FOREVERY_O(DO) \
72 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \
73 DO(0, index_type)
74
75// These are not just shorthands but indicate the particular
76// implementation used (e.g., as opposed to C99's `complex double`,
77// or MLIR's `ComplexType`).
78using complex64 = std::complex<double>;
79using complex32 = std::complex<float>;
80
81/// Encoding of the elemental type, for "overloading" @newSparseTensor.
82enum class PrimaryType : uint32_t {
83 kF64 = 1,
84 kF32 = 2,
85 kF16 = 3,
86 kBF16 = 4,
87 kI64 = 5,
88 kI32 = 6,
89 kI16 = 7,
90 kI8 = 8,
91 kC64 = 9,
92 kC32 = 10
93};
94
95// This x-macro includes all `V` types.
96#define MLIR_SPARSETENSOR_FOREVERY_V(DO) \
97 DO(F64, double) \
98 DO(F32, float) \
99 DO(F16, f16) \
100 DO(BF16, bf16) \
101 DO(I64, int64_t) \
102 DO(I32, int32_t) \
103 DO(I16, int16_t) \
104 DO(I8, int8_t) \
105 DO(C64, complex64) \
106 DO(C32, complex32)
107
108// This x-macro includes all `V` types and supports variadic arguments.
109#define MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, ...) \
110 DO(F64, double, __VA_ARGS__) \
111 DO(F32, float, __VA_ARGS__) \
112 DO(F16, f16, __VA_ARGS__) \
113 DO(BF16, bf16, __VA_ARGS__) \
114 DO(I64, int64_t, __VA_ARGS__) \
115 DO(I32, int32_t, __VA_ARGS__) \
116 DO(I16, int16_t, __VA_ARGS__) \
117 DO(I8, int8_t, __VA_ARGS__) \
118 DO(C64, complex64, __VA_ARGS__) \
119 DO(C32, complex32, __VA_ARGS__)
120
121// This x-macro calls its argument on every pair of overhead and `V` types.
122#define MLIR_SPARSETENSOR_FOREVERY_V_O(DO) \
123 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 64, uint64_t) \
124 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 32, uint32_t) \
125 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 16, uint16_t) \
126 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 8, uint8_t) \
127 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 0, index_type)
128
129constexpr bool isFloatingPrimaryType(PrimaryType valTy) {
130 return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kBF16;
131}
132
133constexpr bool isIntegralPrimaryType(PrimaryType valTy) {
134 return PrimaryType::kI64 <= valTy && valTy <= PrimaryType::kI8;
135}
136
137constexpr bool isRealPrimaryType(PrimaryType valTy) {
138 return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kI8;
139}
140
141constexpr bool isComplexPrimaryType(PrimaryType valTy) {
142 return PrimaryType::kC64 <= valTy && valTy <= PrimaryType::kC32;
143}
144
145/// The actions performed by @newSparseTensor.
146enum class Action : uint32_t {
147 kEmpty = 0,
148 kFromReader = 1,
149 kPack = 2,
150 kSortCOOInPlace = 3,
151};
152
153/// This enum defines all supported storage format without the level properties.
154enum class LevelFormat : uint64_t {
155 Undef = 0x00000000,
156 Dense = 0x00010000,
157 Batch = 0x00020000,
158 Compressed = 0x00040000,
159 Singleton = 0x00080000,
160 LooseCompressed = 0x00100000,
161 NOutOfM = 0x00200000,
162};
163
164constexpr bool encPowOfTwo(LevelFormat fmt) {
165 auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
166 return (enc & (enc - 1)) == 0;
167}
168
169// All LevelFormats must have only one bit set (power of two).
170static_assert(encPowOfTwo(fmt: LevelFormat::Dense) &&
171 encPowOfTwo(fmt: LevelFormat::Batch) &&
172 encPowOfTwo(fmt: LevelFormat::Compressed) &&
173 encPowOfTwo(fmt: LevelFormat::Singleton) &&
174 encPowOfTwo(fmt: LevelFormat::LooseCompressed) &&
175 encPowOfTwo(fmt: LevelFormat::NOutOfM));
176
177template <LevelFormat... targets>
178constexpr bool isAnyOfFmt(LevelFormat fmt) {
179 return (... || (targets == fmt));
180}
181
182/// Returns string representation of the given level format.
183constexpr const char *toFormatString(LevelFormat lvlFmt) {
184 switch (lvlFmt) {
185 case LevelFormat::Undef:
186 return "undef";
187 case LevelFormat::Dense:
188 return "dense";
189 case LevelFormat::Batch:
190 return "batch";
191 case LevelFormat::Compressed:
192 return "compressed";
193 case LevelFormat::Singleton:
194 return "singleton";
195 case LevelFormat::LooseCompressed:
196 return "loose_compressed";
197 case LevelFormat::NOutOfM:
198 return "structured";
199 }
200 return "";
201}
202
203/// This enum defines all the nondefault properties for storage formats.
204enum class LevelPropNonDefault : uint64_t {
205 Nonunique = 0x0001, // 0b001
206 Nonordered = 0x0002, // 0b010
207 SoA = 0x0004, // 0b100
208};
209
210/// Returns string representation of the given level properties.
211constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
212 switch (lvlProp) {
213 case LevelPropNonDefault::Nonunique:
214 return "nonunique";
215 case LevelPropNonDefault::Nonordered:
216 return "nonordered";
217 case LevelPropNonDefault::SoA:
218 return "soa";
219 }
220 return "";
221}
222
223/// This enum defines all the sparse representations supportable by
224/// the SparseTensor dialect. We use a lightweight encoding to encode
225/// the "format" per se (dense, compressed, singleton, loose_compressed,
226/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
227/// the format is NOutOfM.
228/// The encoding is chosen for performance of the runtime library, and thus may
229/// change in future versions; consequently, client code should use the
230/// predicate functions defined below, rather than relying on knowledge
231/// about the particular binary encoding.
232///
233/// The `Undef` "format" is a special value used internally for cases
234/// where we need to store an undefined or indeterminate `LevelType`.
235/// It should not be used externally, since it does not indicate an
236/// actual/representable format.
237
238struct LevelType {
239public:
240 /// Check that the `LevelType` contains a valid (possibly undefined) value.
241 static constexpr bool isValidLvlBits(uint64_t lvlBits) {
242 auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
243 const uint64_t propertyBits = lvlBits & 0xffff;
244 // If undefined/dense/batch/NOutOfM, then must be unique and ordered.
245 // Otherwise, the format must be one of the known ones.
246 return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
247 LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
248 ? (propertyBits == 0)
249 : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
250 LevelFormat::LooseCompressed>(fmt));
251 }
252
253 /// Convert a LevelFormat to its corresponding LevelType with the given
254 /// properties. Returns std::nullopt when the properties are not applicable
255 /// for the input level format.
256 static std::optional<LevelType>
257 buildLvlType(LevelFormat lf,
258 const std::vector<LevelPropNonDefault> &properties,
259 uint64_t n = 0, uint64_t m = 0) {
260 assert((n & 0xff) == n && (m & 0xff) == m);
261 uint64_t newN = n << 32;
262 uint64_t newM = m << 40;
263 uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
264 for (auto p : properties)
265 ltBits |= static_cast<uint64_t>(p);
266
267 return isValidLvlBits(lvlBits: ltBits) ? std::optional(LevelType(ltBits))
268 : std::nullopt;
269 }
270 static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
271 bool unique, uint64_t n = 0,
272 uint64_t m = 0) {
273 std::vector<LevelPropNonDefault> properties;
274 if (!ordered)
275 properties.push_back(x: LevelPropNonDefault::Nonordered);
276 if (!unique)
277 properties.push_back(x: LevelPropNonDefault::Nonunique);
278 return buildLvlType(lf, properties, n, m);
279 }
280
281 /// Explicit conversion from uint64_t.
282 constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
283 assert(isValidLvlBits(bits));
284 };
285
286 /// Constructs a LevelType with the given format using all default properties.
287 /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
288 assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
289 };
290
291 /// Converts to uint64_t
292 explicit operator uint64_t() const { return lvlBits; }
293
294 bool operator==(const LevelType lhs) const {
295 return static_cast<uint64_t>(lhs) == lvlBits;
296 }
297 bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
298
299 LevelType stripStorageIrrelevantProperties() const {
300 // Properties other than `SoA` do not change the storage scheme of the
301 // sparse tensor.
302 constexpr uint64_t mask =
303 0xffff & ~static_cast<uint64_t>(LevelPropNonDefault::SoA);
304 return LevelType(lvlBits & ~mask);
305 }
306
307 /// Get N of NOutOfM level type.
308 constexpr uint64_t getN() const {
309 assert(isa<LevelFormat::NOutOfM>());
310 return (lvlBits >> 32) & 0xff;
311 }
312
313 /// Get M of NOutOfM level type.
314 constexpr uint64_t getM() const {
315 assert(isa<LevelFormat::NOutOfM>());
316 return (lvlBits >> 40) & 0xff;
317 }
318
319 /// Get the `LevelFormat` of the `LevelType`.
320 constexpr LevelFormat getLvlFmt() const {
321 return static_cast<LevelFormat>(lvlBits & 0xffff0000);
322 }
323
324 /// Check if the `LevelType` is in the `LevelFormat`.
325 template <LevelFormat... fmt>
326 constexpr bool isa() const {
327 return (... || (getLvlFmt() == fmt)) || false;
328 }
329
330 /// Check if the `LevelType` has the properties
331 template <LevelPropNonDefault p>
332 constexpr bool isa() const {
333 return lvlBits & static_cast<uint64_t>(p);
334 }
335
336 /// Check if the `LevelType` is considered to be sparse.
337 constexpr bool hasSparseSemantic() const {
338 return isa<LevelFormat::Compressed, LevelFormat::Singleton,
339 LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
340 }
341
342 /// Check if the `LevelType` is considered to be dense-like.
343 constexpr bool hasDenseSemantic() const {
344 return isa<LevelFormat::Dense, LevelFormat::Batch>();
345 }
346
347 /// Check if the `LevelType` needs positions array.
348 constexpr bool isWithPosLT() const {
349 assert(!isa<LevelFormat::Undef>());
350 return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
351 }
352
353 /// Check if the `LevelType` needs coordinates array.
354 constexpr bool isWithCrdLT() const {
355 assert(!isa<LevelFormat::Undef>());
356 // All sparse levels has coordinate array.
357 return hasSparseSemantic();
358 }
359
360 std::string toMLIRString() const {
361 std::string lvlStr = toFormatString(lvlFmt: getLvlFmt());
362 std::string propStr = "";
363 if (isa<LevelFormat::NOutOfM>()) {
364 lvlStr +=
365 "[" + std::to_string(val: getN()) + ", " + std::to_string(val: getM()) + "]";
366 }
367 if (isa<LevelPropNonDefault::Nonunique>())
368 propStr += toPropString(lvlProp: LevelPropNonDefault::Nonunique);
369
370 if (isa<LevelPropNonDefault::Nonordered>()) {
371 if (!propStr.empty())
372 propStr += ", ";
373 propStr += toPropString(lvlProp: LevelPropNonDefault::Nonordered);
374 }
375 if (isa<LevelPropNonDefault::SoA>()) {
376 if (!propStr.empty())
377 propStr += ", ";
378 propStr += toPropString(lvlProp: LevelPropNonDefault::SoA);
379 }
380 if (!propStr.empty())
381 lvlStr += ("(" + propStr + ")");
382 return lvlStr;
383 }
384
385private:
386 /// Bit manipulations for LevelType:
387 ///
388 /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
389 ///
390 uint64_t lvlBits;
391};
392
393// For backward-compatibility. TODO: remove below after fully migration.
394constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
395constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
396
397inline std::optional<LevelType>
398buildLevelType(LevelFormat lf,
399 const std::vector<LevelPropNonDefault> &properties,
400 uint64_t n = 0, uint64_t m = 0) {
401 return LevelType::buildLvlType(lf, properties, n, m);
402}
403inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
404 bool unique, uint64_t n = 0,
405 uint64_t m = 0) {
406 return LevelType::buildLvlType(lf, ordered, unique, n, m);
407}
408inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
409inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
410inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
411inline bool isCompressedLT(LevelType lt) {
412 return lt.isa<LevelFormat::Compressed>();
413}
414inline bool isLooseCompressedLT(LevelType lt) {
415 return lt.isa<LevelFormat::LooseCompressed>();
416}
417inline bool isSingletonLT(LevelType lt) {
418 return lt.isa<LevelFormat::Singleton>();
419}
420inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); }
421inline bool isOrderedLT(LevelType lt) {
422 return !lt.isa<LevelPropNonDefault::Nonordered>();
423}
424inline bool isUniqueLT(LevelType lt) {
425 return !lt.isa<LevelPropNonDefault::Nonunique>();
426}
427inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); }
428inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); }
429inline bool isValidLT(LevelType lt) {
430 return LevelType::isValidLvlBits(lvlBits: static_cast<uint64_t>(lt));
431}
432inline std::optional<LevelFormat> getLevelFormat(LevelType lt) {
433 LevelFormat fmt = lt.getLvlFmt();
434 if (fmt == LevelFormat::Undef)
435 return std::nullopt;
436 return fmt;
437}
438inline uint64_t getN(LevelType lt) { return lt.getN(); }
439inline uint64_t getM(LevelType lt) { return lt.getM(); }
440inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
441 return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m;
442}
443inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); }
444
445/// Bit manipulations for affine encoding.
446///
447/// Note that because the indices in the mappings refer to dimensions
448/// and levels (and *not* the sizes of these dimensions and levels), the
449/// 64-bit encoding gives ample room for a compact encoding of affine
450/// operations in the higher bits. Pure permutations still allow for
451/// 60-bit indices. But non-permutations reserve 20-bits for the
452/// potential three components (index i, constant, index ii).
453///
454/// The compact encoding is as follows:
455///
456/// 0xffffffffffffffff
457/// |0000 | 60-bit idx| e.g. i
458/// |0001 floor| 20-bit const|20-bit idx| e.g. i floor c
459/// |0010 mod | 20-bit const|20-bit idx| e.g. i mod c
460/// |0011 mul |20-bit idx|20-bit const|20-bit idx| e.g. i + c * ii
461///
462/// This encoding provides sufficient generality for currently supported
463/// sparse tensor types. To generalize this more, we will need to provide
464/// a broader encoding scheme for affine functions. Also, the library
465/// encoding may be replaced with pure "direct-IR" code in the future.
466///
467constexpr uint64_t encodeDim(uint64_t i, uint64_t cf, uint64_t cm) {
468 if (cf != 0) {
469 assert(cf <= 0xfffffu && cm == 0 && i <= 0xfffffu);
470 return (static_cast<uint64_t>(0x01u) << 60) | (cf << 20) | i;
471 }
472 if (cm != 0) {
473 assert(cm <= 0xfffffu && i <= 0xfffffu);
474 return (static_cast<uint64_t>(0x02u) << 60) | (cm << 20) | i;
475 }
476 assert(i <= 0x0fffffffffffffffu);
477 return i;
478}
479constexpr uint64_t encodeLvl(uint64_t i, uint64_t c, uint64_t ii) {
480 if (c != 0) {
481 assert(c <= 0xfffffu && ii <= 0xfffffu && i <= 0xfffffu);
482 return (static_cast<uint64_t>(0x03u) << 60) | (c << 20) | (ii << 40) | i;
483 }
484 assert(i <= 0x0fffffffffffffffu);
485 return i;
486}
487constexpr bool isEncodedFloor(uint64_t v) { return (v >> 60) == 0x01u; }
488constexpr bool isEncodedMod(uint64_t v) { return (v >> 60) == 0x02u; }
489constexpr bool isEncodedMul(uint64_t v) { return (v >> 60) == 0x03u; }
490constexpr uint64_t decodeIndex(uint64_t v) { return v & 0xfffffu; }
491constexpr uint64_t decodeConst(uint64_t v) { return (v >> 20) & 0xfffffu; }
492constexpr uint64_t decodeMulc(uint64_t v) { return (v >> 20) & 0xfffffu; }
493constexpr uint64_t decodeMuli(uint64_t v) { return (v >> 40) & 0xfffffu; }
494
495} // namespace sparse_tensor
496} // namespace mlir
497
498#endif // MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
499

source code of mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h