1//===- SparseTensor.h - Sparse tensor dialect -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
10#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
11
12#include "mlir/Bytecode/BytecodeOpInterface.h"
13#include "mlir/Dialect/SparseTensor/IR/Enums.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/IR/OpImplementation.h"
19#include "mlir/IR/TensorEncoding.h"
20#include "mlir/Interfaces/ControlFlowInterfaces.h"
21#include "mlir/Interfaces/InferTypeOpInterface.h"
22#include "mlir/Interfaces/LoopLikeInterface.h"
23#include "mlir/Interfaces/SideEffectInterfaces.h"
24
25#include "llvm/ADT/bit.h"
26
27//===----------------------------------------------------------------------===//
28//
29// Type aliases to help code be more self-documenting. Unfortunately
30// these are not type-checked, so they only provide documentation rather
31// than doing anything to prevent mixups.
32//
33//===----------------------------------------------------------------------===//
34
35namespace mlir {
36namespace sparse_tensor {
37
38/// The type of dimension identifiers and dimension-ranks.
39using Dimension = uint64_t;
40
41/// The type of level identifiers and level-ranks.
42using Level = uint64_t;
43
44/// The type for individual components of a compile-time shape,
45/// including the value `ShapedType::kDynamic` (for shapes).
46using Size = int64_t;
47
48/// A simple structure that encodes a range of levels in the sparse tensors
49/// that forms a COO segment.
50struct COOSegment {
51 std::pair<Level, Level> lvlRange; // [low, high)
52 bool isSoA;
53
54 bool isAoS() const { return !isSoA; }
55 bool isSegmentStart(Level l) const { return l == lvlRange.first; }
56 bool inSegment(Level l) const {
57 return l >= lvlRange.first && l < lvlRange.second;
58 }
59};
60
61/// A simple wrapper to encode a bitset of (at most 64) levels, currently used
62/// by `sparse_tensor.iterate` operation for the set of levels on which the
63/// coordinates should be loaded.
64class I64BitSet {
65 uint64_t storage = 0;
66
67public:
68 using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>;
69 const_set_bits_iterator begin() const {
70 return const_set_bits_iterator(*this);
71 }
72 const_set_bits_iterator end() const {
73 return const_set_bits_iterator(*this, -1);
74 }
75 iterator_range<const_set_bits_iterator> bits() const {
76 return make_range(x: begin(), y: end());
77 }
78
79 I64BitSet() = default;
80 explicit I64BitSet(uint64_t bits) : storage(bits) {}
81 operator uint64_t() const { return storage; }
82
83 I64BitSet &set(unsigned i) {
84 assert(i < 64);
85 storage |= static_cast<uint64_t>(0x01u) << i;
86 return *this;
87 }
88
89 I64BitSet &operator|=(I64BitSet lhs) {
90 storage |= static_cast<uint64_t>(lhs);
91 return *this;
92 }
93
94 I64BitSet &lshift(unsigned offset) {
95 storage = storage << offset;
96 return *this;
97 }
98
99 bool isSubSetOf(const I64BitSet p) const {
100 I64BitSet tmp = *this;
101 tmp |= p;
102 return tmp == p;
103 }
104
105 // Needed by `llvm::const_set_bits_iterator_impl`.
106 int find_first() const { return min(); }
107 int find_next(unsigned prev) const {
108 if (prev >= max() - 1)
109 return -1;
110
111 uint64_t b = storage >> (prev + static_cast<int64_t>(1));
112 assert(b != 0);
113
114 return llvm::countr_zero(Val: b) + prev + static_cast<int64_t>(1);
115 }
116
117 bool operator[](unsigned i) const {
118 assert(i < 64);
119 return (storage & (static_cast<int64_t>(1) << i)) != 0;
120 }
121 unsigned min() const {
122 unsigned m = llvm::countr_zero(Val: storage);
123 return m == 64 ? -1 : m;
124 }
125 unsigned max() const { return llvm::bit_width(Value: storage); }
126 unsigned count() const { return llvm::popcount(Value: storage); }
127 bool empty() const { return storage == 0; }
128};
129
130} // namespace sparse_tensor
131} // namespace mlir
132
133//===----------------------------------------------------------------------===//
134// TableGen-defined classes
135//===----------------------------------------------------------------------===//
136
137#define GET_ATTRDEF_CLASSES
138#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.h.inc"
139
140#define GET_ATTRDEF_CLASSES
141#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.h.inc"
142
143#define GET_TYPEDEF_CLASSES
144#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.h.inc"
145
146#define GET_OP_CLASSES
147#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.h.inc"
148
149#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.h.inc"
150
151//===----------------------------------------------------------------------===//
152// Additional convenience methods.
153//===----------------------------------------------------------------------===//
154
155namespace mlir {
156namespace sparse_tensor {
157
158/// Convenience method to abbreviate casting `getType()`.
159template <typename T>
160inline RankedTensorType getRankedTensorType(T &&t) {
161 assert(static_cast<bool>(std::forward<T>(t)) &&
162 "getRankedTensorType got null argument");
163 return dyn_cast<RankedTensorType>(std::forward<T>(t).getType());
164}
165
166/// Convenience method to abbreviate casting `getType()`.
167template <typename T>
168inline MemRefType getMemRefType(T &&t) {
169 assert(static_cast<bool>(std::forward<T>(t)) &&
170 "getMemRefType got null argument");
171 return cast<MemRefType>(std::forward<T>(t).getType());
172}
173
174/// Convenience method to get a sparse encoding attribute from a type.
175/// Returns null-attribute for any type without an encoding.
176SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
177
178/// Returns true iff the type range has any sparse tensor type.
179inline bool hasAnySparseType(TypeRange types) {
180 return llvm::any_of(Range&: types, P: [](Type type) {
181 return getSparseTensorEncoding(type) != nullptr;
182 });
183}
184
185/// Returns true iff MLIR operand has any sparse operand.
186inline bool hasAnySparseOperand(Operation *op) {
187 return hasAnySparseType(types: op->getOperands().getTypes());
188}
189
190/// Returns true iff MLIR operand has any sparse result.
191inline bool hasAnySparseResult(Operation *op) {
192 return hasAnySparseType(types: op->getResults().getTypes());
193}
194
195/// Returns true iff MLIR operand has any sparse operand or result.
196inline bool hasAnySparseOperandOrResult(Operation *op) {
197 return hasAnySparseOperand(op) || hasAnySparseResult(op);
198}
199
200/// Returns true iff MLIR operation has any sparse tensor with non-identity
201/// dim2lvl maps.
202bool hasAnyNonIdentityOperandsOrResults(Operation *op);
203
204//
205// Inference.
206//
207
208/// Given the dimToLvl map, infers the lvlToDim map, or returns
209/// empty Affine map when inference fails.
210AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);
211
212/// Returns the lvlToDim map for the given dimToLvl map specific
213/// to the block sparse cases.
214/// Asserts on failure (so only use when known to succeed).
215AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);
216
217/// Given the dimToLvl map, returns the block sizes in a vector.
218/// For instance, a 2x3 block will return [2, 3]. Unblocked dimension i
219/// will return 0, and i floordiv 1, i mod 1 will return 1. Therefore,
220/// the example below will return [0, 1].
221/// map = ( i, j ) ->
222/// ( i : dense,
223/// j floordiv 1 : compressed,
224/// j mod 1 : dense
225/// )
226/// Only valid block sparsity will be accepted.
227SmallVector<unsigned> getBlockSize(AffineMap dimToLvl);
228
229/// Given the dimToLvl map, returns if it's block sparsity.
230bool isBlockSparsity(AffineMap dimToLvl);
231
232//
233// Reordering.
234//
235
236/// Convenience method to translate the given level to the corresponding
237/// dimension.
238/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
239Dimension toDim(SparseTensorEncodingAttr enc, Level l);
240
241/// Convenience method to translate the given dimension to the corresponding
242/// level.
243/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
244Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
245
246} // namespace sparse_tensor
247} // namespace mlir
248
249#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_
250

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h