1 | //===- SparseTensorDescriptor.h ---------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header file defines utilities for the sparse memory layout. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_ |
14 | #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_ |
15 | |
16 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
17 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
18 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
19 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
20 | |
21 | namespace mlir { |
22 | namespace sparse_tensor { |
23 | |
24 | class SparseTensorSpecifier { |
25 | public: |
26 | explicit SparseTensorSpecifier(Value specifier) |
27 | : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {} |
28 | |
29 | // Undef value for level-sizes, all zero values for memory-sizes. |
30 | static Value getInitValue(OpBuilder &builder, Location loc, |
31 | SparseTensorType stt); |
32 | |
33 | /*implicit*/ operator Value() { return specifier; } |
34 | |
35 | Value getSpecifierField(OpBuilder &builder, Location loc, |
36 | StorageSpecifierKind kind, std::optional<Level> lvl); |
37 | |
38 | void setSpecifierField(OpBuilder &builder, Location loc, Value v, |
39 | StorageSpecifierKind kind, std::optional<Level> lvl); |
40 | |
41 | private: |
42 | TypedValue<StorageSpecifierType> specifier; |
43 | }; |
44 | |
45 | /// A helper class around an array of values that corresponds to a sparse |
46 | /// tensor. This class provides a set of meaningful APIs to query and update |
47 | /// a particular field in a consistent way. Users should not make assumptions |
48 | /// on how a sparse tensor is laid out but instead rely on this class to access |
49 | /// the right value for the right field. |
50 | template <typename ValueArrayRef> |
51 | class SparseTensorDescriptorImpl { |
52 | protected: |
53 | SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields) |
54 | : rType(stt), fields(fields), layout(stt) { |
55 | assert(layout.getNumFields() == getNumFields()); |
56 | // We should make sure the class is trivially copyable (and should be small |
57 | // enough) such that we can pass it by value. |
58 | static_assert(std::is_trivially_copyable_v< |
59 | SparseTensorDescriptorImpl<ValueArrayRef>>); |
60 | } |
61 | |
62 | public: |
63 | FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, |
64 | std::optional<Level> lvl) const { |
65 | // Delegates to storage layout. |
66 | return layout.getMemRefFieldIndex(kind, lvl); |
67 | } |
68 | |
69 | unsigned getNumFields() const { return fields.size(); } |
70 | |
71 | /// |
72 | /// Getters: get the value for required field. |
73 | /// |
74 | |
75 | Value getSpecifier() const { return fields.back(); } |
76 | |
77 | Value getSpecifierField(OpBuilder &builder, Location loc, |
78 | StorageSpecifierKind kind, |
79 | std::optional<Level> lvl) const { |
80 | SparseTensorSpecifier md(fields.back()); |
81 | return md.getSpecifierField(builder, loc, kind, lvl); |
82 | } |
83 | |
84 | Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const { |
85 | return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl); |
86 | } |
87 | |
88 | Value getPosMemRef(Level lvl) const { |
89 | return getMemRefField(SparseTensorFieldKind::PosMemRef, lvl); |
90 | } |
91 | |
92 | Value getValMemRef() const { |
93 | return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt); |
94 | } |
95 | |
96 | Value getMemRefField(SparseTensorFieldKind kind, |
97 | std::optional<Level> lvl) const { |
98 | return getField(fidx: getMemRefFieldIndex(kind, lvl)); |
99 | } |
100 | |
101 | Value getMemRefField(FieldIndex fidx) const { |
102 | assert(fidx < fields.size() - 1); |
103 | return getField(fidx); |
104 | } |
105 | |
106 | Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const { |
107 | return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, |
108 | lvl); |
109 | } |
110 | |
111 | Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const { |
112 | return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, |
113 | lvl); |
114 | } |
115 | |
116 | Value getValMemSize(OpBuilder &builder, Location loc) const { |
117 | return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, |
118 | std::nullopt); |
119 | } |
120 | |
121 | Type getMemRefElementType(SparseTensorFieldKind kind, |
122 | std::optional<Level> lvl) const { |
123 | return getMemRefType(getMemRefField(kind, lvl)).getElementType(); |
124 | } |
125 | |
126 | Value getField(FieldIndex fidx) const { |
127 | assert(fidx < fields.size()); |
128 | return fields[fidx]; |
129 | } |
130 | |
131 | ValueRange getMemRefFields() const { |
132 | return fields.drop_back(); // drop the last metadata fields |
133 | } |
134 | |
135 | std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const { |
136 | return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl); |
137 | } |
138 | |
139 | Value getAOSMemRef() const { |
140 | const Level cooStart = rType.getAoSCOOStart(); |
141 | assert(cooStart < rType.getLvlRank()); |
142 | return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart); |
143 | } |
144 | |
145 | RankedTensorType getRankedTensorType() const { return rType; } |
146 | ValueArrayRef getFields() const { return fields; } |
147 | StorageLayout getLayout() const { return layout; } |
148 | |
149 | protected: |
150 | SparseTensorType rType; |
151 | ValueArrayRef fields; |
152 | StorageLayout layout; |
153 | }; |
154 | |
155 | /// Uses ValueRange for immutable descriptors. |
156 | class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> { |
157 | public: |
158 | SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers) |
159 | : SparseTensorDescriptorImpl<ValueRange>(stt, buffers) {} |
160 | |
161 | Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const; |
162 | }; |
163 | |
164 | /// Using SmallVector for mutable descriptor allows users to reuse it as a |
165 | /// tmp buffers to append value for some special cases, though users should |
166 | /// be responsible to restore the buffer to legal states after their use. It |
167 | /// is probably not a clean way, but it is the most efficient way to avoid |
168 | /// copying the fields into another SmallVector. If a more clear way is |
169 | /// wanted, we should change it to MutableArrayRef instead. |
170 | class MutSparseTensorDescriptor |
171 | : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> { |
172 | public: |
173 | MutSparseTensorDescriptor(SparseTensorType stt, |
174 | SmallVectorImpl<Value> &buffers) |
175 | : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(stt, buffers) {} |
176 | |
177 | // Allow implicit type conversion from mutable descriptors to immutable ones |
178 | // (but not vice versa). |
179 | /*implicit*/ operator SparseTensorDescriptor() const { |
180 | return SparseTensorDescriptor(rType, fields); |
181 | } |
182 | |
183 | /// |
184 | /// Adds additional setters for mutable descriptor, update the value for |
185 | /// required field. |
186 | /// |
187 | |
188 | void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl, |
189 | Value v) { |
190 | fields[getMemRefFieldIndex(kind, lvl)] = v; |
191 | } |
192 | |
193 | void setMemRefField(FieldIndex fidx, Value v) { |
194 | assert(fidx < fields.size() - 1); |
195 | fields[fidx] = v; |
196 | } |
197 | |
198 | void setField(FieldIndex fidx, Value v) { |
199 | assert(fidx < fields.size()); |
200 | fields[fidx] = v; |
201 | } |
202 | |
203 | void setSpecifier(Value newSpec) { fields.back() = newSpec; } |
204 | |
205 | void setSpecifierField(OpBuilder &builder, Location loc, |
206 | StorageSpecifierKind kind, std::optional<Level> lvl, |
207 | Value v) { |
208 | SparseTensorSpecifier md(fields.back()); |
209 | md.setSpecifierField(builder, loc, v, kind, lvl); |
210 | fields.back() = md; |
211 | } |
212 | |
213 | void setValMemSize(OpBuilder &builder, Location loc, Value v) { |
214 | setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, |
215 | std::nullopt, v); |
216 | } |
217 | |
218 | void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { |
219 | setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v); |
220 | } |
221 | |
222 | void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { |
223 | setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v); |
224 | } |
225 | |
226 | void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) { |
227 | setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v); |
228 | } |
229 | }; |
230 | |
231 | /// Returns the "tuple" value of the adapted tensor. |
232 | inline UnrealizedConversionCastOp getTuple(Value tensor) { |
233 | return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp()); |
234 | } |
235 | |
236 | /// Packs the given values as a "tuple" value. |
237 | inline Value genTuple(OpBuilder &builder, Location loc, Type tp, |
238 | ValueRange values) { |
239 | return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values) |
240 | .getResult(0); |
241 | } |
242 | |
243 | inline Value genTuple(OpBuilder &builder, Location loc, |
244 | SparseTensorDescriptor desc) { |
245 | return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields()); |
246 | } |
247 | |
248 | inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { |
249 | auto tuple = getTuple(tensor); |
250 | SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0])); |
251 | return SparseTensorDescriptor(stt, tuple.getInputs()); |
252 | } |
253 | |
254 | inline MutSparseTensorDescriptor |
255 | getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) { |
256 | auto tuple = getTuple(tensor); |
257 | fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); |
258 | SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0])); |
259 | return MutSparseTensorDescriptor(stt, fields); |
260 | } |
261 | |
262 | } // namespace sparse_tensor |
263 | } // namespace mlir |
264 | |
265 | #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSODESCRIPTOR_H_ |
266 | |