1//===- SparseTensorDescriptor.h ---------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines utilities for the sparse memory layout.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
14#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
15
16#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
18#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
19#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
20
21namespace mlir {
22namespace sparse_tensor {
23
24class SparseTensorSpecifier {
25public:
26 explicit SparseTensorSpecifier(Value specifier)
27 : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {}
28
29 // Undef value for level-sizes, all zero values for memory-sizes.
30 static Value getInitValue(OpBuilder &builder, Location loc,
31 SparseTensorType stt);
32
33 /*implicit*/ operator Value() { return specifier; }
34
35 Value getSpecifierField(OpBuilder &builder, Location loc,
36 StorageSpecifierKind kind, std::optional<Level> lvl);
37
38 void setSpecifierField(OpBuilder &builder, Location loc, Value v,
39 StorageSpecifierKind kind, std::optional<Level> lvl);
40
41private:
42 TypedValue<StorageSpecifierType> specifier;
43};
44
45/// A helper class around an array of values that corresponds to a sparse
46/// tensor. This class provides a set of meaningful APIs to query and update
47/// a particular field in a consistent way. Users should not make assumptions
48/// on how a sparse tensor is laid out but instead rely on this class to access
49/// the right value for the right field.
50template <typename ValueArrayRef>
51class SparseTensorDescriptorImpl {
52protected:
53 SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
54 : rType(stt), fields(fields), layout(stt) {
55 assert(layout.getNumFields() == getNumFields());
56 // We should make sure the class is trivially copyable (and should be small
57 // enough) such that we can pass it by value.
58 static_assert(std::is_trivially_copyable_v<
59 SparseTensorDescriptorImpl<ValueArrayRef>>);
60 }
61
62public:
63 FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
64 std::optional<Level> lvl) const {
65 // Delegates to storage layout.
66 return layout.getMemRefFieldIndex(kind, lvl);
67 }
68
69 unsigned getNumFields() const { return fields.size(); }
70
71 ///
72 /// Getters: get the value for required field.
73 ///
74
75 Value getSpecifier() const { return fields.back(); }
76
77 Value getSpecifierField(OpBuilder &builder, Location loc,
78 StorageSpecifierKind kind,
79 std::optional<Level> lvl) const {
80 SparseTensorSpecifier md(fields.back());
81 return md.getSpecifierField(builder, loc, kind, lvl);
82 }
83
84 Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const {
85 return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl);
86 }
87
88 Value getPosMemRef(Level lvl) const {
89 return getMemRefField(SparseTensorFieldKind::PosMemRef, lvl);
90 }
91
92 Value getValMemRef() const {
93 return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
94 }
95
96 Value getMemRefField(SparseTensorFieldKind kind,
97 std::optional<Level> lvl) const {
98 return getField(fidx: getMemRefFieldIndex(kind, lvl));
99 }
100
101 Value getMemRefField(FieldIndex fidx) const {
102 assert(fidx < fields.size() - 1);
103 return getField(fidx);
104 }
105
106 Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const {
107 return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize,
108 lvl);
109 }
110
111 Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const {
112 return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize,
113 lvl);
114 }
115
116 Value getValMemSize(OpBuilder &builder, Location loc) const {
117 return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
118 std::nullopt);
119 }
120
121 Type getMemRefElementType(SparseTensorFieldKind kind,
122 std::optional<Level> lvl) const {
123 return getMemRefType(getMemRefField(kind, lvl)).getElementType();
124 }
125
126 Value getField(FieldIndex fidx) const {
127 assert(fidx < fields.size());
128 return fields[fidx];
129 }
130
131 ValueRange getMemRefFields() const {
132 return fields.drop_back(); // drop the last metadata fields
133 }
134
135 std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const {
136 return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl);
137 }
138
139 Value getAOSMemRef() const {
140 const Level cooStart = rType.getAoSCOOStart();
141 assert(cooStart < rType.getLvlRank());
142 return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
143 }
144
145 RankedTensorType getRankedTensorType() const { return rType; }
146 ValueArrayRef getFields() const { return fields; }
147 StorageLayout getLayout() const { return layout; }
148
149protected:
150 SparseTensorType rType;
151 ValueArrayRef fields;
152 StorageLayout layout;
153};
154
155/// Uses ValueRange for immutable descriptors.
156class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> {
157public:
158 SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers)
159 : SparseTensorDescriptorImpl<ValueRange>(stt, buffers) {}
160
161 Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const;
162};
163
164/// Using SmallVector for mutable descriptor allows users to reuse it as a
165/// tmp buffers to append value for some special cases, though users should
166/// be responsible to restore the buffer to legal states after their use. It
167/// is probably not a clean way, but it is the most efficient way to avoid
168/// copying the fields into another SmallVector. If a more clear way is
169/// wanted, we should change it to MutableArrayRef instead.
170class MutSparseTensorDescriptor
171 : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
172public:
173 MutSparseTensorDescriptor(SparseTensorType stt,
174 SmallVectorImpl<Value> &buffers)
175 : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(stt, buffers) {}
176
177 // Allow implicit type conversion from mutable descriptors to immutable ones
178 // (but not vice versa).
179 /*implicit*/ operator SparseTensorDescriptor() const {
180 return SparseTensorDescriptor(rType, fields);
181 }
182
183 ///
184 /// Adds additional setters for mutable descriptor, update the value for
185 /// required field.
186 ///
187
188 void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl,
189 Value v) {
190 fields[getMemRefFieldIndex(kind, lvl)] = v;
191 }
192
193 void setMemRefField(FieldIndex fidx, Value v) {
194 assert(fidx < fields.size() - 1);
195 fields[fidx] = v;
196 }
197
198 void setField(FieldIndex fidx, Value v) {
199 assert(fidx < fields.size());
200 fields[fidx] = v;
201 }
202
203 void setSpecifier(Value newSpec) { fields.back() = newSpec; }
204
205 void setSpecifierField(OpBuilder &builder, Location loc,
206 StorageSpecifierKind kind, std::optional<Level> lvl,
207 Value v) {
208 SparseTensorSpecifier md(fields.back());
209 md.setSpecifierField(builder, loc, v, kind, lvl);
210 fields.back() = md;
211 }
212
213 void setValMemSize(OpBuilder &builder, Location loc, Value v) {
214 setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
215 std::nullopt, v);
216 }
217
218 void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
219 setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v);
220 }
221
222 void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
223 setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v);
224 }
225
226 void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
227 setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v);
228 }
229};
230
231/// Returns the "tuple" value of the adapted tensor.
232inline UnrealizedConversionCastOp getTuple(Value tensor) {
233 return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
234}
235
236/// Packs the given values as a "tuple" value.
237inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
238 ValueRange values) {
239 return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
240 .getResult(0);
241}
242
243inline Value genTuple(OpBuilder &builder, Location loc,
244 SparseTensorDescriptor desc) {
245 return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
246}
247
248inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
249 auto tuple = getTuple(tensor);
250 SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
251 return SparseTensorDescriptor(stt, tuple.getInputs());
252}
253
254inline MutSparseTensorDescriptor
255getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
256 auto tuple = getTuple(tensor);
257 fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
258 SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
259 return MutSparseTensorDescriptor(stt, fields);
260}
261
262} // namespace sparse_tensor
263} // namespace mlir
264
265#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSODESCRIPTOR_H_
266

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h