| 1 | //===- SparseTensorDescriptor.h ---------------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This header file defines utilities for the sparse memory layout. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_ |
| 14 | #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_ |
| 15 | |
| 16 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 17 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
| 18 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 19 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 20 | |
| 21 | namespace mlir { |
| 22 | namespace sparse_tensor { |
| 23 | |
| 24 | class SparseTensorSpecifier { |
| 25 | public: |
| 26 | explicit SparseTensorSpecifier(Value specifier) |
| 27 | : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {} |
| 28 | |
| 29 | // Undef value for level-sizes, all zero values for memory-sizes. |
| 30 | static Value getInitValue(OpBuilder &builder, Location loc, |
| 31 | SparseTensorType stt); |
| 32 | |
| 33 | /*implicit*/ operator Value() { return specifier; } |
| 34 | |
| 35 | Value getSpecifierField(OpBuilder &builder, Location loc, |
| 36 | StorageSpecifierKind kind, std::optional<Level> lvl); |
| 37 | |
| 38 | void setSpecifierField(OpBuilder &builder, Location loc, Value v, |
| 39 | StorageSpecifierKind kind, std::optional<Level> lvl); |
| 40 | |
| 41 | private: |
| 42 | TypedValue<StorageSpecifierType> specifier; |
| 43 | }; |
| 44 | |
| 45 | /// A helper class around an array of values that corresponds to a sparse |
| 46 | /// tensor. This class provides a set of meaningful APIs to query and update |
| 47 | /// a particular field in a consistent way. Users should not make assumptions |
| 48 | /// on how a sparse tensor is laid out but instead rely on this class to access |
| 49 | /// the right value for the right field. |
| 50 | template <typename ValueArrayRef> |
| 51 | class SparseTensorDescriptorImpl { |
| 52 | protected: |
| 53 | SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields) |
| 54 | : rType(stt), fields(fields), layout(stt) { |
| 55 | assert(layout.getNumFields() == getNumFields()); |
| 56 | // We should make sure the class is trivially copyable (and should be small |
| 57 | // enough) such that we can pass it by value. |
| 58 | static_assert(std::is_trivially_copyable_v< |
| 59 | SparseTensorDescriptorImpl<ValueArrayRef>>); |
| 60 | } |
| 61 | |
| 62 | public: |
| 63 | FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, |
| 64 | std::optional<Level> lvl) const { |
| 65 | // Delegates to storage layout. |
| 66 | return layout.getMemRefFieldIndex(kind, lvl); |
| 67 | } |
| 68 | |
| 69 | unsigned getNumFields() const { return fields.size(); } |
| 70 | |
| 71 | /// |
| 72 | /// Getters: get the value for required field. |
| 73 | /// |
| 74 | |
| 75 | Value getSpecifier() const { return fields.back(); } |
| 76 | |
| 77 | Value getSpecifierField(OpBuilder &builder, Location loc, |
| 78 | StorageSpecifierKind kind, |
| 79 | std::optional<Level> lvl) const { |
| 80 | SparseTensorSpecifier md(fields.back()); |
| 81 | return md.getSpecifierField(builder, loc, kind, lvl); |
| 82 | } |
| 83 | |
| 84 | Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const { |
| 85 | return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl); |
| 86 | } |
| 87 | |
| 88 | Value getPosMemRef(Level lvl) const { |
| 89 | return getMemRefField(SparseTensorFieldKind::PosMemRef, lvl); |
| 90 | } |
| 91 | |
| 92 | Value getValMemRef() const { |
| 93 | return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt); |
| 94 | } |
| 95 | |
| 96 | Value getMemRefField(SparseTensorFieldKind kind, |
| 97 | std::optional<Level> lvl) const { |
| 98 | return getField(fidx: getMemRefFieldIndex(kind, lvl)); |
| 99 | } |
| 100 | |
| 101 | Value getMemRefField(FieldIndex fidx) const { |
| 102 | assert(fidx < fields.size() - 1); |
| 103 | return getField(fidx); |
| 104 | } |
| 105 | |
| 106 | Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const { |
| 107 | return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, |
| 108 | lvl); |
| 109 | } |
| 110 | |
| 111 | Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const { |
| 112 | return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, |
| 113 | lvl); |
| 114 | } |
| 115 | |
| 116 | Value getValMemSize(OpBuilder &builder, Location loc) const { |
| 117 | return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, |
| 118 | std::nullopt); |
| 119 | } |
| 120 | |
| 121 | Type getMemRefElementType(SparseTensorFieldKind kind, |
| 122 | std::optional<Level> lvl) const { |
| 123 | return getMemRefType(getMemRefField(kind, lvl)).getElementType(); |
| 124 | } |
| 125 | |
| 126 | Value getField(FieldIndex fidx) const { |
| 127 | assert(fidx < fields.size()); |
| 128 | return fields[fidx]; |
| 129 | } |
| 130 | |
| 131 | ValueRange getMemRefFields() const { |
| 132 | return fields.drop_back(); // drop the last metadata fields |
| 133 | } |
| 134 | |
| 135 | std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const { |
| 136 | return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl); |
| 137 | } |
| 138 | |
| 139 | Value getAOSMemRef() const { |
| 140 | const Level cooStart = rType.getAoSCOOStart(); |
| 141 | assert(cooStart < rType.getLvlRank()); |
| 142 | return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart); |
| 143 | } |
| 144 | |
| 145 | RankedTensorType getRankedTensorType() const { return rType; } |
| 146 | ValueArrayRef getFields() const { return fields; } |
| 147 | StorageLayout getLayout() const { return layout; } |
| 148 | |
| 149 | protected: |
| 150 | SparseTensorType rType; |
| 151 | ValueArrayRef fields; |
| 152 | StorageLayout layout; |
| 153 | }; |
| 154 | |
| 155 | /// Uses ValueRange for immutable descriptors. |
| 156 | class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> { |
| 157 | public: |
| 158 | SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers) |
| 159 | : SparseTensorDescriptorImpl<ValueRange>(stt, buffers) {} |
| 160 | |
| 161 | Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const; |
| 162 | }; |
| 163 | |
| 164 | /// Using SmallVector for mutable descriptor allows users to reuse it as a |
| 165 | /// tmp buffers to append value for some special cases, though users should |
| 166 | /// be responsible to restore the buffer to legal states after their use. It |
| 167 | /// is probably not a clean way, but it is the most efficient way to avoid |
| 168 | /// copying the fields into another SmallVector. If a more clear way is |
| 169 | /// wanted, we should change it to MutableArrayRef instead. |
| 170 | class MutSparseTensorDescriptor |
| 171 | : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> { |
| 172 | public: |
| 173 | MutSparseTensorDescriptor(SparseTensorType stt, |
| 174 | SmallVectorImpl<Value> &buffers) |
| 175 | : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(stt, buffers) {} |
| 176 | |
| 177 | // Allow implicit type conversion from mutable descriptors to immutable ones |
| 178 | // (but not vice versa). |
| 179 | /*implicit*/ operator SparseTensorDescriptor() const { |
| 180 | return SparseTensorDescriptor(rType, fields); |
| 181 | } |
| 182 | |
| 183 | /// |
| 184 | /// Adds additional setters for mutable descriptor, update the value for |
| 185 | /// required field. |
| 186 | /// |
| 187 | |
| 188 | void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl, |
| 189 | Value v) { |
| 190 | fields[getMemRefFieldIndex(kind, lvl)] = v; |
| 191 | } |
| 192 | |
| 193 | void setMemRefField(FieldIndex fidx, Value v) { |
| 194 | assert(fidx < fields.size() - 1); |
| 195 | fields[fidx] = v; |
| 196 | } |
| 197 | |
| 198 | void setField(FieldIndex fidx, Value v) { |
| 199 | assert(fidx < fields.size()); |
| 200 | fields[fidx] = v; |
| 201 | } |
| 202 | |
| 203 | void setSpecifier(Value newSpec) { fields.back() = newSpec; } |
| 204 | |
| 205 | void setSpecifierField(OpBuilder &builder, Location loc, |
| 206 | StorageSpecifierKind kind, std::optional<Level> lvl, |
| 207 | Value v) { |
| 208 | SparseTensorSpecifier md(fields.back()); |
| 209 | md.setSpecifierField(builder, loc, v, kind, lvl); |
| 210 | fields.back() = md; |
| 211 | } |
| 212 | |
| 213 | void setValMemSize(OpBuilder &builder, Location loc, Value v) { |
| 214 | setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize, |
| 215 | std::nullopt, v); |
| 216 | } |
| 217 | |
| 218 | void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { |
| 219 | setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v); |
| 220 | } |
| 221 | |
| 222 | void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { |
| 223 | setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v); |
| 224 | } |
| 225 | |
| 226 | void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) { |
| 227 | setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v); |
| 228 | } |
| 229 | }; |
| 230 | |
| 231 | /// Packs the given values as a "tuple" value. |
| 232 | inline Value genTuple(OpBuilder &builder, Location loc, Type tp, |
| 233 | ValueRange values) { |
| 234 | return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values) |
| 235 | .getResult(0); |
| 236 | } |
| 237 | |
| 238 | inline Value genTuple(OpBuilder &builder, Location loc, |
| 239 | SparseTensorDescriptor desc) { |
| 240 | return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields()); |
| 241 | } |
| 242 | |
| 243 | inline SparseTensorDescriptor |
| 244 | getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) { |
| 245 | return SparseTensorDescriptor(SparseTensorType(type), adaptorValues); |
| 246 | } |
| 247 | |
| 248 | inline MutSparseTensorDescriptor |
| 249 | getMutDescriptorFromTensorTuple(ValueRange adaptorValues, |
| 250 | SmallVectorImpl<Value> &fields, |
| 251 | RankedTensorType type) { |
| 252 | fields.assign(in_start: adaptorValues.begin(), in_end: adaptorValues.end()); |
| 253 | return MutSparseTensorDescriptor(SparseTensorType(type), fields); |
| 254 | } |
| 255 | |
| 256 | } // namespace sparse_tensor |
| 257 | } // namespace mlir |
| 258 | |
| 259 | #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSODESCRIPTOR_H_ |
| 260 | |