1 | //===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Utils/CodegenUtils.h" |
10 | |
11 | #include "mlir/Conversion/LLVMCommon/StructBuilder.h" |
12 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
13 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
14 | |
15 | #include <optional> |
16 | |
17 | using namespace mlir; |
18 | using namespace sparse_tensor; |
19 | |
20 | namespace { |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // Helper methods. |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) { |
27 | MLIRContext *ctx = tp.getContext(); |
28 | auto enc = tp.getEncoding(); |
29 | const Level lvlRank = enc.getLvlRank(); |
30 | |
31 | SmallVector<Type, 4> result; |
32 | // TODO: how can we get the lowering type for index type in the later pipeline |
33 | // to be consistent? LLVM::StructureType does not allow index fields. |
34 | auto sizeType = IntegerType::get(tp.getContext(), 64); |
35 | auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); |
36 | auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType, |
37 | getNumDataFieldsFromEncoding(enc)); |
38 | result.push_back(Elt: lvlSizes); |
39 | result.push_back(Elt: memSizes); |
40 | |
41 | if (enc.isSlice()) { |
42 | // Extra fields are required for the slice information. |
43 | auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); |
44 | auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank); |
45 | |
46 | result.push_back(Elt: dimOffset); |
47 | result.push_back(Elt: dimStride); |
48 | } |
49 | |
50 | return result; |
51 | } |
52 | |
53 | static Type convertSpecifier(StorageSpecifierType tp) { |
54 | return LLVM::LLVMStructType::getLiteral(tp.getContext(), |
55 | getSpecifierFields(tp)); |
56 | } |
57 | |
58 | //===----------------------------------------------------------------------===// |
59 | // Specifier struct builder. |
60 | //===----------------------------------------------------------------------===// |
61 | |
62 | constexpr uint64_t kLvlSizePosInSpecifier = 0; |
63 | constexpr uint64_t kMemSizePosInSpecifier = 1; |
64 | constexpr uint64_t kDimOffsetPosInSpecifier = 2; |
65 | constexpr uint64_t kDimStridePosInSpecifier = 3; |
66 | |
67 | class SpecifierStructBuilder : public StructBuilder { |
68 | private: |
69 | Value (OpBuilder &builder, Location loc, |
70 | ArrayRef<int64_t> indices) const { |
71 | return genCast(builder, loc, |
72 | builder.create<LLVM::ExtractValueOp>(loc, value, indices), |
73 | builder.getIndexType()); |
74 | } |
75 | |
76 | void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices, |
77 | Value v) { |
78 | value = builder.create<LLVM::InsertValueOp>( |
79 | loc, value, genCast(builder, loc, v, builder.getIntegerType(64)), |
80 | indices); |
81 | } |
82 | |
83 | public: |
84 | explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) { |
85 | assert(value); |
86 | } |
87 | |
88 | // Undef value for dimension sizes, all zero value for memory sizes. |
89 | static Value getInitValue(OpBuilder &builder, Location loc, Type structType, |
90 | Value source); |
91 | |
92 | Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const; |
93 | void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size); |
94 | |
95 | Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const; |
96 | void setDimOffset(OpBuilder &builder, Location loc, Dimension dim, |
97 | Value size); |
98 | |
99 | Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const; |
100 | void setDimStride(OpBuilder &builder, Location loc, Dimension dim, |
101 | Value size); |
102 | |
103 | Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const; |
104 | void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx, |
105 | Value size); |
106 | |
107 | Value memSizeArray(OpBuilder &builder, Location loc) const; |
108 | void setMemSizeArray(OpBuilder &builder, Location loc, Value array); |
109 | }; |
110 | |
111 | Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, |
112 | Type structType, Value source) { |
113 | Value metaData = builder.create<LLVM::UndefOp>(loc, structType); |
114 | SpecifierStructBuilder md(metaData); |
115 | if (!source) { |
116 | auto memSizeArrayType = |
117 | cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType) |
118 | .getBody()[kMemSizePosInSpecifier]); |
119 | |
120 | Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); |
121 | // Fill memSizes array with zero. |
122 | for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) |
123 | md.setMemSize(builder, loc, fidx: i, size: zero); |
124 | } else { |
125 | // We copy non-slice information (memory sizes array) from source |
126 | SpecifierStructBuilder sourceMd(source); |
127 | md.setMemSizeArray(builder, loc, array: sourceMd.memSizeArray(builder, loc)); |
128 | } |
129 | return md; |
130 | } |
131 | |
132 | /// Builds IR extracting the pos-th offset from the descriptor. |
133 | Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc, |
134 | Dimension dim) const { |
135 | return extractField( |
136 | builder, loc, |
137 | indices: ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}); |
138 | } |
139 | |
140 | /// Builds IR inserting the pos-th offset into the descriptor. |
141 | void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc, |
142 | Dimension dim, Value size) { |
143 | insertField( |
144 | builder, loc, |
145 | indices: ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}, |
146 | v: size); |
147 | } |
148 | |
149 | /// Builds IR extracting the `lvl`-th level-size from the descriptor. |
150 | Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc, |
151 | Level lvl) const { |
152 | // This static_cast makes the narrowing of `lvl` explicit, as required |
153 | // by the braces notation for the ctor. |
154 | return extractField( |
155 | builder, loc, |
156 | indices: ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)}); |
157 | } |
158 | |
159 | /// Builds IR inserting the `lvl`-th level-size into the descriptor. |
160 | void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc, |
161 | Level lvl, Value size) { |
162 | // This static_cast makes the narrowing of `lvl` explicit, as required |
163 | // by the braces notation for the ctor. |
164 | insertField( |
165 | builder, loc, |
166 | indices: ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)}, |
167 | v: size); |
168 | } |
169 | |
170 | /// Builds IR extracting the pos-th stride from the descriptor. |
171 | Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc, |
172 | Dimension dim) const { |
173 | return extractField( |
174 | builder, loc, |
175 | indices: ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)}); |
176 | } |
177 | |
178 | /// Builds IR inserting the pos-th stride into the descriptor. |
179 | void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc, |
180 | Dimension dim, Value size) { |
181 | insertField( |
182 | builder, loc, |
183 | indices: ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)}, |
184 | v: size); |
185 | } |
186 | |
187 | /// Builds IR extracting the pos-th memory size into the descriptor. |
188 | Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, |
189 | FieldIndex fidx) const { |
190 | return extractField( |
191 | builder, loc, |
192 | indices: ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)}); |
193 | } |
194 | |
195 | /// Builds IR inserting the `fidx`-th memory-size into the descriptor. |
196 | void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, |
197 | FieldIndex fidx, Value size) { |
198 | insertField( |
199 | builder, loc, |
200 | indices: ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)}, |
201 | v: size); |
202 | } |
203 | |
204 | /// Builds IR extracting the memory size array from the descriptor. |
205 | Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder, |
206 | Location loc) const { |
207 | return builder.create<LLVM::ExtractValueOp>(loc, value, |
208 | kMemSizePosInSpecifier); |
209 | } |
210 | |
211 | /// Builds IR inserting the memory size array into the descriptor. |
212 | void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc, |
213 | Value array) { |
214 | value = builder.create<LLVM::InsertValueOp>(loc, value, array, |
215 | kMemSizePosInSpecifier); |
216 | } |
217 | |
218 | } // namespace |
219 | |
220 | //===----------------------------------------------------------------------===// |
221 | // The sparse storage specifier type converter (defined in Passes.h). |
222 | //===----------------------------------------------------------------------===// |
223 | |
224 | StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() { |
225 | addConversion(callback: [](Type type) { return type; }); |
226 | addConversion(convertSpecifier); |
227 | } |
228 | |
229 | //===----------------------------------------------------------------------===// |
230 | // Storage specifier conversion rules. |
231 | //===----------------------------------------------------------------------===// |
232 | |
233 | template <typename Base, typename SourceOp> |
234 | class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> { |
235 | public: |
236 | using OpAdaptor = typename SourceOp::Adaptor; |
237 | using OpConversionPattern<SourceOp>::OpConversionPattern; |
238 | |
239 | LogicalResult |
240 | matchAndRewrite(SourceOp op, OpAdaptor adaptor, |
241 | ConversionPatternRewriter &rewriter) const override { |
242 | SpecifierStructBuilder spec(adaptor.getSpecifier()); |
243 | switch (op.getSpecifierKind()) { |
244 | case StorageSpecifierKind::LvlSize: { |
245 | Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel())); |
246 | rewriter.replaceOp(op, v); |
247 | return success(); |
248 | } |
249 | case StorageSpecifierKind::DimOffset: { |
250 | Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel())); |
251 | rewriter.replaceOp(op, v); |
252 | return success(); |
253 | } |
254 | case StorageSpecifierKind::DimStride: { |
255 | Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel())); |
256 | rewriter.replaceOp(op, v); |
257 | return success(); |
258 | } |
259 | case StorageSpecifierKind::CrdMemSize: |
260 | case StorageSpecifierKind::PosMemSize: |
261 | case StorageSpecifierKind::ValMemSize: { |
262 | auto enc = op.getSpecifier().getType().getEncoding(); |
263 | StorageLayout layout(enc); |
264 | std::optional<unsigned> lvl; |
265 | if (op.getLevel()) |
266 | lvl = (*op.getLevel()); |
267 | unsigned idx = |
268 | layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl); |
269 | Value v = Base::onMemSize(rewriter, op, spec, idx); |
270 | rewriter.replaceOp(op, v); |
271 | return success(); |
272 | } |
273 | } |
274 | llvm_unreachable("unrecognized specifer kind" ); |
275 | } |
276 | }; |
277 | |
278 | struct StorageSpecifierSetOpConverter |
279 | : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter, |
280 | SetStorageSpecifierOp> { |
281 | using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; |
282 | |
283 | static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, |
284 | SpecifierStructBuilder &spec, Level lvl) { |
285 | spec.setLvlSize(builder, loc: op.getLoc(), lvl, size: op.getValue()); |
286 | return spec; |
287 | } |
288 | |
289 | static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, |
290 | SpecifierStructBuilder &spec, Dimension d) { |
291 | spec.setDimOffset(builder, loc: op.getLoc(), dim: d, size: op.getValue()); |
292 | return spec; |
293 | } |
294 | |
295 | static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, |
296 | SpecifierStructBuilder &spec, Dimension d) { |
297 | spec.setDimStride(builder, loc: op.getLoc(), dim: d, size: op.getValue()); |
298 | return spec; |
299 | } |
300 | |
301 | static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, |
302 | SpecifierStructBuilder &spec, FieldIndex fidx) { |
303 | spec.setMemSize(builder, loc: op.getLoc(), fidx, size: op.getValue()); |
304 | return spec; |
305 | } |
306 | }; |
307 | |
308 | struct StorageSpecifierGetOpConverter |
309 | : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter, |
310 | GetStorageSpecifierOp> { |
311 | using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; |
312 | |
313 | static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, |
314 | SpecifierStructBuilder &spec, Level lvl) { |
315 | return spec.lvlSize(builder, loc: op.getLoc(), lvl); |
316 | } |
317 | |
318 | static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, |
319 | const SpecifierStructBuilder &spec, Dimension d) { |
320 | return spec.dimOffset(builder, loc: op.getLoc(), dim: d); |
321 | } |
322 | |
323 | static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, |
324 | const SpecifierStructBuilder &spec, Dimension d) { |
325 | return spec.dimStride(builder, loc: op.getLoc(), dim: d); |
326 | } |
327 | |
328 | static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, |
329 | SpecifierStructBuilder &spec, FieldIndex fidx) { |
330 | return spec.memSize(builder, loc: op.getLoc(), fidx); |
331 | } |
332 | }; |
333 | |
334 | struct StorageSpecifierInitOpConverter |
335 | : public OpConversionPattern<StorageSpecifierInitOp> { |
336 | public: |
337 | using OpConversionPattern::OpConversionPattern; |
338 | LogicalResult |
339 | matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, |
340 | ConversionPatternRewriter &rewriter) const override { |
341 | Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); |
342 | rewriter.replaceOp( |
343 | op, SpecifierStructBuilder::getInitValue( |
344 | builder&: rewriter, loc: op.getLoc(), structType: llvmType, source: adaptor.getSource())); |
345 | return success(); |
346 | } |
347 | }; |
348 | |
349 | //===----------------------------------------------------------------------===// |
350 | // Public method for populating conversion rules. |
351 | //===----------------------------------------------------------------------===// |
352 | |
353 | void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, |
354 | RewritePatternSet &patterns) { |
355 | patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter, |
356 | StorageSpecifierInitOpConverter>(arg&: converter, |
357 | args: patterns.getContext()); |
358 | } |
359 | |