1//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils/CodegenUtils.h"
10
11#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
12#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
13#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
14
15#include <optional>
16
17using namespace mlir;
18using namespace sparse_tensor;
19
20namespace {
21
22//===----------------------------------------------------------------------===//
23// Helper methods.
24//===----------------------------------------------------------------------===//
25
26static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
27 MLIRContext *ctx = tp.getContext();
28 auto enc = tp.getEncoding();
29 const Level lvlRank = enc.getLvlRank();
30
31 SmallVector<Type, 4> result;
32 // TODO: how can we get the lowering type for index type in the later pipeline
33 // to be consistent? LLVM::StructureType does not allow index fields.
34 auto sizeType = IntegerType::get(tp.getContext(), 64);
35 auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
36 auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType,
37 getNumDataFieldsFromEncoding(enc));
38 result.push_back(Elt: lvlSizes);
39 result.push_back(Elt: memSizes);
40
41 if (enc.isSlice()) {
42 // Extra fields are required for the slice information.
43 auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
44 auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
45
46 result.push_back(Elt: dimOffset);
47 result.push_back(Elt: dimStride);
48 }
49
50 return result;
51}
52
53static Type convertSpecifier(StorageSpecifierType tp) {
54 return LLVM::LLVMStructType::getLiteral(tp.getContext(),
55 getSpecifierFields(tp));
56}
57
58//===----------------------------------------------------------------------===//
59// Specifier struct builder.
60//===----------------------------------------------------------------------===//
61
62constexpr uint64_t kLvlSizePosInSpecifier = 0;
63constexpr uint64_t kMemSizePosInSpecifier = 1;
64constexpr uint64_t kDimOffsetPosInSpecifier = 2;
65constexpr uint64_t kDimStridePosInSpecifier = 3;
66
67class SpecifierStructBuilder : public StructBuilder {
68private:
69 Value extractField(OpBuilder &builder, Location loc,
70 ArrayRef<int64_t> indices) const {
71 return genCast(builder, loc,
72 builder.create<LLVM::ExtractValueOp>(loc, value, indices),
73 builder.getIndexType());
74 }
75
76 void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
77 Value v) {
78 value = builder.create<LLVM::InsertValueOp>(
79 loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
80 indices);
81 }
82
83public:
84 explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
85 assert(value);
86 }
87
88 // Undef value for dimension sizes, all zero value for memory sizes.
89 static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
90 Value source);
91
92 Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
93 void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);
94
95 Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
96 void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
97 Value size);
98
99 Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
100 void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
101 Value size);
102
103 Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
104 void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
105 Value size);
106
107 Value memSizeArray(OpBuilder &builder, Location loc) const;
108 void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
109};
110
111Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
112 Type structType, Value source) {
113 Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
114 SpecifierStructBuilder md(metaData);
115 if (!source) {
116 auto memSizeArrayType =
117 cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
118 .getBody()[kMemSizePosInSpecifier]);
119
120 Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
121 // Fill memSizes array with zero.
122 for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
123 md.setMemSize(builder, loc, fidx: i, size: zero);
124 } else {
125 // We copy non-slice information (memory sizes array) from source
126 SpecifierStructBuilder sourceMd(source);
127 md.setMemSizeArray(builder, loc, array: sourceMd.memSizeArray(builder, loc));
128 }
129 return md;
130}
131
132/// Builds IR extracting the pos-th offset from the descriptor.
133Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
134 Dimension dim) const {
135 return extractField(
136 builder, loc,
137 indices: ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
138}
139
140/// Builds IR inserting the pos-th offset into the descriptor.
141void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
142 Dimension dim, Value size) {
143 insertField(
144 builder, loc,
145 indices: ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
146 v: size);
147}
148
149/// Builds IR extracting the `lvl`-th level-size from the descriptor.
150Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
151 Level lvl) const {
152 // This static_cast makes the narrowing of `lvl` explicit, as required
153 // by the braces notation for the ctor.
154 return extractField(
155 builder, loc,
156 indices: ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)});
157}
158
159/// Builds IR inserting the `lvl`-th level-size into the descriptor.
160void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
161 Level lvl, Value size) {
162 // This static_cast makes the narrowing of `lvl` explicit, as required
163 // by the braces notation for the ctor.
164 insertField(
165 builder, loc,
166 indices: ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)},
167 v: size);
168}
169
170/// Builds IR extracting the pos-th stride from the descriptor.
171Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
172 Dimension dim) const {
173 return extractField(
174 builder, loc,
175 indices: ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
176}
177
178/// Builds IR inserting the pos-th stride into the descriptor.
179void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
180 Dimension dim, Value size) {
181 insertField(
182 builder, loc,
183 indices: ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
184 v: size);
185}
186
187/// Builds IR extracting the pos-th memory size into the descriptor.
188Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
189 FieldIndex fidx) const {
190 return extractField(
191 builder, loc,
192 indices: ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
193}
194
195/// Builds IR inserting the `fidx`-th memory-size into the descriptor.
196void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
197 FieldIndex fidx, Value size) {
198 insertField(
199 builder, loc,
200 indices: ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
201 v: size);
202}
203
204/// Builds IR extracting the memory size array from the descriptor.
205Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
206 Location loc) const {
207 return builder.create<LLVM::ExtractValueOp>(loc, value,
208 kMemSizePosInSpecifier);
209}
210
211/// Builds IR inserting the memory size array into the descriptor.
212void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
213 Value array) {
214 value = builder.create<LLVM::InsertValueOp>(loc, value, array,
215 kMemSizePosInSpecifier);
216}
217
218} // namespace
219
220//===----------------------------------------------------------------------===//
221// The sparse storage specifier type converter (defined in Passes.h).
222//===----------------------------------------------------------------------===//
223
224StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
225 addConversion(callback: [](Type type) { return type; });
226 addConversion(convertSpecifier);
227}
228
229//===----------------------------------------------------------------------===//
230// Storage specifier conversion rules.
231//===----------------------------------------------------------------------===//
232
233template <typename Base, typename SourceOp>
234class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
235public:
236 using OpAdaptor = typename SourceOp::Adaptor;
237 using OpConversionPattern<SourceOp>::OpConversionPattern;
238
239 LogicalResult
240 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override {
242 SpecifierStructBuilder spec(adaptor.getSpecifier());
243 switch (op.getSpecifierKind()) {
244 case StorageSpecifierKind::LvlSize: {
245 Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
246 rewriter.replaceOp(op, v);
247 return success();
248 }
249 case StorageSpecifierKind::DimOffset: {
250 Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
251 rewriter.replaceOp(op, v);
252 return success();
253 }
254 case StorageSpecifierKind::DimStride: {
255 Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
256 rewriter.replaceOp(op, v);
257 return success();
258 }
259 case StorageSpecifierKind::CrdMemSize:
260 case StorageSpecifierKind::PosMemSize:
261 case StorageSpecifierKind::ValMemSize: {
262 auto enc = op.getSpecifier().getType().getEncoding();
263 StorageLayout layout(enc);
264 std::optional<unsigned> lvl;
265 if (op.getLevel())
266 lvl = (*op.getLevel());
267 unsigned idx =
268 layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl);
269 Value v = Base::onMemSize(rewriter, op, spec, idx);
270 rewriter.replaceOp(op, v);
271 return success();
272 }
273 }
274 llvm_unreachable("unrecognized specifer kind");
275 }
276};
277
278struct StorageSpecifierSetOpConverter
279 : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
280 SetStorageSpecifierOp> {
281 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
282
283 static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
284 SpecifierStructBuilder &spec, Level lvl) {
285 spec.setLvlSize(builder, loc: op.getLoc(), lvl, size: op.getValue());
286 return spec;
287 }
288
289 static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
290 SpecifierStructBuilder &spec, Dimension d) {
291 spec.setDimOffset(builder, loc: op.getLoc(), dim: d, size: op.getValue());
292 return spec;
293 }
294
295 static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
296 SpecifierStructBuilder &spec, Dimension d) {
297 spec.setDimStride(builder, loc: op.getLoc(), dim: d, size: op.getValue());
298 return spec;
299 }
300
301 static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
302 SpecifierStructBuilder &spec, FieldIndex fidx) {
303 spec.setMemSize(builder, loc: op.getLoc(), fidx, size: op.getValue());
304 return spec;
305 }
306};
307
308struct StorageSpecifierGetOpConverter
309 : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
310 GetStorageSpecifierOp> {
311 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
312
313 static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
314 SpecifierStructBuilder &spec, Level lvl) {
315 return spec.lvlSize(builder, loc: op.getLoc(), lvl);
316 }
317
318 static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
319 const SpecifierStructBuilder &spec, Dimension d) {
320 return spec.dimOffset(builder, loc: op.getLoc(), dim: d);
321 }
322
323 static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
324 const SpecifierStructBuilder &spec, Dimension d) {
325 return spec.dimStride(builder, loc: op.getLoc(), dim: d);
326 }
327
328 static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
329 SpecifierStructBuilder &spec, FieldIndex fidx) {
330 return spec.memSize(builder, loc: op.getLoc(), fidx);
331 }
332};
333
334struct StorageSpecifierInitOpConverter
335 : public OpConversionPattern<StorageSpecifierInitOp> {
336public:
337 using OpConversionPattern::OpConversionPattern;
338 LogicalResult
339 matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const override {
341 Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
342 rewriter.replaceOp(
343 op, SpecifierStructBuilder::getInitValue(
344 builder&: rewriter, loc: op.getLoc(), structType: llvmType, source: adaptor.getSource()));
345 return success();
346 }
347};
348
349//===----------------------------------------------------------------------===//
350// Public method for populating conversion rules.
351//===----------------------------------------------------------------------===//
352
353void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
354 RewritePatternSet &patterns) {
355 patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
356 StorageSpecifierInitOpConverter>(arg&: converter,
357 args: patterns.getContext());
358}
359

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp