1 | //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // A pass that converts sparse tensor types and primitives to actual compiler |
10 | // visible buffers and actual compiler IR that implements these primitives on |
11 | // the selected sparse tensor storage schemes. This pass provides an alternative |
12 | // to the SparseTensorConversion pass, eliminating the dependence on a runtime |
13 | // support library (other than for file I/O), and providing many more |
14 | // opportunities for subsequent compiler optimization of the generated code. |
15 | // |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #include "Utils/CodegenUtils.h" |
19 | #include "Utils/SparseTensorDescriptor.h" |
20 | |
21 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
22 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
23 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
24 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
25 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
26 | #include "mlir/Dialect/SparseTensor/IR/Enums.h" |
27 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
28 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
29 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
30 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
31 | #include "mlir/Transforms/DialectConversion.h" |
32 | |
33 | #include <optional> |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::sparse_tensor; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Helper methods. |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// Flattens a list of operands that may contain sparse tensors. |
43 | static void flattenOperands(ValueRange operands, |
44 | SmallVectorImpl<Value> &flattened) { |
45 | // In case of |
46 | // sparse_tensor, c, sparse_tensor |
47 | // ==> |
48 | // memref ..., c, memref ... |
49 | for (auto operand : operands) { |
50 | if (getSparseTensorEncoding(type: operand.getType())) { |
51 | auto tuple = getTuple(operand); |
52 | // An unrealized_conversion_cast will be inserted by type converter to |
53 | // inter-mix the gap between 1:N conversion between sparse tensors and |
54 | // fields. In this case, take the operands in the cast and replace the |
55 | // sparse tensor output with the flattened type array. |
56 | flattened.append(tuple.getOperands().begin(), tuple.getOperands().end()); |
57 | } else { |
58 | flattened.push_back(Elt: operand); |
59 | } |
60 | } |
61 | } |
62 | |
63 | /// Generates a load with proper `index` typing. |
64 | static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) { |
65 | idx = genCast(builder, loc, idx, builder.getIndexType()); |
66 | return builder.create<memref::LoadOp>(loc, mem, idx); |
67 | } |
68 | |
69 | /// Generates a store with proper `index` typing and proper value. |
70 | static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, |
71 | Value idx) { |
72 | idx = genCast(builder, loc, idx, builder.getIndexType()); |
73 | val = genCast(builder, loc, val, |
74 | cast<ShapedType>(mem.getType()).getElementType()); |
75 | builder.create<memref::StoreOp>(loc, val, mem, idx); |
76 | } |
77 | |
78 | /// Creates a straightforward counting for-loop. |
79 | static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, |
80 | MutableArrayRef<Value> fields, |
81 | Value lower = Value()) { |
82 | Type indexType = builder.getIndexType(); |
83 | if (!lower) |
84 | lower = constantZero(builder, loc, tp: indexType); |
85 | Value one = constantOne(builder, loc, tp: indexType); |
86 | scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields); |
87 | for (unsigned i = 0, e = fields.size(); i < e; i++) |
88 | fields[i] = forOp.getRegionIterArg(i); |
89 | builder.setInsertionPointToStart(forOp.getBody()); |
90 | return forOp; |
91 | } |
92 | |
93 | /// Creates a push back operation. |
94 | static void createPushback(OpBuilder &builder, Location loc, |
95 | MutSparseTensorDescriptor desc, |
96 | SparseTensorFieldKind kind, std::optional<Level> lvl, |
97 | Value value, Value repeat = Value()) { |
98 | Type etp = desc.getMemRefElementType(kind, lvl); |
99 | Value field = desc.getMemRefField(kind, lvl); |
100 | StorageSpecifierKind specFieldKind = toSpecifierKind(kind); |
101 | |
102 | auto pushBackOp = builder.create<PushBackOp>( |
103 | loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, |
104 | genCast(builder, loc, value, etp), repeat); |
105 | |
106 | desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer()); |
107 | desc.setSpecifierField(builder, loc, specFieldKind, lvl, |
108 | pushBackOp.getNewSize()); |
109 | } |
110 | |
111 | /// Generates code that allocates a sparse storage scheme for given rank. |
112 | static void allocSchemeForRank(OpBuilder &builder, Location loc, |
113 | MutSparseTensorDescriptor desc, Level startLvl) { |
114 | const SparseTensorType stt(desc.getRankedTensorType()); |
115 | Value linear = constantIndex(builder, loc, i: 1); |
116 | const Level lvlRank = stt.getLvlRank(); |
117 | for (Level lvl = startLvl; lvl < lvlRank; lvl++) { |
118 | const auto lt = stt.getLvlType(l: lvl); |
119 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { |
120 | // Append linear x positions, initialized to zero. Since each compressed |
121 | // dimension initially already has a single zero entry, this maintains |
122 | // the desired "linear + 1" length property at all times. For loose |
123 | // compression, we multiply linear by two in order to append both the |
124 | // lo/hi positions. |
125 | Value posZero = constantZero(builder, loc, tp: stt.getPosType()); |
126 | if (isLooseCompressedLT(lt)) { |
127 | Value two = constantIndex(builder, loc, i: 2); |
128 | linear = builder.create<arith::MulIOp>(loc, linear, two); |
129 | } |
130 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl, |
131 | /*value=*/posZero, /*repeat=*/linear); |
132 | return; |
133 | } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { |
134 | return; // nothing to do |
135 | } |
136 | // Keep compounding the size, but nothing needs to be initialized |
137 | // at this level. We will eventually reach a compressed level or |
138 | // otherwise the values array for the from-here "all-dense" case. |
139 | assert(isDenseLT(lt)); |
140 | Value size = desc.getLvlSize(builder, loc, lvl); |
141 | linear = builder.create<arith::MulIOp>(loc, linear, size); |
142 | } |
143 | // Reached values array so prepare for an insertion. |
144 | Value valZero = constantZero(builder, loc, tp: stt.getElementType()); |
145 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef, |
146 | lvl: std::nullopt, /*value=*/valZero, /*repeat=*/linear); |
147 | } |
148 | |
149 | /// Creates allocation operation. |
150 | static Value createAllocation(OpBuilder &builder, Location loc, |
151 | MemRefType memRefType, Value sz, |
152 | bool enableInit) { |
153 | Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz); |
154 | Type elemType = memRefType.getElementType(); |
155 | if (enableInit) { |
156 | Value fillValue = constantZero(builder, loc, tp: elemType); |
157 | builder.create<linalg::FillOp>(loc, fillValue, buffer); |
158 | } |
159 | return buffer; |
160 | } |
161 | |
162 | /// Creates the dim sizes array, filling in from dynamic sizes. |
163 | static void createDimSizes(OpBuilder &builder, Location loc, |
164 | SparseTensorType stt, ValueRange dynSizes, |
165 | /*out*/ SmallVectorImpl<Value> &dimSizesValues) { |
166 | const Dimension dimRank = stt.getDimRank(); |
167 | dimSizesValues.clear(); |
168 | dimSizesValues.reserve(N: dimRank); |
169 | unsigned i = 0; |
170 | for (const Size sz : stt.getDimShape()) |
171 | dimSizesValues.push_back(ShapedType::isDynamic(sz) |
172 | ? dynSizes[i++] |
173 | : constantIndex(builder, loc, sz)); |
174 | } |
175 | |
176 | /// Creates allocation for each field in sparse tensor type. Note that |
177 | /// for all dynamic memrefs in the sparse tensor stroage layout, the |
178 | /// memory size is really the capacity of the "vector", while the actual |
179 | /// size resides in the sizes array. |
180 | static void createAllocFields(OpBuilder &builder, Location loc, |
181 | SparseTensorType stt, bool enableInit, |
182 | Value sizeHint, |
183 | SmallVectorImpl<Value> &lvlSizesValues, |
184 | /*out*/ SmallVectorImpl<Value> &fields) { |
185 | Level lvlRank = stt.getLvlRank(); |
186 | // Set up some heuristic sizes. We try to set the initial |
187 | // size based on available information. Otherwise we just |
188 | // initialize a few elements to start the reallocation chain. |
189 | // TODO: refine this |
190 | Value posHeuristic, crdHeuristic, valHeuristic; |
191 | if (stt.isAllDense()) { |
192 | valHeuristic = lvlSizesValues[0]; |
193 | for (Level lvl = 1; lvl < lvlRank; lvl++) |
194 | valHeuristic = |
195 | builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]); |
196 | } else if (sizeHint) { |
197 | if (stt.getAoSCOOStart() == 0) { |
198 | posHeuristic = constantIndex(builder, loc, i: 2); |
199 | crdHeuristic = builder.create<arith::MulIOp>( |
200 | loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS |
201 | } else if (lvlRank == 2 && stt.isDenseLvl(l: 0) && stt.isCompressedLvl(l: 1)) { |
202 | posHeuristic = builder.create<arith::AddIOp>( |
203 | loc, sizeHint, constantIndex(builder, loc, 1)); |
204 | crdHeuristic = sizeHint; |
205 | } else { |
206 | posHeuristic = crdHeuristic = constantIndex(builder, loc, i: 16); |
207 | } |
208 | valHeuristic = sizeHint; |
209 | } else { |
210 | posHeuristic = crdHeuristic = valHeuristic = |
211 | constantIndex(builder, loc, i: 16); |
212 | } |
213 | // Initializes all fields. An initial storage specifier and allocated |
214 | // positions/coordinates/values memrefs (with heuristic capacity). |
215 | foreachFieldAndTypeInSparseTensor( |
216 | stt, |
217 | [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic, |
218 | enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, |
219 | Level /*lvl*/, LevelType /*lt*/) -> bool { |
220 | assert(fields.size() == fIdx); |
221 | Value field; |
222 | switch (fKind) { |
223 | case SparseTensorFieldKind::StorageSpec: |
224 | field = SparseTensorSpecifier::getInitValue(builder, loc, stt); |
225 | break; |
226 | case SparseTensorFieldKind::PosMemRef: |
227 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
228 | posHeuristic, enableInit); |
229 | break; |
230 | case SparseTensorFieldKind::CrdMemRef: |
231 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
232 | crdHeuristic, enableInit); |
233 | break; |
234 | case SparseTensorFieldKind::ValMemRef: |
235 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
236 | valHeuristic, enableInit); |
237 | break; |
238 | } |
239 | assert(field); |
240 | fields.push_back(Elt: field); |
241 | // Returns true to continue the iteration. |
242 | return true; |
243 | }); |
244 | // Initialize the storage scheme to an empty tensor. Sets the lvlSizes |
245 | // and gives all position fields an initial zero entry, so that it is |
246 | // easier to maintain the "linear + 1" length property. |
247 | MutSparseTensorDescriptor desc(stt, fields); |
248 | Value posZero = constantZero(builder, loc, tp: stt.getPosType()); |
249 | for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { |
250 | desc.setLvlSize(builder, loc, lvl, v: lvlSizesValues[lvl]); |
251 | const auto lt = stt.getLvlType(l: lvl); |
252 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) |
253 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl, |
254 | /*value=*/posZero); |
255 | } |
256 | allocSchemeForRank(builder, loc, desc, /*rank=*/startLvl: 0); |
257 | } |
258 | |
259 | /// Helper method that generates block specific to compressed case: |
260 | /// |
261 | /// // given: parentPos = posCursor[lvl-1] |
262 | /// pstart = desc.positions[lvl][parentPos] |
263 | /// pstop = desc.positions[lvl][parentPos+1] |
264 | /// plast = pstop - 1 |
265 | /// msz = desc.coordinates[lvl].size() |
266 | /// if (pstart < pstop) { |
267 | /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl]) |
268 | /// } else { // first insertion |
269 | /// isPresent = false |
270 | /// desc.positions[lvl][parentPos] = msz |
271 | /// } |
272 | /// if (isPresent) { // coordinate is already present |
273 | /// pnext = plast |
274 | /// } else { |
275 | /// desc.coordinates[lvl].push_back(lvlCoords[lvl]) |
276 | /// desc.positions[lvl][parentPos+1] = msz+1 |
277 | /// pnext = msz |
278 | /// <prepare level lvl+1> |
279 | /// } |
280 | /// posCursor[lvl] = pnext |
281 | static Value genCompressed(OpBuilder &builder, Location loc, |
282 | MutSparseTensorDescriptor desc, ValueRange lvlCoords, |
283 | Value /*unused*/, Value parentPos, Level lvl) { |
284 | const SparseTensorType stt(desc.getRankedTensorType()); |
285 | const Level lvlRank = stt.getLvlRank(); |
286 | assert(lvl < lvlRank && "Level is out of bounds" ); |
287 | assert(lvlCoords.size() == static_cast<size_t>(lvlRank) && |
288 | "Level-rank mismatch" ); |
289 | SmallVector<Type> types; |
290 | Type indexType = builder.getIndexType(); |
291 | Type boolType = builder.getIntegerType(1); |
292 | unsigned crdFidx; |
293 | unsigned crdStride; |
294 | std::tie(args&: crdFidx, args&: crdStride) = desc.getCrdMemRefIndexAndStride(lvl); |
295 | const Value one = constantIndex(builder, loc, i: 1); |
296 | const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one); |
297 | const Value positionsAtLvl = desc.getPosMemRef(lvl); |
298 | const Value pstart = genLoad(builder, loc, mem: positionsAtLvl, idx: parentPos); |
299 | const Value pstop = genLoad(builder, loc, mem: positionsAtLvl, idx: pp1); |
300 | const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl); |
301 | const Value crdStrideC = |
302 | crdStride > 1 ? constantIndex(builder, loc, i: crdStride) : Value(); |
303 | const Value msz = |
304 | crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC) |
305 | : crdMsz; |
306 | const Value plast = builder.create<arith::SubIOp>( |
307 | loc, genCast(builder, loc, pstop, indexType), one); |
308 | // Conditional expression. |
309 | Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
310 | pstart, pstop); |
311 | types.push_back(Elt: boolType); |
312 | scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true); |
313 | types.pop_back(); |
314 | builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); |
315 | Value crd = |
316 | genLoad(builder, loc, desc.getMemRefField(crdFidx), |
317 | crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC) |
318 | : plast); |
319 | Value eq = builder.create<arith::CmpIOp>( |
320 | loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType), |
321 | lvlCoords[lvl]); |
322 | builder.create<scf::YieldOp>(loc, eq); |
323 | builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); |
324 | if (lvl > 0) |
325 | genStore(builder, loc, val: msz, mem: positionsAtLvl, idx: parentPos); |
326 | builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false)); |
327 | builder.setInsertionPointAfter(ifOp1); |
328 | // If present construct. Note that for a non-unique dimension level, we |
329 | // simply set the condition to false and rely on CSE/DCE to clean up the IR. |
330 | // |
331 | // TODO: generate less temporary IR? |
332 | // |
333 | for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) |
334 | types.push_back(Elt: desc.getField(i).getType()); |
335 | types.push_back(Elt: indexType); |
336 | const Value p = stt.isUniqueLvl(l: lvl) ? ifOp1.getResult(0) |
337 | : constantI1(builder, loc, b: false); |
338 | scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true); |
339 | // If present (fields unaffected, update pnext to plast). |
340 | builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); |
341 | |
342 | // FIXME: This does not looks like a clean way, but probably the most |
343 | // efficient way. |
344 | desc.getFields().push_back(plast); |
345 | builder.create<scf::YieldOp>(loc, desc.getFields()); |
346 | desc.getFields().pop_back(); |
347 | |
348 | // If !present (changes fields, update pnext). |
349 | builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); |
350 | Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one); |
351 | genStore(builder, loc, val: mszp1, mem: positionsAtLvl, idx: pp1); |
352 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef, lvl, |
353 | /*value=*/lvlCoords[lvl]); |
354 | // Prepare the next level "as needed". |
355 | if ((lvl + 1) < lvlRank) |
356 | allocSchemeForRank(builder, loc, desc, startLvl: lvl + 1); |
357 | |
358 | desc.getFields().push_back(msz); |
359 | builder.create<scf::YieldOp>(loc, desc.getFields()); |
360 | desc.getFields().pop_back(); |
361 | |
362 | // Update fields and return next pos. |
363 | builder.setInsertionPointAfter(ifOp2); |
364 | unsigned o = 0; |
365 | for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) |
366 | desc.setField(fidx: i, v: ifOp2.getResult(o++)); |
367 | return ifOp2.getResult(o); |
368 | } |
369 | |
370 | /// Generates insertion finalization code. |
371 | static void genEndInsert(OpBuilder &builder, Location loc, |
372 | SparseTensorDescriptor desc) { |
373 | const SparseTensorType stt(desc.getRankedTensorType()); |
374 | const Level lvlRank = stt.getLvlRank(); |
375 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
376 | const auto lt = stt.getLvlType(l: lvl); |
377 | if (isCompressedLT(lt)) { |
378 | // Compressed dimensions need a position cleanup for all entries |
379 | // that were not visited during the insertion pass. |
380 | // |
381 | // TODO: avoid cleanup and keep compressed scheme consistent at all |
382 | // times? |
383 | // |
384 | if (lvl > 0) { |
385 | Type posType = stt.getPosType(); |
386 | Value posMemRef = desc.getPosMemRef(lvl); |
387 | Value hi = desc.getPosMemSize(builder, loc, lvl); |
388 | Value zero = constantIndex(builder, loc, i: 0); |
389 | Value one = constantIndex(builder, loc, i: 1); |
390 | // Vector of only one, but needed by createFor's prototype. |
391 | SmallVector<Value, 1> inits{genLoad(builder, loc, mem: posMemRef, idx: zero)}; |
392 | scf::ForOp loop = createFor(builder, loc, hi, inits, one); |
393 | Value i = loop.getInductionVar(); |
394 | Value oldv = loop.getRegionIterArg(0); |
395 | Value newv = genLoad(builder, loc, mem: posMemRef, idx: i); |
396 | Value posZero = constantZero(builder, loc, tp: posType); |
397 | Value cond = builder.create<arith::CmpIOp>( |
398 | loc, arith::CmpIPredicate::eq, newv, posZero); |
399 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), |
400 | cond, /*else*/ true); |
401 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
402 | genStore(builder, loc, val: oldv, mem: posMemRef, idx: i); |
403 | builder.create<scf::YieldOp>(loc, oldv); |
404 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
405 | builder.create<scf::YieldOp>(loc, newv); |
406 | builder.setInsertionPointAfter(ifOp); |
407 | builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
408 | builder.setInsertionPointAfter(loop); |
409 | } |
410 | } else { |
411 | assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) || |
412 | isNOutOfMLT(lt)); |
413 | } |
414 | } |
415 | } |
416 | |
417 | /// Generates a subview into the sizes. |
418 | static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, |
419 | Value sz) { |
420 | auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType(); |
421 | return builder |
422 | .create<memref::SubViewOp>( |
423 | loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem, |
424 | ValueRange{}, ValueRange{sz}, ValueRange{}, |
425 | ArrayRef<int64_t>{0}, // static offset |
426 | ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size |
427 | ArrayRef<int64_t>{1}) // static stride |
428 | .getResult(); |
429 | } |
430 | |
431 | /// Creates the reassociation array. |
432 | static SmallVector<ReassociationIndices> |
433 | getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) { |
434 | SmallVector<ReassociationIndices> ret(batchLvls + 1, {}); |
435 | // Create reassociation in the form: |
436 | // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank} |
437 | for (unsigned i = 0; i < batchLvls; i++) |
438 | ret[i].push_back(Elt: i); |
439 | |
440 | for (int i = batchLvls, e = srcTp.getRank(); i < e; i++) |
441 | ret.back().push_back(Elt: i); |
442 | |
443 | return ret; |
444 | } |
445 | |
446 | //===----------------------------------------------------------------------===// |
447 | // Codegen rules. |
448 | //===----------------------------------------------------------------------===// |
449 | |
450 | namespace { |
451 | |
452 | /// Helper class to help lowering sparse_tensor.insert operation. |
453 | class SparseInsertGenerator |
454 | : public FuncCallOrInlineGenerator<SparseInsertGenerator> { |
455 | public: |
456 | SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params, |
457 | bool genCall) |
458 | : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){}; |
459 | |
460 | /// Generates code along an insertion path without the need for a "cursor". |
461 | /// This current insertion strategy comes at the expense of some testing |
462 | /// overhead for each insertion. The strategy will be optimized later for |
463 | /// common insertion patterns. The current insertion strategy also assumes |
464 | /// insertions occur in "a reasonable order" that enables building the |
465 | /// storage scheme in an appending/inserting kind of fashion (i.e. no |
466 | /// in-between insertions that need data movement). The implementation |
467 | /// relies on CSE/DCE to clean up all bookkeeping that is not needed. |
468 | /// |
469 | /// TODO: better unord/not-unique; also generalize, optimize, specialize! |
470 | SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args, |
471 | OpBuilder &builder, Location loc) { |
472 | const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); |
473 | const Level lvlRank = stt.getLvlRank(); |
474 | // Extract fields and coordinates from args. |
475 | SmallVector<Value> fields = llvm::to_vector(Range: args.drop_back(n: lvlRank + 1)); |
476 | MutSparseTensorDescriptor desc(stt, fields); |
477 | const SmallVector<Value> coords = |
478 | llvm::to_vector(Range: args.take_back(n: lvlRank + 1).drop_back()); |
479 | Value value = args.back(); |
480 | Value parentPos = constantZero(builder, loc, builder.getIndexType()); |
481 | // Generate code for every level. |
482 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
483 | const auto lt = stt.getLvlType(l: lvl); |
484 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { |
485 | // Create: |
486 | // if (!present) { |
487 | // coordinates[lvl].push_back(coords[lvl]) |
488 | // <update positions and prepare level lvl + 1> |
489 | // } |
490 | // positions[lvl] = coordinates.size() - 1 |
491 | // <insert @ positions[lvl] at next level lvl + 1> |
492 | if (isLooseCompressedLT(lt)) { |
493 | Value two = constantIndex(builder, loc, i: 2); |
494 | parentPos = builder.create<arith::MulIOp>(loc, parentPos, two); |
495 | } |
496 | parentPos = |
497 | genCompressed(builder, loc, desc, lvlCoords: coords, value, parentPos, lvl); |
498 | } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { |
499 | // Create: |
500 | // coordinates[lvl].push_back(coords[lvl]) |
501 | // positions[lvl] = positions[lvl-1] |
502 | // <insert @ positions[lvl] at next level lvl + 1> |
503 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef, |
504 | lvl, /*value=*/coords[lvl]); |
505 | } else { |
506 | assert(isDenseLT(lt)); |
507 | // Construct the new position as: |
508 | // positions[lvl] = size * positions[lvl-1] + coords[lvl] |
509 | // <insert @ positions[lvl] at next level lvl + 1> |
510 | Value size = desc.getLvlSize(builder, loc, lvl); |
511 | Value mult = builder.create<arith::MulIOp>(loc, size, parentPos); |
512 | parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]); |
513 | } |
514 | } |
515 | // Reached the actual value append/insert. |
516 | if (!stt.isDenseLvl(l: lvlRank - 1)) |
517 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef, |
518 | lvl: std::nullopt, value); |
519 | else |
520 | genStore(builder, loc, value, desc.getValMemRef(), parentPos); |
521 | return fields; |
522 | } |
523 | |
524 | std::string getMangledFuncName() { |
525 | // The mangled name of the function has this format: |
526 | // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth> |
527 | constexpr const char kInsertFuncNamePrefix[] = "_insert_" ; |
528 | const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); |
529 | SmallString<32> nameBuffer; |
530 | llvm::raw_svector_ostream nameOstream(nameBuffer); |
531 | nameOstream << kInsertFuncNamePrefix; |
532 | const Level lvlRank = stt.getLvlRank(); |
533 | for (Level l = 0; l < lvlRank; l++) { |
534 | std::string lvlType = toMLIRString(lt: stt.getLvlType(l)); |
535 | // Replace/remove punctuations in level properties. |
536 | std::replace_if( |
537 | first: lvlType.begin(), last: lvlType.end(), |
538 | pred: [](char c) { return c == '(' || c == ','; }, new_value: '_'); |
539 | llvm::erase_if(C&: lvlType, P: [](char c) { return c == ')' || c == ' '; }); |
540 | nameOstream << lvlType << "_" ; |
541 | } |
542 | // Static dim sizes are used in the generated code while dynamic sizes are |
543 | // loaded from the dimSizes buffer. This is the reason for adding the shape |
544 | // to the function name. |
545 | for (const auto sz : stt.getDimShape()) |
546 | nameOstream << sz << "_" ; |
547 | // Permutation information is also used in generating insertion. |
548 | if (!stt.isIdentity()) |
549 | nameOstream << stt.getDimToLvl() << "_" ; |
550 | nameOstream << stt.getElementType() << "_" ; |
551 | nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth(); |
552 | return nameOstream.str().str(); |
553 | } |
554 | |
555 | private: |
556 | TensorType rtp; |
557 | }; |
558 | |
559 | /// Sparse tensor storage conversion rule for returns. |
560 | class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { |
561 | public: |
562 | using OpConversionPattern::OpConversionPattern; |
563 | LogicalResult |
564 | matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, |
565 | ConversionPatternRewriter &rewriter) const override { |
566 | SmallVector<Value> flattened; |
567 | flattenOperands(adaptor.getOperands(), flattened); |
568 | // Create a return with the flattened value extracted from sparse tensors. |
569 | rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened); |
570 | return success(); |
571 | } |
572 | }; |
573 | |
574 | /// Sparse tensor storage conversion rule for calls. |
575 | class SparseCallConverter : public OpConversionPattern<func::CallOp> { |
576 | public: |
577 | // The default CallOp converter can not handle 1:N type conversion. |
578 | using OpConversionPattern::OpConversionPattern; |
579 | LogicalResult |
580 | matchAndRewrite(func::CallOp op, OpAdaptor adaptor, |
581 | ConversionPatternRewriter &rewriter) const override { |
582 | Location loc = op.getLoc(); |
583 | // In case of: |
584 | // sparse_tensor, f, sparse_tensor = call @foo(...) |
585 | // ==> |
586 | // memref..., f, memref = call @foo(...) replace with |
587 | // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor |
588 | SmallVector<Type> finalRetTy; |
589 | if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy))) |
590 | return failure(); |
591 | |
592 | // (1) Generates new call with flattened return value. |
593 | SmallVector<Value> flattened; |
594 | flattenOperands(adaptor.getOperands(), flattened); |
595 | auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(), |
596 | finalRetTy, flattened); |
597 | // (2) Create cast operation for sparse tensor returns. |
598 | SmallVector<Value> castedRet; |
599 | // Tracks the offset of current return value (of the original call) |
600 | // relative to the new call (after sparse tensor flattening); |
601 | unsigned retOffset = 0; |
602 | // Temporal buffer to hold the flattened list of type for |
603 | // a sparse tensor. |
604 | SmallVector<Type> sparseFlat; |
605 | for (auto ret : op.getResults()) { |
606 | assert(retOffset < newCall.getNumResults()); |
607 | auto retType = ret.getType(); |
608 | if (failed(typeConverter->convertType(retType, sparseFlat))) |
609 | llvm_unreachable("Failed to convert type in sparse tensor codegen" ); |
610 | |
611 | // Converted types can not be empty when the type conversion succeed. |
612 | assert(!sparseFlat.empty()); |
613 | if (sparseFlat.size() > 1) { |
614 | auto flatSize = sparseFlat.size(); |
615 | ValueRange fields(iterator_range<ResultRange::iterator>( |
616 | newCall.result_begin() + retOffset, |
617 | newCall.result_begin() + retOffset + flatSize)); |
618 | castedRet.push_back(genTuple(rewriter, loc, retType, fields)); |
619 | retOffset += flatSize; |
620 | } else { |
621 | // If this is an 1:1 conversion, no need for casting. |
622 | castedRet.push_back(newCall.getResult(retOffset)); |
623 | retOffset++; |
624 | } |
625 | sparseFlat.clear(); |
626 | } |
627 | |
628 | assert(castedRet.size() == op.getNumResults()); |
629 | rewriter.replaceOp(op, castedRet); |
630 | return success(); |
631 | } |
632 | }; |
633 | |
634 | /// Sparse codegen rule for level accesses. |
635 | class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { |
636 | public: |
637 | using OpConversionPattern::OpConversionPattern; |
638 | LogicalResult |
639 | matchAndRewrite(LvlOp op, OpAdaptor adaptor, |
640 | ConversionPatternRewriter &rewriter) const override { |
641 | std::optional<int64_t> lvl = op.getConstantLvlIndex(); |
642 | if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType())) |
643 | return failure(); |
644 | |
645 | auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); |
646 | auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl); |
647 | |
648 | rewriter.replaceOp(op, sz); |
649 | return success(); |
650 | } |
651 | }; |
652 | |
653 | // TODO: use a new SortCOO operation here instead of reusing convert op. |
654 | struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { |
655 | using OpConversionPattern::OpConversionPattern; |
656 | LogicalResult |
657 | matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor, |
658 | ConversionPatternRewriter &rewriter) const override { |
659 | Location loc = op.getLoc(); |
660 | MLIRContext *ctx = op.getContext(); |
661 | |
662 | SparseTensorType srcStt = getSparseTensorType(op.getInputCoo()); |
663 | SparseTensorType dstStt = getSparseTensorType(op.getResultCoo()); |
664 | |
665 | // Should have been verified. |
666 | assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() && |
667 | dstStt.isCOOType() && srcStt.isCOOType()); |
668 | assert(dstStt.hasSameDimToLvl(srcStt)); |
669 | |
670 | // We don't need a mutable descriptor here as we perform sorting in-place. |
671 | auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo()); |
672 | auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo()); |
673 | auto crd = desc.getAOSMemRef(); |
674 | auto val = desc.getValMemRef(); |
675 | |
676 | // Otherwise we need another data shuffle and a non-identity map. |
677 | assert(dstStt.hasSameDimToLvl(srcStt)); |
678 | (void)dstStt; // to silence warning when assertion is disabled |
679 | |
680 | auto id = AffineMap::getMultiDimIdentityMap(numDims: srcStt.getLvlRank(), context: ctx); |
681 | |
682 | rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id, |
683 | rewriter.getIndexAttr(0), op.getAlgorithm()); |
684 | |
685 | // Since we do in-place sorting, the destinate tensor will have the same set |
686 | // of memrefs as the source tensor. |
687 | rewriter.replaceOp(op, adaptor.getInputCoo()); |
688 | return success(); |
689 | } |
690 | }; |
691 | |
692 | template <typename Op, StorageSpecifierKind kind> |
693 | class SparseSliceGetterOpConverter : public OpConversionPattern<Op> { |
694 | public: |
695 | using OpConversionPattern<Op>::OpConversionPattern; |
696 | LogicalResult |
697 | matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
698 | ConversionPatternRewriter &rewriter) const override { |
699 | // Simply lowers to specifer.get <field> operation. |
700 | auto desc = getDescriptorFromTensorTuple(adaptor.getSlice()); |
701 | auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, |
702 | op.getDim().getZExtValue()); |
703 | |
704 | rewriter.replaceOp(op, v); |
705 | return success(); |
706 | } |
707 | }; |
708 | |
709 | /// Sparse codegen rule for trivial tensor casts. |
710 | class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { |
711 | public: |
712 | using OpConversionPattern::OpConversionPattern; |
713 | LogicalResult |
714 | matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, |
715 | ConversionPatternRewriter &rewriter) const override { |
716 | // Only rewrite identically annotated source/dest. |
717 | auto encDst = getSparseTensorEncoding(op.getType()); |
718 | auto encSrc = getSparseTensorEncoding(op.getSource().getType()); |
719 | if (!encDst || encDst != encSrc) |
720 | return failure(); |
721 | rewriter.replaceOp(op, adaptor.getOperands()); |
722 | return success(); |
723 | } |
724 | }; |
725 | |
726 | class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { |
727 | public: |
728 | using OpConversionPattern::OpConversionPattern; |
729 | LogicalResult |
730 | matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, |
731 | ConversionPatternRewriter &rewriter) const override { |
732 | // Simply fold the operation. |
733 | rewriter.replaceOp(op, adaptor.getSource()); |
734 | return success(); |
735 | } |
736 | }; |
737 | |
738 | /// Sparse codegen rule for the alloc operator. |
739 | class SparseTensorAllocConverter |
740 | : public OpConversionPattern<bufferization::AllocTensorOp> { |
741 | public: |
742 | using OpConversionPattern::OpConversionPattern; |
743 | SparseTensorAllocConverter(TypeConverter &typeConverter, MLIRContext *context, |
744 | bool enableInit) |
745 | : OpConversionPattern(typeConverter, context), |
746 | enableBufferInitialization(enableInit) {} |
747 | |
748 | LogicalResult |
749 | matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, |
750 | ConversionPatternRewriter &rewriter) const override { |
751 | const auto resType = getSparseTensorType(op); |
752 | if (!resType.hasEncoding()) |
753 | return failure(); |
754 | |
755 | Location loc = op.getLoc(); |
756 | // Deal with copy. |
757 | if (op.getCopy()) { |
758 | auto desc = getDescriptorFromTensorTuple(adaptor.getCopy()); |
759 | SmallVector<Value> fields; |
760 | fields.reserve(N: desc.getNumFields()); |
761 | // Memcpy on memref fields. |
762 | for (auto field : desc.getMemRefFields()) { |
763 | auto memrefTp = cast<MemRefType>(field.getType()); |
764 | auto size = rewriter.create<memref::DimOp>(loc, field, 0); |
765 | auto copied = |
766 | rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size}); |
767 | rewriter.create<memref::CopyOp>(loc, field, copied); |
768 | fields.push_back(copied); |
769 | } |
770 | // Reuses specifier. |
771 | fields.push_back(Elt: desc.getSpecifier()); |
772 | assert(fields.size() == desc.getNumFields()); |
773 | rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); |
774 | return success(); |
775 | } |
776 | |
777 | if (!resType.isIdentity()) { |
778 | return rewriter.notifyMatchFailure( |
779 | op, "try run --sparse-reinterpret-map before codegen" ); |
780 | } |
781 | // Level size equals to dimension size since lvl2dim map is an identity map. |
782 | SmallVector<Value> lvlSizesValues; |
783 | createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), |
784 | /*dimSizesValues=*/lvlSizesValues); |
785 | |
786 | // Construct allocation for each field. |
787 | Value sizeHint = op.getSizeHint(); |
788 | SmallVector<Value> fields; |
789 | createAllocFields(rewriter, loc, resType, enableBufferInitialization, |
790 | sizeHint, lvlSizesValues, fields); |
791 | |
792 | // Replace operation with resulting memrefs. |
793 | rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); |
794 | return success(); |
795 | } |
796 | |
797 | private: |
798 | bool enableBufferInitialization; |
799 | }; |
800 | |
801 | /// Sparse codegen rule for the empty tensor operator. |
802 | class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { |
803 | public: |
804 | using OpConversionPattern::OpConversionPattern; |
805 | SparseTensorEmptyConverter(TypeConverter &typeConverter, MLIRContext *context, |
806 | bool enableInit) |
807 | : OpConversionPattern(typeConverter, context), |
808 | enableBufferInitialization(enableInit) {} |
809 | |
810 | LogicalResult |
811 | matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, |
812 | ConversionPatternRewriter &rewriter) const override { |
813 | const auto resType = getSparseTensorType(op); |
814 | if (!resType.hasEncoding()) |
815 | return failure(); |
816 | |
817 | if (!resType.isIdentity()) { |
818 | return rewriter.notifyMatchFailure( |
819 | op, "try run --sparse-reinterpret-map before codegen" ); |
820 | } |
821 | |
822 | Location loc = op.getLoc(); |
823 | // Level size equals to dimension size since lvl2dim map is an identity map. |
824 | SmallVector<Value> lvlSizesValues; |
825 | createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), |
826 | /*dimSizesValues=*/lvlSizesValues); |
827 | // Construct allocation for each field. |
828 | Value sizeHint; // none |
829 | SmallVector<Value> fields; |
830 | createAllocFields(rewriter, loc, resType, enableBufferInitialization, |
831 | sizeHint, lvlSizesValues, fields); |
832 | |
833 | // Replace operation with resulting memrefs. |
834 | rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); |
835 | return success(); |
836 | } |
837 | |
838 | private: |
839 | bool enableBufferInitialization; |
840 | }; |
841 | |
842 | /// Sparse codegen rule for the dealloc operator. |
843 | class SparseTensorDeallocConverter |
844 | : public OpConversionPattern<bufferization::DeallocTensorOp> { |
845 | public: |
846 | using OpConversionPattern::OpConversionPattern; |
847 | SparseTensorDeallocConverter(TypeConverter &typeConverter, |
848 | MLIRContext *context, bool createDeallocs) |
849 | : OpConversionPattern(typeConverter, context), |
850 | createDeallocs(createDeallocs) {} |
851 | |
852 | LogicalResult |
853 | matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, |
854 | ConversionPatternRewriter &rewriter) const override { |
855 | auto enc = getSparseTensorEncoding(op.getTensor().getType()); |
856 | if (!enc) |
857 | return failure(); |
858 | |
859 | // If user requests not to deallocate sparse tensors, simply erase the |
860 | // operation. |
861 | if (createDeallocs) { |
862 | // Replace the sparse tensor deallocation with field deallocations. |
863 | Location loc = op.getLoc(); |
864 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); |
865 | for (auto input : desc.getMemRefFields()) |
866 | // Deallocate every buffer used to store the sparse tensor handler. |
867 | rewriter.create<memref::DeallocOp>(loc, input); |
868 | } |
869 | rewriter.eraseOp(op: op); |
870 | return success(); |
871 | } |
872 | |
873 | private: |
874 | const bool createDeallocs; |
875 | }; |
876 | |
877 | /// Sparse codegen rule for tensor rematerialization. |
878 | class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { |
879 | public: |
880 | using OpConversionPattern::OpConversionPattern; |
881 | LogicalResult |
882 | matchAndRewrite(LoadOp op, OpAdaptor adaptor, |
883 | ConversionPatternRewriter &rewriter) const override { |
884 | // Prepare descriptor. |
885 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); |
886 | // Generate optional insertion finalization code. |
887 | if (op.getHasInserts()) |
888 | genEndInsert(rewriter, op.getLoc(), desc); |
889 | // Replace operation with resulting memrefs. |
890 | rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); |
891 | return success(); |
892 | } |
893 | }; |
894 | |
895 | /// Sparse codegen rule for the expand op. |
896 | class SparseExpandConverter : public OpConversionPattern<ExpandOp> { |
897 | public: |
898 | using OpConversionPattern::OpConversionPattern; |
899 | LogicalResult |
900 | matchAndRewrite(ExpandOp op, OpAdaptor adaptor, |
901 | ConversionPatternRewriter &rewriter) const override { |
902 | if (!getSparseTensorEncoding(op.getTensor().getType())) |
903 | return failure(); |
904 | Location loc = op->getLoc(); |
905 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); |
906 | const auto srcType = getSparseTensorType(op.getTensor()); |
907 | Type eltType = srcType.getElementType(); |
908 | Type boolType = rewriter.getIntegerType(1); |
909 | Type idxType = rewriter.getIndexType(); |
910 | // All initialization should be done on entry of the loop nest. |
911 | rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); |
912 | |
913 | // Determine the size for access expansion (always the innermost stored |
914 | // level size). |
915 | const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1); |
916 | // Generate a memref for `sz` elements of type `t`. |
917 | const auto genAlloc = [&](Type t) { |
918 | const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); |
919 | return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); |
920 | }; |
921 | // Allocate temporary buffers for values/filled-switch and added. |
922 | // We do not use stack buffers for this, since the expanded size may |
923 | // be rather large (as it envelops a single expanded dense dimension). |
924 | Value values = genAlloc(eltType); |
925 | Value filled = genAlloc(boolType); |
926 | Value added = genAlloc(idxType); |
927 | Value zero = constantZero(builder&: rewriter, loc, tp: idxType); |
928 | // Reset the values/filled-switch to all-zero/false. Note that this |
929 | // introduces an O(N) operation into the computation, but this reset |
930 | // operation is amortized over the innermost loops for the access |
931 | // pattern expansion. As noted in the operation doc, we would like |
932 | // to amortize this setup cost even between kernels. |
933 | rewriter.create<linalg::FillOp>( |
934 | loc, ValueRange{constantZero(rewriter, loc, eltType)}, |
935 | ValueRange{values}); |
936 | rewriter.create<linalg::FillOp>( |
937 | loc, ValueRange{constantZero(rewriter, loc, boolType)}, |
938 | ValueRange{filled}); |
939 | // Replace expansion op with these buffers and initial coordinate. |
940 | assert(op.getNumResults() == 4); |
941 | rewriter.replaceOp(op, {values, filled, added, zero}); |
942 | return success(); |
943 | } |
944 | }; |
945 | |
946 | /// Sparse codegen rule for the compress operator. |
947 | class SparseCompressConverter : public OpConversionPattern<CompressOp> { |
948 | public: |
949 | using OpConversionPattern::OpConversionPattern; |
950 | LogicalResult |
951 | matchAndRewrite(CompressOp op, OpAdaptor adaptor, |
952 | ConversionPatternRewriter &rewriter) const override { |
953 | Location loc = op->getLoc(); |
954 | SmallVector<Value> fields; |
955 | auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); |
956 | Value values = adaptor.getValues(); |
957 | Value filled = adaptor.getFilled(); |
958 | Value added = adaptor.getAdded(); |
959 | Value count = adaptor.getCount(); |
960 | const SparseTensorType dstType(desc.getRankedTensorType()); |
961 | Type eltType = dstType.getElementType(); |
962 | |
963 | // If the innermost level is ordered, we need to sort the coordinates |
964 | // in the "added" array prior to applying the compression. |
965 | if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) |
966 | rewriter.create<SortOp>( |
967 | loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1), |
968 | rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); |
969 | // While performing the insertions, we also need to reset the elements |
970 | // of the values/filled-switch by only iterating over the set elements, |
971 | // to ensure that the runtime complexity remains proportional to the |
972 | // sparsity of the expanded access pattern. |
973 | // |
974 | // Generate |
975 | // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) { |
976 | // crd = added[i]; |
977 | // value = values[crd]; |
978 | // insert({lvlCoords, crd}, value); |
979 | // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value); |
980 | // values[crd] = 0; |
981 | // filled[crd] = false; |
982 | // yield new_memrefs |
983 | // } |
984 | scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); |
985 | Value i = loop.getInductionVar(); |
986 | |
987 | Value crd = genLoad(builder&: rewriter, loc, mem: added, idx: i); |
988 | Value value = genLoad(builder&: rewriter, loc, mem: values, idx: crd); |
989 | SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end()); |
990 | SmallVector<Type> flatSpTensorTps = llvm::to_vector( |
991 | llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); |
992 | params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end()); |
993 | params.push_back(Elt: crd); |
994 | params.push_back(Elt: value); |
995 | SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, |
996 | params, /*genCall=*/true); |
997 | SmallVector<Value> insertRet = insertGen.genCallOrInline(builder&: rewriter, loc); |
998 | genStore(builder&: rewriter, loc, val: constantZero(builder&: rewriter, loc, tp: eltType), mem: values, idx: crd); |
999 | genStore(builder&: rewriter, loc, val: constantI1(builder&: rewriter, loc, b: false), mem: filled, idx: crd); |
1000 | rewriter.create<scf::YieldOp>(loc, insertRet); |
1001 | |
1002 | rewriter.setInsertionPointAfter(loop); |
1003 | Value result = genTuple(rewriter, loc, dstType, loop->getResults()); |
1004 | // Deallocate the buffers on exit of the full loop nest. |
1005 | Operation *parent = getTop(op); |
1006 | rewriter.setInsertionPointAfter(parent); |
1007 | rewriter.create<memref::DeallocOp>(loc, values); |
1008 | rewriter.create<memref::DeallocOp>(loc, filled); |
1009 | rewriter.create<memref::DeallocOp>(loc, added); |
1010 | // Replace operation with resulting memrefs. |
1011 | rewriter.replaceOp(op, result); |
1012 | return success(); |
1013 | } |
1014 | }; |
1015 | |
1016 | /// Sparse codegen rule for the insert operator. |
1017 | class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { |
1018 | public: |
1019 | using OpConversionPattern::OpConversionPattern; |
1020 | LogicalResult |
1021 | matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, |
1022 | ConversionPatternRewriter &rewriter) const override { |
1023 | auto stt = getSparseTensorType(adaptor.getDest()); |
1024 | if (!stt.hasEncoding()) |
1025 | return failure(); |
1026 | assert(stt.isIdentity() && "Run reinterpret-map before conversion." ); |
1027 | |
1028 | Location loc = op.getLoc(); |
1029 | auto desc = getDescriptorFromTensorTuple(adaptor.getDest()); |
1030 | TypeRange flatSpTensorTps = desc.getFields().getTypes(); |
1031 | SmallVector<Value> params = llvm::to_vector(desc.getFields()); |
1032 | params.append(adaptor.getIndices().begin(), adaptor.getIndices().end()); |
1033 | params.push_back(Elt: adaptor.getScalar()); |
1034 | SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps, |
1035 | params, /*genCall=*/true); |
1036 | SmallVector<Value> ret = insertGen.genCallOrInline(builder&: rewriter, loc); |
1037 | // Replace operation with resulting memrefs. |
1038 | rewriter.replaceOp(op, |
1039 | genTuple(rewriter, loc, op.getDest().getType(), ret)); |
1040 | return success(); |
1041 | } |
1042 | }; |
1043 | |
1044 | /// Sparse codegen rule for position accesses. |
1045 | class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> { |
1046 | public: |
1047 | using OpAdaptor = typename ToPositionsOp::Adaptor; |
1048 | using OpConversionPattern<ToPositionsOp>::OpConversionPattern; |
1049 | LogicalResult |
1050 | matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, |
1051 | ConversionPatternRewriter &rewriter) const override { |
1052 | // Replace the requested position access with corresponding field. |
1053 | // The cast_op is inserted by type converter to intermix 1:N type |
1054 | // conversion. |
1055 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); |
1056 | rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel())); |
1057 | return success(); |
1058 | } |
1059 | }; |
1060 | |
1061 | /// Sparse codegen rule for accessing the coordinates arrays. |
1062 | class SparseToCoordinatesConverter |
1063 | : public OpConversionPattern<ToCoordinatesOp> { |
1064 | public: |
1065 | using OpAdaptor = typename ToCoordinatesOp::Adaptor; |
1066 | using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern; |
1067 | LogicalResult |
1068 | matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, |
1069 | ConversionPatternRewriter &rewriter) const override { |
1070 | // Replace the requested coordinates access with corresponding field. |
1071 | // The cast_op is inserted by type converter to intermix 1:N type |
1072 | // conversion. |
1073 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); |
1074 | rewriter.replaceOp( |
1075 | op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel())); |
1076 | |
1077 | return success(); |
1078 | } |
1079 | }; |
1080 | |
1081 | /// Sparse codegen rule for accessing the linear coordinates buffer. |
1082 | class SparseToCoordinatesBufferConverter |
1083 | : public OpConversionPattern<ToCoordinatesBufferOp> { |
1084 | public: |
1085 | using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; |
1086 | using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern; |
1087 | LogicalResult |
1088 | matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor, |
1089 | ConversionPatternRewriter &rewriter) const override { |
1090 | // Replace the requested coordinates access with corresponding field. |
1091 | // The cast_op is inserted by type converter to intermix 1:N type |
1092 | // conversion. |
1093 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); |
1094 | rewriter.replaceOp(op, desc.getAOSMemRef()); |
1095 | |
1096 | return success(); |
1097 | } |
1098 | }; |
1099 | |
1100 | /// Sparse codegen rule for value accesses. |
1101 | class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> { |
1102 | public: |
1103 | using OpAdaptor = typename ToValuesOp::Adaptor; |
1104 | using OpConversionPattern<ToValuesOp>::OpConversionPattern; |
1105 | LogicalResult |
1106 | matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, |
1107 | ConversionPatternRewriter &rewriter) const override { |
1108 | // Replace the requested values access with corresponding field. |
1109 | // The cast_op is inserted by type converter to intermix 1:N type |
1110 | // conversion. |
1111 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); |
1112 | rewriter.replaceOp(op, desc.getValMemRef()); |
1113 | return success(); |
1114 | } |
1115 | }; |
1116 | |
1117 | /// Sparse codegen rule for the convert operator. |
1118 | class SparseConvertConverter : public OpConversionPattern<ConvertOp> { |
1119 | public: |
1120 | using OpConversionPattern::OpConversionPattern; |
1121 | LogicalResult |
1122 | matchAndRewrite(ConvertOp op, OpAdaptor adaptor, |
1123 | ConversionPatternRewriter &rewriter) const override { |
1124 | SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); |
1125 | SparseTensorEncodingAttr encSrc = |
1126 | getSparseTensorEncoding(op.getSource().getType()); |
1127 | // The output tensor can not be a slice and those cases should have been |
1128 | // rejected by ConvertOp::verify() already. |
1129 | assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices." ); |
1130 | // Different encoding (except for different bitwidth) should be handled by |
1131 | // rewriting. |
1132 | // We need further rewrites if the input tensor is a slice too. |
1133 | if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() || |
1134 | encSrc.isSlice()) { |
1135 | return failure(); |
1136 | } |
1137 | |
1138 | Type retElemTp = op.getResult().getType().getElementType(); |
1139 | Type srcElemTp = op.getSource().getType().getElementType(); |
1140 | // Fold the trivial cases. |
1141 | if (retElemTp == srcElemTp && encDst == encSrc) { |
1142 | rewriter.replaceOp(op, adaptor.getSource()); |
1143 | return success(); |
1144 | } |
1145 | // |
1146 | // Do element-wise type conversion without using InsertOp. |
1147 | // |
1148 | // for each memref in srcTensor: |
1149 | // dst = memref.alloc |
1150 | // if srcMemRefType != dstMemRefType: |
1151 | // for every dst[i] = cast(src[i]) |
1152 | // else: |
1153 | // dst = memref.copy(src) |
1154 | Location loc = op.getLoc(); |
1155 | auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource()); |
1156 | SmallVector<Value> fields; |
1157 | foreachFieldAndTypeInSparseTensor( |
1158 | SparseTensorType(cast<RankedTensorType>(op.getResult().getType())), |
1159 | [&rewriter, &fields, srcDesc, |
1160 | loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl, |
1161 | LevelType /*lt*/) -> bool { |
1162 | // Simply reuses the storage specifier as it is an SSA value. |
1163 | if (fKind == SparseTensorFieldKind::StorageSpec) { |
1164 | fields.push_back(Elt: srcDesc.getSpecifier()); |
1165 | } else { |
1166 | // Allocates new memrefs |
1167 | Value srcMem = srcDesc.getMemRefField(fIdx); |
1168 | // TODO: We can instead use the actual memSize in specifier, that |
1169 | // would require a subViewOp to avoid overflow when copying |
1170 | // values. |
1171 | Value sz = linalg::createOrFoldDimOp(b&: rewriter, loc, val: srcMem, dim: 0); |
1172 | auto dstMem = rewriter.create<memref::AllocOp>( |
1173 | loc, cast<MemRefType>(fTp), sz); |
1174 | if (fTp != srcMem.getType()) { |
1175 | // Converts elements type. |
1176 | scf::buildLoopNest( |
1177 | builder&: rewriter, loc, lbs: constantIndex(builder&: rewriter, loc, i: 0), ubs: sz, |
1178 | steps: constantIndex(builder&: rewriter, loc, i: 1), |
1179 | bodyBuilder: [srcMem, &dstMem](OpBuilder &builder, Location loc, |
1180 | ValueRange ivs) { |
1181 | Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs); |
1182 | Value casted = genCast(builder, loc, v, |
1183 | dstMem.getType().getElementType()); |
1184 | builder.create<memref::StoreOp>(loc, casted, dstMem, ivs); |
1185 | }); |
1186 | } else { |
1187 | // TODO: We can even reuse the same memref for the new tensor, |
1188 | // but that requires a `ref-counting` based memory management |
1189 | // for shared memrefs between multiple sparse tensors. |
1190 | rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); |
1191 | } |
1192 | fields.push_back(Elt: dstMem); |
1193 | } |
1194 | return true; |
1195 | }); |
1196 | |
1197 | rewriter.replaceOp( |
1198 | op, genTuple(rewriter, loc, op.getResult().getType(), fields)); |
1199 | return success(); |
1200 | } |
1201 | }; |
1202 | |
1203 | class |
1204 | : public OpConversionPattern<tensor::ExtractSliceOp> { |
1205 | public: |
1206 | using OpConversionPattern::OpConversionPattern; |
1207 | LogicalResult |
1208 | matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, |
1209 | ConversionPatternRewriter &rewriter) const override { |
1210 | Location loc = op.getLoc(); |
1211 | MLIRContext *ctx = op.getContext(); |
1212 | auto srcEnc = getSparseTensorEncoding(op.getSourceType()); |
1213 | auto dstEnc = getSparseTensorEncoding(op.getResult().getType()); |
1214 | // TODO: We should check these in ExtractSliceOp::verify. |
1215 | if (!srcEnc || !dstEnc || !dstEnc.isSlice()) |
1216 | return failure(); |
1217 | assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); |
1218 | |
1219 | SmallVector<Value> fields; |
1220 | auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields); |
1221 | |
1222 | auto newSpec = rewriter.create<StorageSpecifierInitOp>( |
1223 | loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); |
1224 | desc.setSpecifier(newSpec); |
1225 | |
1226 | // Fills in slice information. |
1227 | for (auto [idx, offset, size, stride] : llvm::enumerate( |
1228 | op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) { |
1229 | Dimension dim = idx; |
1230 | |
1231 | Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset); |
1232 | Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size); |
1233 | Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride); |
1234 | // TODO: We could probably only set dynamic value here. But it would |
1235 | // requires us to fill the hole when casting a static slice to dynamic |
1236 | // slice. |
1237 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset, |
1238 | dim, offsetV); |
1239 | |
1240 | // FIXME: we need to distinguish level sizes and dimension size for slices |
1241 | // here. Maybe we should store slice level sizes in a different array |
1242 | // instead of reusing it. |
1243 | assert(srcEnc.isIdentity()); |
1244 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim, |
1245 | sizeV); |
1246 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride, |
1247 | dim, strideV); |
1248 | } |
1249 | |
1250 | // NOTE: we can not generate tuples directly from descriptor here, as the |
1251 | // descriptor is holding the original type, yet we want the slice type |
1252 | // here (they shared every memref but with an updated specifier). |
1253 | rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(), |
1254 | desc.getFields())); |
1255 | return success(); |
1256 | } |
1257 | }; |
1258 | |
1259 | /// Sparse codegen rule for number of entries operator. |
1260 | class SparseNumberOfEntriesConverter |
1261 | : public OpConversionPattern<NumberOfEntriesOp> { |
1262 | public: |
1263 | using OpConversionPattern::OpConversionPattern; |
1264 | LogicalResult |
1265 | matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, |
1266 | ConversionPatternRewriter &rewriter) const override { |
1267 | // Query memSizes for the actually stored values. |
1268 | // FIXME: the nse value computed in this way might be wrong when there is |
1269 | // any "loose_compressed" level. |
1270 | rewriter.replaceOp( |
1271 | op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor())); |
1272 | return success(); |
1273 | } |
1274 | }; |
1275 | |
1276 | struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> { |
1277 | using OpConversionPattern::OpConversionPattern; |
1278 | LogicalResult |
1279 | matchAndRewrite(AssembleOp op, OpAdaptor adaptor, |
1280 | ConversionPatternRewriter &rewriter) const override { |
1281 | Location loc = op.getLoc(); |
1282 | const auto stt = getSparseTensorType(op.getResult()); |
1283 | |
1284 | SmallVector<Value> fields; |
1285 | |
1286 | foreachFieldAndTypeInSparseTensor( |
1287 | stt, |
1288 | [&rewriter, &fields, &op, &stt, |
1289 | loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, |
1290 | Level /*lvl*/, LevelType lt) -> bool { |
1291 | assert(fields.size() == fIdx); |
1292 | if (fKind == SparseTensorFieldKind::StorageSpec) { |
1293 | fields.push_back( |
1294 | Elt: SparseTensorSpecifier::getInitValue(builder&: rewriter, loc, stt: stt)); |
1295 | } else { |
1296 | // Else simply takes the inputs. |
1297 | Value tensor = fKind == SparseTensorFieldKind::ValMemRef |
1298 | ? op.getValues() |
1299 | : op.getLevels()[fIdx]; |
1300 | // TODO: handle batch. |
1301 | TypedValue<BaseMemRefType> mem = genToMemref(builder&: rewriter, loc, tensor); |
1302 | if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) { |
1303 | // Flattens the buffer to batchLvlRank. |
1304 | auto reassoc = getReassociationForFlattening( |
1305 | mem.getType(), stt.getBatchLvlRank()); |
1306 | mem = rewriter.create<memref::CastOp>( |
1307 | loc, fType, |
1308 | rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc)); |
1309 | } else { |
1310 | mem = rewriter.create<memref::CastOp>(loc, fType, mem); |
1311 | } |
1312 | fields.push_back(Elt: mem); |
1313 | } |
1314 | return true; |
1315 | }); |
1316 | |
1317 | MutSparseTensorDescriptor desc(stt, fields); |
1318 | Value c0 = constantIndex(builder&: rewriter, loc, i: 0); |
1319 | Value c1 = constantIndex(builder&: rewriter, loc, i: 1); |
1320 | Value c2 = constantIndex(builder&: rewriter, loc, i: 2); |
1321 | Value posBack = c0; // index to the last value in the position array |
1322 | Value memSize = c1; // memory size for current array |
1323 | |
1324 | Level trailCOOStart = stt.getAoSCOOStart(); |
1325 | Level trailCOORank = stt.getLvlRank() - trailCOOStart; |
1326 | // Sets up SparseTensorSpecifier. |
1327 | for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { |
1328 | assert(!ShapedType::isDynamic(stt.getDimShape()[lvl])); |
1329 | |
1330 | // Sets up the level size. |
1331 | auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]); |
1332 | desc.setLvlSize(builder&: rewriter, loc, lvl, v: lvlSize); |
1333 | // We use a single AOS array to store the trailing COO, so there is only |
1334 | // one memory size to set for the entire COO section. |
1335 | if (lvl > trailCOOStart) |
1336 | continue; |
1337 | |
1338 | // Sets up the memory size by reading the last value in position array. |
1339 | LevelType lt = stt.getLvlType(lvl); |
1340 | // Simply forwards the position index when this is a dense level. |
1341 | if (lt.isa<LevelFormat::Dense>()) { |
1342 | memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize); |
1343 | posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); |
1344 | continue; |
1345 | } |
1346 | if (lt.isa<LevelFormat::Batch>()) { |
1347 | // Skips batch levels as it is not linearized. |
1348 | // FIXME: this assumes that every batch has the same number of nse, need |
1349 | // to be generalized to handle varied-size batches. |
1350 | continue; |
1351 | } |
1352 | |
1353 | if (isWithPosLT(lt)) { |
1354 | assert(isCompressedLT(lt) || isLooseCompressedLT(lt)); |
1355 | if (isLooseCompressedLT(lt)) { |
1356 | memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2); |
1357 | posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); |
1358 | } else { |
1359 | assert(isCompressedLT(lt)); |
1360 | posBack = memSize; |
1361 | memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1); |
1362 | } |
1363 | desc.setPosMemSize(builder&: rewriter, loc, lvl, v: memSize); |
1364 | // The last value in position array is the memory size for next level. |
1365 | // FIXME: this assumes that every batch has the same number of nse, need |
1366 | // to be generalized to handle varied-size batches. |
1367 | SmallVector<Value> batched(stt.getBatchLvlRank(), |
1368 | constantIndex(builder&: rewriter, loc, i: 0)); |
1369 | batched.push_back(Elt: posBack); |
1370 | memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched); |
1371 | posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1); |
1372 | } |
1373 | assert(isWithCrdLT(lt) && lvl <= trailCOOStart); |
1374 | // FIXME: This seems to be unnecessarily complex, can we simplify it? |
1375 | if (lvl == trailCOOStart) { |
1376 | Value cooSz = rewriter.create<arith::MulIOp>( |
1377 | loc, memSize, constantIndex(rewriter, loc, trailCOORank)); |
1378 | desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: cooSz); |
1379 | } else { |
1380 | desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: memSize); |
1381 | } |
1382 | } |
1383 | desc.setValMemSize(builder&: rewriter, loc, v: memSize); |
1384 | |
1385 | rewriter.replaceOp(op, genTuple(builder&: rewriter, loc, desc)); |
1386 | return success(); |
1387 | } |
1388 | }; |
1389 | |
1390 | struct SparseDisassembleOpConverter |
1391 | : public OpConversionPattern<DisassembleOp> { |
1392 | using OpConversionPattern::OpConversionPattern; |
1393 | SparseDisassembleOpConverter(TypeConverter &typeConverter, |
1394 | MLIRContext *context) |
1395 | : OpConversionPattern(typeConverter, context) {} |
1396 | |
1397 | LogicalResult |
1398 | matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, |
1399 | ConversionPatternRewriter &rewriter) const override { |
1400 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); |
1401 | Location loc = op.getLoc(); |
1402 | SmallVector<Value> retMem; |
1403 | SmallVector<Value> retLen; |
1404 | desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, |
1405 | &retLen](FieldIndex fid, |
1406 | SparseTensorFieldKind fKind, |
1407 | Level lvl, LevelType lt) -> bool { |
1408 | if (fKind == SparseTensorFieldKind::StorageSpec) |
1409 | return true; |
1410 | SparseTensorType stt(desc.getRankedTensorType()); |
1411 | Value sz, src; |
1412 | TypedValue<BaseMemRefType> dst; |
1413 | if (fKind == SparseTensorFieldKind::ValMemRef) { |
1414 | sz = desc.getValMemSize(rewriter, loc); |
1415 | src = desc.getValMemRef(); |
1416 | dst = genToMemref(rewriter, loc, op.getOutValues()); |
1417 | |
1418 | retMem.push_back(Elt: dst); |
1419 | Type valLenTp = op.getValLen().getType(); |
1420 | retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: valLenTp)); |
1421 | } else { |
1422 | assert(fKind == SparseTensorFieldKind::PosMemRef || |
1423 | fKind == SparseTensorFieldKind::CrdMemRef); |
1424 | |
1425 | sz = fKind == SparseTensorFieldKind::PosMemRef |
1426 | ? desc.getPosMemSize(rewriter, loc, lvl) |
1427 | : desc.getCrdMemSize(rewriter, loc, lvl); |
1428 | src = desc.getMemRefField(fid); |
1429 | dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]); |
1430 | retMem.push_back(Elt: dst); |
1431 | // Retrieves the corresponding level length type. |
1432 | Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()]; |
1433 | retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: lvlLenTp)); |
1434 | } |
1435 | Value flatOut = dst; |
1436 | if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) { |
1437 | auto reassoc = |
1438 | getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank()); |
1439 | flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc); |
1440 | } |
1441 | Value dstMem = genSliceToSize(builder&: rewriter, loc, mem: flatOut, sz); |
1442 | Value srcMem = genSliceToSize(builder&: rewriter, loc, mem: src, sz); |
1443 | rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); |
1444 | return true; |
1445 | }); |
1446 | |
1447 | // Converts MemRefs back to Tensors. |
1448 | SmallVector<Value> retValues = llvm::to_vector( |
1449 | Range: llvm::map_range(C&: retMem, F: [&rewriter, loc](Value v) -> Value { |
1450 | return rewriter.create<bufferization::ToTensorOp>(loc, v); |
1451 | })); |
1452 | // Appends the actual memory length used in each buffer returned. |
1453 | retValues.append(in_start: retLen.begin(), in_end: retLen.end()); |
1454 | rewriter.replaceOp(op, retValues); |
1455 | return success(); |
1456 | } |
1457 | }; |
1458 | |
1459 | struct SparseNewConverter : public OpConversionPattern<NewOp> { |
1460 | using OpConversionPattern::OpConversionPattern; |
1461 | LogicalResult |
1462 | matchAndRewrite(NewOp op, OpAdaptor adaptor, |
1463 | ConversionPatternRewriter &rewriter) const override { |
1464 | Location loc = op.getLoc(); |
1465 | const auto dstTp = getSparseTensorType(op.getResult()); |
1466 | // Creating COO with NewOp is handled by direct IR codegen. All other cases |
1467 | // are handled by rewriting. |
1468 | if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0) |
1469 | return failure(); |
1470 | |
1471 | // Implement as follows: |
1472 | // %reader = @createCheckedSparseTensorReader(%filename) |
1473 | // %nse = @getSparseTensorNSE(%reader) |
1474 | // %coo = bufferization.alloc_tensor an ordered COO with |
1475 | // dst dim ordering, size_hint = %nse |
1476 | // %coordinates = sparse_tensor.coordinates_buffer(%coo) |
1477 | // %values = sparse_tensor.values(%coo) |
1478 | // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values) |
1479 | // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values) |
1480 | // update storage specifier |
1481 | // @delSparseTensorReader(%reader) |
1482 | SmallVector<Value> dimSizesValues; |
1483 | Value dimSizesBuffer; |
1484 | Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0], |
1485 | dimSizesValues, dimSizesBuffer); |
1486 | |
1487 | // Get the number of stored entries. |
1488 | const Type indexTp = rewriter.getIndexType(); |
1489 | Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE" , |
1490 | {indexTp}, {reader}, EmitCInterface::Off) |
1491 | .getResult(0); |
1492 | |
1493 | // Construct the lvl sizes and the dim2lvl/lvl2dim buffers. |
1494 | SmallVector<Value> lvlSizesValues; |
1495 | Value dim2lvlBuffer; |
1496 | Value lvl2dimBuffer; |
1497 | genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer, |
1498 | lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer); |
1499 | |
1500 | // Construct allocation for each field. |
1501 | Value sizeHint = nse; |
1502 | SmallVector<Value> fields; |
1503 | createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint, |
1504 | lvlSizesValues, fields); |
1505 | |
1506 | // Read the COO tensor data. |
1507 | MutSparseTensorDescriptor desc(dstTp, fields); |
1508 | Value xs = desc.getAOSMemRef(); |
1509 | Value ys = desc.getValMemRef(); |
1510 | const Type boolTp = rewriter.getIntegerType(1); |
1511 | const Type elemTp = dstTp.getElementType(); |
1512 | const Type crdTp = dstTp.getCrdType(); |
1513 | SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers" , |
1514 | overheadTypeFunctionSuffix(overheadTp: crdTp), |
1515 | primaryTypeFunctionSuffix(elemTp)}; |
1516 | Value isSorted = |
1517 | createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp}, |
1518 | {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys}, |
1519 | EmitCInterface::On) |
1520 | .getResult(0); |
1521 | |
1522 | // If the destination tensor is a sorted COO, we need to sort the COO tensor |
1523 | // data if the input elements aren't sorted yet. |
1524 | const Level lvlRank = dstTp.getLvlRank(); |
1525 | if (dstTp.isOrderedLvl(lvlRank - 1)) { |
1526 | Value kFalse = constantI1(builder&: rewriter, loc, b: false); |
1527 | Value notSorted = rewriter.create<arith::CmpIOp>( |
1528 | loc, arith::CmpIPredicate::eq, isSorted, kFalse); |
1529 | scf::IfOp ifOp = |
1530 | rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false); |
1531 | rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
1532 | auto xPerm = rewriter.getMultiDimIdentityMap(rank: lvlRank); |
1533 | rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm, |
1534 | rewriter.getIndexAttr(0), |
1535 | SparseTensorSortKind::HybridQuickSort); |
1536 | rewriter.setInsertionPointAfter(ifOp); |
1537 | } |
1538 | |
1539 | // Set PosMemRef0[1] = nse. |
1540 | const Value c1 = constantIndex(builder&: rewriter, loc, i: 1); |
1541 | const Value posMemref0 = desc.getPosMemRef(0); |
1542 | const Type posTp = dstTp.getPosType(); |
1543 | const Value posNse = genCast(builder&: rewriter, loc, value: nse, dstTy: posTp); |
1544 | rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1); |
1545 | |
1546 | // Update storage specifier. |
1547 | Value coordinatesSize = rewriter.create<arith::MulIOp>( |
1548 | loc, nse, constantIndex(rewriter, loc, lvlRank)); |
1549 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0, |
1550 | coordinatesSize); |
1551 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize, |
1552 | std::nullopt, nse); |
1553 | |
1554 | // Release the sparse tensor reader. |
1555 | createFuncCall(builder&: rewriter, loc, name: "delSparseTensorReader" , resultType: {}, operands: {reader}, |
1556 | emitCInterface: EmitCInterface::Off); |
1557 | |
1558 | // Replace operation with resulting memrefs. |
1559 | rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields)); |
1560 | return success(); |
1561 | } |
1562 | }; |
1563 | |
1564 | struct SparseHasRuntimeLibraryConverter |
1565 | : public OpConversionPattern<HasRuntimeLibraryOp> { |
1566 | using OpConversionPattern::OpConversionPattern; |
1567 | LogicalResult |
1568 | matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor, |
1569 | ConversionPatternRewriter &rewriter) const override { |
1570 | auto i1Type = rewriter.getI1Type(); |
1571 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
1572 | op, i1Type, rewriter.getIntegerAttr(i1Type, 0)); |
1573 | return success(); |
1574 | } |
1575 | }; |
1576 | |
1577 | } // namespace |
1578 | |
1579 | //===----------------------------------------------------------------------===// |
1580 | // Public method for populating conversion rules. |
1581 | //===----------------------------------------------------------------------===// |
1582 | |
1583 | /// Populates the given patterns list with conversion rules required for |
1584 | /// the sparsification of linear algebra operations. |
1585 | void mlir::populateSparseTensorCodegenPatterns( |
1586 | TypeConverter &typeConverter, RewritePatternSet &patterns, |
1587 | bool createSparseDeallocs, bool enableBufferInitialization) { |
1588 | patterns.add< |
1589 | SparseAssembleOpConverter, SparseDisassembleOpConverter, |
1590 | SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter, |
1591 | SparseCastConverter, SparseExtractSliceConverter, |
1592 | SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, |
1593 | SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter, |
1594 | SparseSliceGetterOpConverter<ToSliceOffsetOp, |
1595 | StorageSpecifierKind::DimOffset>, |
1596 | SparseSliceGetterOpConverter<ToSliceStrideOp, |
1597 | StorageSpecifierKind::DimStride>, |
1598 | SparseToPositionsConverter, SparseToCoordinatesConverter, |
1599 | SparseToCoordinatesBufferConverter, SparseToValuesConverter, |
1600 | SparseConvertConverter, SparseNewConverter, |
1601 | SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>( |
1602 | typeConverter, patterns.getContext()); |
1603 | patterns.add<SparseTensorDeallocConverter>( |
1604 | arg&: typeConverter, args: patterns.getContext(), args&: createSparseDeallocs); |
1605 | patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>( |
1606 | arg&: typeConverter, args: patterns.getContext(), args&: enableBufferInitialization); |
1607 | } |
1608 | |