1 | //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // A pass that converts sparse tensor types and primitives to actual compiler |
10 | // visible buffers and actual compiler IR that implements these primitives on |
11 | // the selected sparse tensor storage schemes. This pass provides an alternative |
12 | // to the SparseTensorConversion pass, eliminating the dependence on a runtime |
13 | // support library (other than for file I/O), and providing many more |
14 | // opportunities for subsequent compiler optimization of the generated code. |
15 | // |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #include "Utils/CodegenUtils.h" |
19 | #include "Utils/SparseTensorDescriptor.h" |
20 | |
21 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
22 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
23 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
24 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
25 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
26 | #include "mlir/Dialect/SparseTensor/IR/Enums.h" |
27 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
28 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
29 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
30 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
31 | #include "mlir/Transforms/DialectConversion.h" |
32 | |
33 | #include <optional> |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::sparse_tensor; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Helper methods. |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// Flatten the given value ranges into a single vector of values. |
43 | static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { |
44 | SmallVector<Value> result; |
45 | for (const auto &vals : values) |
46 | llvm::append_range(C&: result, R: vals); |
47 | return result; |
48 | } |
49 | |
50 | /// Generates a load with proper `index` typing. |
51 | static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) { |
52 | idx = genCast(builder, loc, idx, builder.getIndexType()); |
53 | return builder.create<memref::LoadOp>(loc, mem, idx); |
54 | } |
55 | |
56 | /// Generates a store with proper `index` typing and proper value. |
57 | static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, |
58 | Value idx) { |
59 | idx = genCast(builder, loc, idx, builder.getIndexType()); |
60 | val = genCast(builder, loc, val, |
61 | cast<ShapedType>(mem.getType()).getElementType()); |
62 | builder.create<memref::StoreOp>(loc, val, mem, idx); |
63 | } |
64 | |
65 | /// Creates a straightforward counting for-loop. |
66 | static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, |
67 | MutableArrayRef<Value> fields, |
68 | Value lower = Value()) { |
69 | Type indexType = builder.getIndexType(); |
70 | if (!lower) |
71 | lower = constantZero(builder, loc, tp: indexType); |
72 | Value one = constantOne(builder, loc, tp: indexType); |
73 | scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields); |
74 | for (unsigned i = 0, e = fields.size(); i < e; i++) |
75 | fields[i] = forOp.getRegionIterArg(i); |
76 | builder.setInsertionPointToStart(forOp.getBody()); |
77 | return forOp; |
78 | } |
79 | |
80 | /// Creates a push back operation. |
81 | static void createPushback(OpBuilder &builder, Location loc, |
82 | MutSparseTensorDescriptor desc, |
83 | SparseTensorFieldKind kind, std::optional<Level> lvl, |
84 | Value value, Value repeat = Value()) { |
85 | Type etp = desc.getMemRefElementType(kind, lvl); |
86 | Value field = desc.getMemRefField(kind, lvl); |
87 | StorageSpecifierKind specFieldKind = toSpecifierKind(kind); |
88 | |
89 | auto pushBackOp = builder.create<PushBackOp>( |
90 | loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, |
91 | genCast(builder, loc, value, etp), repeat); |
92 | |
93 | desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer()); |
94 | desc.setSpecifierField(builder, loc, specFieldKind, lvl, |
95 | pushBackOp.getNewSize()); |
96 | } |
97 | |
98 | /// Generates code that allocates a sparse storage scheme for given rank. |
99 | static void allocSchemeForRank(OpBuilder &builder, Location loc, |
100 | MutSparseTensorDescriptor desc, Level startLvl) { |
101 | const SparseTensorType stt(desc.getRankedTensorType()); |
102 | Value linear = constantIndex(builder, loc, i: 1); |
103 | const Level lvlRank = stt.getLvlRank(); |
104 | for (Level lvl = startLvl; lvl < lvlRank; lvl++) { |
105 | const auto lt = stt.getLvlType(l: lvl); |
106 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { |
107 | // Append linear x positions, initialized to zero. Since each compressed |
108 | // dimension initially already has a single zero entry, this maintains |
109 | // the desired "linear + 1" length property at all times. For loose |
110 | // compression, we multiply linear by two in order to append both the |
111 | // lo/hi positions. |
112 | Value posZero = constantZero(builder, loc, tp: stt.getPosType()); |
113 | if (isLooseCompressedLT(lt)) { |
114 | Value two = constantIndex(builder, loc, i: 2); |
115 | linear = builder.create<arith::MulIOp>(loc, linear, two); |
116 | } |
117 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl, |
118 | /*value=*/posZero, /*repeat=*/linear); |
119 | return; |
120 | } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { |
121 | return; // nothing to do |
122 | } |
123 | // Keep compounding the size, but nothing needs to be initialized |
124 | // at this level. We will eventually reach a compressed level or |
125 | // otherwise the values array for the from-here "all-dense" case. |
126 | assert(isDenseLT(lt)); |
127 | Value size = desc.getLvlSize(builder, loc, lvl); |
128 | linear = builder.create<arith::MulIOp>(loc, linear, size); |
129 | } |
130 | // Reached values array so prepare for an insertion. |
131 | Value valZero = constantZero(builder, loc, tp: stt.getElementType()); |
132 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef, |
133 | lvl: std::nullopt, /*value=*/valZero, /*repeat=*/linear); |
134 | } |
135 | |
136 | /// Creates allocation operation. |
137 | static Value createAllocation(OpBuilder &builder, Location loc, |
138 | MemRefType memRefType, Value sz, |
139 | bool enableInit) { |
140 | Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz); |
141 | Type elemType = memRefType.getElementType(); |
142 | if (enableInit) { |
143 | Value fillValue = constantZero(builder, loc, tp: elemType); |
144 | builder.create<linalg::FillOp>(loc, fillValue, buffer); |
145 | } |
146 | return buffer; |
147 | } |
148 | |
149 | /// Creates the dim sizes array, filling in from dynamic sizes. |
150 | static void createDimSizes(OpBuilder &builder, Location loc, |
151 | SparseTensorType stt, ValueRange dynSizes, |
152 | /*out*/ SmallVectorImpl<Value> &dimSizesValues) { |
153 | const Dimension dimRank = stt.getDimRank(); |
154 | dimSizesValues.clear(); |
155 | dimSizesValues.reserve(N: dimRank); |
156 | unsigned i = 0; |
157 | for (const Size sz : stt.getDimShape()) |
158 | dimSizesValues.push_back(ShapedType::isDynamic(sz) |
159 | ? dynSizes[i++] |
160 | : constantIndex(builder, loc, sz)); |
161 | } |
162 | |
163 | /// Creates allocation for each field in sparse tensor type. Note that |
164 | /// for all dynamic memrefs in the sparse tensor stroage layout, the |
165 | /// memory size is really the capacity of the "vector", while the actual |
166 | /// size resides in the sizes array. |
167 | static void createAllocFields(OpBuilder &builder, Location loc, |
168 | SparseTensorType stt, bool enableInit, |
169 | Value sizeHint, |
170 | SmallVectorImpl<Value> &lvlSizesValues, |
171 | /*out*/ SmallVectorImpl<Value> &fields) { |
172 | Level lvlRank = stt.getLvlRank(); |
173 | // Set up some heuristic sizes. We try to set the initial |
174 | // size based on available information. Otherwise we just |
175 | // initialize a few elements to start the reallocation chain. |
176 | // TODO: refine this |
177 | Value posHeuristic, crdHeuristic, valHeuristic; |
178 | if (stt.isAllDense()) { |
179 | valHeuristic = lvlSizesValues[0]; |
180 | for (Level lvl = 1; lvl < lvlRank; lvl++) |
181 | valHeuristic = |
182 | builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]); |
183 | } else if (sizeHint) { |
184 | if (stt.getAoSCOOStart() == 0) { |
185 | posHeuristic = constantIndex(builder, loc, i: 2); |
186 | crdHeuristic = builder.create<arith::MulIOp>( |
187 | loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS |
188 | } else if (lvlRank == 2 && stt.isDenseLvl(l: 0) && stt.isCompressedLvl(l: 1)) { |
189 | posHeuristic = builder.create<arith::AddIOp>( |
190 | loc, sizeHint, constantIndex(builder, loc, 1)); |
191 | crdHeuristic = sizeHint; |
192 | } else { |
193 | posHeuristic = crdHeuristic = constantIndex(builder, loc, i: 16); |
194 | } |
195 | valHeuristic = sizeHint; |
196 | } else { |
197 | posHeuristic = crdHeuristic = valHeuristic = |
198 | constantIndex(builder, loc, i: 16); |
199 | } |
200 | // Initializes all fields. An initial storage specifier and allocated |
201 | // positions/coordinates/values memrefs (with heuristic capacity). |
202 | foreachFieldAndTypeInSparseTensor( |
203 | stt, |
204 | [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic, |
205 | enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, |
206 | Level /*lvl*/, LevelType /*lt*/) -> bool { |
207 | assert(fields.size() == fIdx); |
208 | Value field; |
209 | switch (fKind) { |
210 | case SparseTensorFieldKind::StorageSpec: |
211 | field = SparseTensorSpecifier::getInitValue(builder, loc, stt); |
212 | break; |
213 | case SparseTensorFieldKind::PosMemRef: |
214 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
215 | posHeuristic, enableInit); |
216 | break; |
217 | case SparseTensorFieldKind::CrdMemRef: |
218 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
219 | crdHeuristic, enableInit); |
220 | break; |
221 | case SparseTensorFieldKind::ValMemRef: |
222 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
223 | valHeuristic, enableInit); |
224 | break; |
225 | } |
226 | assert(field); |
227 | fields.push_back(Elt: field); |
228 | // Returns true to continue the iteration. |
229 | return true; |
230 | }); |
231 | // Initialize the storage scheme to an empty tensor. Sets the lvlSizes |
232 | // and gives all position fields an initial zero entry, so that it is |
233 | // easier to maintain the "linear + 1" length property. |
234 | MutSparseTensorDescriptor desc(stt, fields); |
235 | Value posZero = constantZero(builder, loc, tp: stt.getPosType()); |
236 | for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { |
237 | desc.setLvlSize(builder, loc, lvl, v: lvlSizesValues[lvl]); |
238 | const auto lt = stt.getLvlType(l: lvl); |
239 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) |
240 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl, |
241 | /*value=*/posZero); |
242 | } |
243 | allocSchemeForRank(builder, loc, desc, /*rank=*/startLvl: 0); |
244 | } |
245 | |
246 | /// Helper method that generates block specific to compressed case: |
247 | /// |
248 | /// // given: parentPos = posCursor[lvl-1] |
249 | /// pstart = desc.positions[lvl][parentPos] |
250 | /// pstop = desc.positions[lvl][parentPos+1] |
251 | /// plast = pstop - 1 |
252 | /// msz = desc.coordinates[lvl].size() |
253 | /// if (pstart < pstop) { |
254 | /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl]) |
255 | /// } else { // first insertion |
256 | /// isPresent = false |
257 | /// desc.positions[lvl][parentPos] = msz |
258 | /// } |
259 | /// if (isPresent) { // coordinate is already present |
260 | /// pnext = plast |
261 | /// } else { |
262 | /// desc.coordinates[lvl].push_back(lvlCoords[lvl]) |
263 | /// desc.positions[lvl][parentPos+1] = msz+1 |
264 | /// pnext = msz |
265 | /// <prepare level lvl+1> |
266 | /// } |
267 | /// posCursor[lvl] = pnext |
268 | static Value genCompressed(OpBuilder &builder, Location loc, |
269 | MutSparseTensorDescriptor desc, ValueRange lvlCoords, |
270 | Value /*unused*/, Value parentPos, Level lvl) { |
271 | const SparseTensorType stt(desc.getRankedTensorType()); |
272 | const Level lvlRank = stt.getLvlRank(); |
273 | assert(lvl < lvlRank && "Level is out of bounds" ); |
274 | assert(lvlCoords.size() == static_cast<size_t>(lvlRank) && |
275 | "Level-rank mismatch" ); |
276 | SmallVector<Type> types; |
277 | Type indexType = builder.getIndexType(); |
278 | Type boolType = builder.getIntegerType(1); |
279 | unsigned crdFidx; |
280 | unsigned crdStride; |
281 | std::tie(args&: crdFidx, args&: crdStride) = desc.getCrdMemRefIndexAndStride(lvl); |
282 | const Value one = constantIndex(builder, loc, i: 1); |
283 | const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one); |
284 | const Value positionsAtLvl = desc.getPosMemRef(lvl); |
285 | const Value pstart = genLoad(builder, loc, mem: positionsAtLvl, idx: parentPos); |
286 | const Value pstop = genLoad(builder, loc, mem: positionsAtLvl, idx: pp1); |
287 | const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl); |
288 | const Value crdStrideC = |
289 | crdStride > 1 ? constantIndex(builder, loc, i: crdStride) : Value(); |
290 | const Value msz = |
291 | crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC) |
292 | : crdMsz; |
293 | const Value plast = builder.create<arith::SubIOp>( |
294 | loc, genCast(builder, loc, pstop, indexType), one); |
295 | // Conditional expression. |
296 | Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
297 | pstart, pstop); |
298 | types.push_back(Elt: boolType); |
299 | scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true); |
300 | types.pop_back(); |
301 | builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); |
302 | Value crd = |
303 | genLoad(builder, loc, desc.getMemRefField(crdFidx), |
304 | crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC) |
305 | : plast); |
306 | Value eq = builder.create<arith::CmpIOp>( |
307 | loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType), |
308 | lvlCoords[lvl]); |
309 | builder.create<scf::YieldOp>(loc, eq); |
310 | builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); |
311 | if (lvl > 0) |
312 | genStore(builder, loc, val: msz, mem: positionsAtLvl, idx: parentPos); |
313 | builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false)); |
314 | builder.setInsertionPointAfter(ifOp1); |
315 | // If present construct. Note that for a non-unique dimension level, we |
316 | // simply set the condition to false and rely on CSE/DCE to clean up the IR. |
317 | // |
318 | // TODO: generate less temporary IR? |
319 | // |
320 | for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) |
321 | types.push_back(Elt: desc.getField(i).getType()); |
322 | types.push_back(Elt: indexType); |
323 | const Value p = stt.isUniqueLvl(l: lvl) ? ifOp1.getResult(0) |
324 | : constantI1(builder, loc, b: false); |
325 | scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true); |
326 | // If present (fields unaffected, update pnext to plast). |
327 | builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); |
328 | |
329 | // FIXME: This does not looks like a clean way, but probably the most |
330 | // efficient way. |
331 | desc.getFields().push_back(plast); |
332 | builder.create<scf::YieldOp>(loc, desc.getFields()); |
333 | desc.getFields().pop_back(); |
334 | |
335 | // If !present (changes fields, update pnext). |
336 | builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); |
337 | Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one); |
338 | genStore(builder, loc, val: mszp1, mem: positionsAtLvl, idx: pp1); |
339 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef, lvl, |
340 | /*value=*/lvlCoords[lvl]); |
341 | // Prepare the next level "as needed". |
342 | if ((lvl + 1) < lvlRank) |
343 | allocSchemeForRank(builder, loc, desc, startLvl: lvl + 1); |
344 | |
345 | desc.getFields().push_back(msz); |
346 | builder.create<scf::YieldOp>(loc, desc.getFields()); |
347 | desc.getFields().pop_back(); |
348 | |
349 | // Update fields and return next pos. |
350 | builder.setInsertionPointAfter(ifOp2); |
351 | unsigned o = 0; |
352 | for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) |
353 | desc.setField(fidx: i, v: ifOp2.getResult(o++)); |
354 | return ifOp2.getResult(o); |
355 | } |
356 | |
357 | /// Generates insertion finalization code. |
358 | static void genEndInsert(OpBuilder &builder, Location loc, |
359 | SparseTensorDescriptor desc) { |
360 | const SparseTensorType stt(desc.getRankedTensorType()); |
361 | const Level lvlRank = stt.getLvlRank(); |
362 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
363 | const auto lt = stt.getLvlType(l: lvl); |
364 | if (isCompressedLT(lt)) { |
365 | // Compressed dimensions need a position cleanup for all entries |
366 | // that were not visited during the insertion pass. |
367 | // |
368 | // TODO: avoid cleanup and keep compressed scheme consistent at all |
369 | // times? |
370 | // |
371 | if (lvl > 0) { |
372 | Type posType = stt.getPosType(); |
373 | Value posMemRef = desc.getPosMemRef(lvl); |
374 | Value hi = desc.getPosMemSize(builder, loc, lvl); |
375 | Value zero = constantIndex(builder, loc, i: 0); |
376 | Value one = constantIndex(builder, loc, i: 1); |
377 | // Vector of only one, but needed by createFor's prototype. |
378 | SmallVector<Value, 1> inits{genLoad(builder, loc, mem: posMemRef, idx: zero)}; |
379 | scf::ForOp loop = createFor(builder, loc, hi, inits, one); |
380 | Value i = loop.getInductionVar(); |
381 | Value oldv = loop.getRegionIterArg(0); |
382 | Value newv = genLoad(builder, loc, mem: posMemRef, idx: i); |
383 | Value posZero = constantZero(builder, loc, tp: posType); |
384 | Value cond = builder.create<arith::CmpIOp>( |
385 | loc, arith::CmpIPredicate::eq, newv, posZero); |
386 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), |
387 | cond, /*else*/ true); |
388 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
389 | genStore(builder, loc, val: oldv, mem: posMemRef, idx: i); |
390 | builder.create<scf::YieldOp>(loc, oldv); |
391 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
392 | builder.create<scf::YieldOp>(loc, newv); |
393 | builder.setInsertionPointAfter(ifOp); |
394 | builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
395 | builder.setInsertionPointAfter(loop); |
396 | } |
397 | } else { |
398 | assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) || |
399 | isNOutOfMLT(lt)); |
400 | } |
401 | } |
402 | } |
403 | |
404 | /// Generates a subview into the sizes. |
405 | static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, |
406 | Value sz) { |
407 | auto memTp = llvm::cast<MemRefType>(mem.getType()); |
408 | // For higher-dimensional memrefs, we assume that the innermost |
409 | // dimension is always of the right size. |
410 | // TODO: generate complex truncating view here too? |
411 | if (memTp.getRank() > 1) |
412 | return mem; |
413 | // Truncate linear memrefs to given size. |
414 | return builder |
415 | .create<memref::SubViewOp>( |
416 | loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), |
417 | mem, ValueRange{}, ValueRange{sz}, ValueRange{}, |
418 | ArrayRef<int64_t>{0}, // static offset |
419 | ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size |
420 | ArrayRef<int64_t>{1}) // static stride |
421 | .getResult(); |
422 | } |
423 | |
424 | /// Creates the reassociation array. |
425 | static SmallVector<ReassociationIndices> |
426 | getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) { |
427 | SmallVector<ReassociationIndices> ret(batchLvls + 1, {}); |
428 | // Create reassociation in the form: |
429 | // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank} |
430 | for (unsigned i = 0; i < batchLvls; i++) |
431 | ret[i].push_back(Elt: i); |
432 | |
433 | for (int i = batchLvls, e = srcTp.getRank(); i < e; i++) |
434 | ret.back().push_back(Elt: i); |
435 | |
436 | return ret; |
437 | } |
438 | |
439 | //===----------------------------------------------------------------------===// |
440 | // Codegen rules. |
441 | //===----------------------------------------------------------------------===// |
442 | |
443 | namespace { |
444 | |
445 | /// Helper class to help lowering sparse_tensor.insert operation. |
446 | class SparseInsertGenerator |
447 | : public FuncCallOrInlineGenerator<SparseInsertGenerator> { |
448 | public: |
449 | SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params, |
450 | bool genCall) |
451 | : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){}; |
452 | |
453 | /// Generates code along an insertion path without the need for a "cursor". |
454 | /// This current insertion strategy comes at the expense of some testing |
455 | /// overhead for each insertion. The strategy will be optimized later for |
456 | /// common insertion patterns. The current insertion strategy also assumes |
457 | /// insertions occur in "a reasonable order" that enables building the |
458 | /// storage scheme in an appending/inserting kind of fashion (i.e. no |
459 | /// in-between insertions that need data movement). The implementation |
460 | /// relies on CSE/DCE to clean up all bookkeeping that is not needed. |
461 | /// |
462 | /// TODO: better unord/not-unique; also generalize, optimize, specialize! |
463 | SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args, |
464 | OpBuilder &builder, Location loc) { |
465 | const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); |
466 | const Level lvlRank = stt.getLvlRank(); |
467 | // Extract fields and coordinates from args. |
468 | SmallVector<Value> fields = llvm::to_vector(Range: args.drop_back(n: lvlRank + 1)); |
469 | MutSparseTensorDescriptor desc(stt, fields); |
470 | const SmallVector<Value> coords = |
471 | llvm::to_vector(Range: args.take_back(n: lvlRank + 1).drop_back()); |
472 | Value value = args.back(); |
473 | Value parentPos = constantZero(builder, loc, builder.getIndexType()); |
474 | // Generate code for every level. |
475 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
476 | const auto lt = stt.getLvlType(l: lvl); |
477 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { |
478 | // Create: |
479 | // if (!present) { |
480 | // coordinates[lvl].push_back(coords[lvl]) |
481 | // <update positions and prepare level lvl + 1> |
482 | // } |
483 | // positions[lvl] = coordinates.size() - 1 |
484 | // <insert @ positions[lvl] at next level lvl + 1> |
485 | if (isLooseCompressedLT(lt)) { |
486 | Value two = constantIndex(builder, loc, i: 2); |
487 | parentPos = builder.create<arith::MulIOp>(loc, parentPos, two); |
488 | } |
489 | parentPos = |
490 | genCompressed(builder, loc, desc, lvlCoords: coords, value, parentPos, lvl); |
491 | } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { |
492 | // Create: |
493 | // coordinates[lvl].push_back(coords[lvl]) |
494 | // positions[lvl] = positions[lvl-1] |
495 | // <insert @ positions[lvl] at next level lvl + 1> |
496 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef, |
497 | lvl, /*value=*/coords[lvl]); |
498 | } else { |
499 | assert(isDenseLT(lt)); |
500 | // Construct the new position as: |
501 | // positions[lvl] = size * positions[lvl-1] + coords[lvl] |
502 | // <insert @ positions[lvl] at next level lvl + 1> |
503 | Value size = desc.getLvlSize(builder, loc, lvl); |
504 | Value mult = builder.create<arith::MulIOp>(loc, size, parentPos); |
505 | parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]); |
506 | } |
507 | } |
508 | // Reached the actual value append/insert. |
509 | if (!stt.isDenseLvl(l: lvlRank - 1)) |
510 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef, |
511 | lvl: std::nullopt, value); |
512 | else |
513 | genStore(builder, loc, value, desc.getValMemRef(), parentPos); |
514 | return fields; |
515 | } |
516 | |
517 | std::string getMangledFuncName() { |
518 | // The mangled name of the function has this format: |
519 | // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth> |
520 | constexpr const char kInsertFuncNamePrefix[] = "_insert_" ; |
521 | const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); |
522 | SmallString<32> nameBuffer; |
523 | llvm::raw_svector_ostream nameOstream(nameBuffer); |
524 | nameOstream << kInsertFuncNamePrefix; |
525 | const Level lvlRank = stt.getLvlRank(); |
526 | for (Level l = 0; l < lvlRank; l++) { |
527 | std::string lvlType = toMLIRString(lt: stt.getLvlType(l)); |
528 | // Replace/remove punctuations in level properties. |
529 | std::replace_if( |
530 | first: lvlType.begin(), last: lvlType.end(), |
531 | pred: [](char c) { return c == '(' || c == ','; }, new_value: '_'); |
532 | llvm::erase_if(C&: lvlType, P: [](char c) { return c == ')' || c == ' '; }); |
533 | nameOstream << lvlType << "_" ; |
534 | } |
535 | // Static dim sizes are used in the generated code while dynamic sizes are |
536 | // loaded from the dimSizes buffer. This is the reason for adding the shape |
537 | // to the function name. |
538 | for (const auto sz : stt.getDimShape()) |
539 | nameOstream << sz << "_" ; |
540 | // Permutation information is also used in generating insertion. |
541 | if (!stt.isIdentity()) |
542 | nameOstream << stt.getDimToLvl() << "_" ; |
543 | nameOstream << stt.getElementType() << "_" ; |
544 | nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth(); |
545 | return nameOstream.str().str(); |
546 | } |
547 | |
548 | private: |
549 | TensorType rtp; |
550 | }; |
551 | |
552 | /// Sparse tensor storage conversion rule for returns. |
553 | class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { |
554 | public: |
555 | using OpConversionPattern::OpConversionPattern; |
556 | LogicalResult |
557 | matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, |
558 | ConversionPatternRewriter &rewriter) const override { |
559 | // Create a return with the flattened value extracted from sparse tensors. |
560 | rewriter.replaceOpWithNewOp<func::ReturnOp>( |
561 | op, flattenValues(adaptor.getOperands())); |
562 | return success(); |
563 | } |
564 | }; |
565 | |
566 | /// Sparse tensor storage conversion rule for calls. |
567 | class SparseCallConverter : public OpConversionPattern<func::CallOp> { |
568 | public: |
569 | // The default CallOp converter can not handle 1:N type conversion. |
570 | using OpConversionPattern::OpConversionPattern; |
571 | LogicalResult |
572 | matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor, |
573 | ConversionPatternRewriter &rewriter) const override { |
574 | Location loc = op.getLoc(); |
575 | // In case of: |
576 | // sparse_tensor, f, sparse_tensor = call @foo(...) |
577 | // ==> |
578 | // memref..., f, memref = call @foo(...) replace with |
579 | // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor |
580 | SmallVector<Type> finalRetTy; |
581 | if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy))) |
582 | return failure(); |
583 | |
584 | // (1) Generates new call with flattened return value. |
585 | auto newCall = rewriter.create<func::CallOp>( |
586 | loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands())); |
587 | // (2) Gather sparse tensor returns. |
588 | SmallVector<SmallVector<Value>> packedResultVals; |
589 | // Tracks the offset of current return value (of the original call) |
590 | // relative to the new call (after sparse tensor flattening); |
591 | unsigned retOffset = 0; |
592 | // Temporal buffer to hold the flattened list of type for |
593 | // a sparse tensor. |
594 | SmallVector<Type> sparseFlat; |
595 | for (auto ret : op.getResults()) { |
596 | assert(retOffset < newCall.getNumResults()); |
597 | auto retType = ret.getType(); |
598 | if (failed(typeConverter->convertType(retType, sparseFlat))) |
599 | llvm_unreachable("Failed to convert type in sparse tensor codegen" ); |
600 | |
601 | // Converted types can not be empty when the type conversion succeed. |
602 | assert(!sparseFlat.empty()); |
603 | if (sparseFlat.size() > 1) { |
604 | auto flatSize = sparseFlat.size(); |
605 | packedResultVals.emplace_back(); |
606 | llvm::append_range(packedResultVals.back(), |
607 | newCall.getResults().slice(retOffset, flatSize)); |
608 | retOffset += flatSize; |
609 | } else { |
610 | // If this is an 1:1 conversion, no need for casting. |
611 | packedResultVals.emplace_back(); |
612 | packedResultVals.back().push_back(newCall.getResult(retOffset)); |
613 | retOffset++; |
614 | } |
615 | sparseFlat.clear(); |
616 | } |
617 | |
618 | assert(packedResultVals.size() == op.getNumResults()); |
619 | rewriter.replaceOpWithMultiple(op, std::move(packedResultVals)); |
620 | return success(); |
621 | } |
622 | }; |
623 | |
624 | /// Sparse codegen rule for level accesses. |
625 | class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { |
626 | public: |
627 | using OpConversionPattern::OpConversionPattern; |
628 | LogicalResult |
629 | matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor, |
630 | ConversionPatternRewriter &rewriter) const override { |
631 | std::optional<int64_t> lvl = op.getConstantLvlIndex(); |
632 | RankedTensorType srcType = op.getSource().getType(); |
633 | if (!lvl || !getSparseTensorEncoding(srcType)) |
634 | return failure(); |
635 | |
636 | auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType); |
637 | auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl); |
638 | |
639 | rewriter.replaceOp(op, sz); |
640 | return success(); |
641 | } |
642 | }; |
643 | |
644 | // TODO: use a new SortCOO operation here instead of reusing convert op. |
645 | struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { |
646 | using OpConversionPattern::OpConversionPattern; |
647 | LogicalResult |
648 | matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor, |
649 | ConversionPatternRewriter &rewriter) const override { |
650 | Location loc = op.getLoc(); |
651 | MLIRContext *ctx = op.getContext(); |
652 | |
653 | SparseTensorType srcStt = getSparseTensorType(op.getInputCoo()); |
654 | SparseTensorType dstStt = getSparseTensorType(op.getResultCoo()); |
655 | |
656 | // Should have been verified. |
657 | assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() && |
658 | dstStt.isCOOType() && srcStt.isCOOType()); |
659 | assert(dstStt.hasSameDimToLvl(srcStt)); |
660 | |
661 | // We don't need a mutable descriptor here as we perform sorting in-place. |
662 | auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(), |
663 | op.getInputCoo().getType()); |
664 | auto nnz = desc.getValMemSize(rewriter, op.getLoc()); |
665 | auto crd = desc.getAOSMemRef(); |
666 | auto val = desc.getValMemRef(); |
667 | |
668 | // Otherwise we need another data shuffle and a non-identity map. |
669 | assert(dstStt.hasSameDimToLvl(srcStt)); |
670 | (void)dstStt; // to silence warning when assertion is disabled |
671 | |
672 | auto id = AffineMap::getMultiDimIdentityMap(numDims: srcStt.getLvlRank(), context: ctx); |
673 | |
674 | rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id, |
675 | rewriter.getIndexAttr(0), op.getAlgorithm()); |
676 | |
677 | // Since we do in-place sorting, the destinate tensor will have the same set |
678 | // of memrefs as the source tensor. |
679 | rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()}); |
680 | return success(); |
681 | } |
682 | }; |
683 | |
684 | template <typename Op, StorageSpecifierKind kind> |
685 | class SparseSliceGetterOpConverter : public OpConversionPattern<Op> { |
686 | public: |
687 | using OpConversionPattern<Op>::OpConversionPattern; |
688 | using typename OpConversionPattern<Op>::OneToNOpAdaptor; |
689 | |
690 | LogicalResult |
691 | matchAndRewrite(Op op, OneToNOpAdaptor adaptor, |
692 | ConversionPatternRewriter &rewriter) const override { |
693 | // Simply lowers to specifer.get <field> operation. |
694 | auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(), |
695 | op.getSlice().getType()); |
696 | auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, |
697 | op.getDim().getZExtValue()); |
698 | |
699 | rewriter.replaceOp(op, v); |
700 | return success(); |
701 | } |
702 | }; |
703 | |
704 | /// Sparse codegen rule for trivial tensor casts. |
705 | class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { |
706 | public: |
707 | using OpConversionPattern::OpConversionPattern; |
708 | LogicalResult |
709 | matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor, |
710 | ConversionPatternRewriter &rewriter) const override { |
711 | // Only rewrite identically annotated source/dest. |
712 | auto encDst = getSparseTensorEncoding(op.getType()); |
713 | auto encSrc = getSparseTensorEncoding(op.getSource().getType()); |
714 | if (!encDst || encDst != encSrc) |
715 | return failure(); |
716 | rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); |
717 | return success(); |
718 | } |
719 | }; |
720 | |
721 | class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { |
722 | public: |
723 | using OpConversionPattern::OpConversionPattern; |
724 | LogicalResult |
725 | matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor, |
726 | ConversionPatternRewriter &rewriter) const override { |
727 | // Simply fold the operation. |
728 | rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); |
729 | return success(); |
730 | } |
731 | }; |
732 | |
733 | /// Sparse codegen rule for the alloc operator. |
734 | class SparseTensorAllocConverter |
735 | : public OpConversionPattern<bufferization::AllocTensorOp> { |
736 | public: |
737 | using OpConversionPattern::OpConversionPattern; |
738 | SparseTensorAllocConverter(const TypeConverter &typeConverter, |
739 | MLIRContext *context, bool enableInit) |
740 | : OpConversionPattern(typeConverter, context), |
741 | enableBufferInitialization(enableInit) {} |
742 | |
743 | LogicalResult |
744 | matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor, |
745 | ConversionPatternRewriter &rewriter) const override { |
746 | const auto resType = getSparseTensorType(op); |
747 | if (!resType.hasEncoding()) |
748 | return failure(); |
749 | |
750 | Location loc = op.getLoc(); |
751 | // Deal with copy. |
752 | if (op.getCopy()) { |
753 | auto desc = getDescriptorFromTensorTuple( |
754 | adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType())); |
755 | SmallVector<Value> fields; |
756 | fields.reserve(N: desc.getNumFields()); |
757 | // Memcpy on memref fields. |
758 | for (auto field : desc.getMemRefFields()) { |
759 | auto memrefTp = cast<MemRefType>(field.getType()); |
760 | auto size = rewriter.create<memref::DimOp>(loc, field, 0); |
761 | auto copied = |
762 | rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size}); |
763 | rewriter.create<memref::CopyOp>(loc, field, copied); |
764 | fields.push_back(copied); |
765 | } |
766 | // Reuses specifier. |
767 | fields.push_back(Elt: desc.getSpecifier()); |
768 | assert(fields.size() == desc.getNumFields()); |
769 | rewriter.replaceOpWithMultiple(op, {fields}); |
770 | return success(); |
771 | } |
772 | |
773 | if (!resType.isIdentity()) { |
774 | return rewriter.notifyMatchFailure( |
775 | op, "try run --sparse-reinterpret-map before codegen" ); |
776 | } |
777 | // Level size equals to dimension size since lvl2dim map is an identity map. |
778 | SmallVector<Value> lvlSizesValues; |
779 | createDimSizes(rewriter, loc, resType, |
780 | flattenValues(adaptor.getDynamicSizes()), |
781 | /*dimSizesValues=*/lvlSizesValues); |
782 | |
783 | // Construct allocation for each field. |
784 | Value sizeHint = op.getSizeHint(); |
785 | SmallVector<Value> fields; |
786 | createAllocFields(rewriter, loc, resType, enableBufferInitialization, |
787 | sizeHint, lvlSizesValues, fields); |
788 | |
789 | // Replace operation with resulting memrefs. |
790 | rewriter.replaceOpWithMultiple(op, {fields}); |
791 | return success(); |
792 | } |
793 | |
794 | private: |
795 | bool enableBufferInitialization; |
796 | }; |
797 | |
798 | /// Sparse codegen rule for the empty tensor operator. |
799 | class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { |
800 | public: |
801 | using OpConversionPattern::OpConversionPattern; |
802 | SparseTensorEmptyConverter(const TypeConverter &typeConverter, |
803 | MLIRContext *context, bool enableInit) |
804 | : OpConversionPattern(typeConverter, context), |
805 | enableBufferInitialization(enableInit) {} |
806 | |
807 | LogicalResult |
808 | matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, |
809 | ConversionPatternRewriter &rewriter) const override { |
810 | const auto resType = getSparseTensorType(op); |
811 | if (!resType.hasEncoding()) |
812 | return failure(); |
813 | |
814 | if (!resType.isIdentity()) { |
815 | return rewriter.notifyMatchFailure( |
816 | op, "try run --sparse-reinterpret-map before codegen" ); |
817 | } |
818 | |
819 | Location loc = op.getLoc(); |
820 | // Level size equals to dimension size since lvl2dim map is an identity map. |
821 | SmallVector<Value> lvlSizesValues; |
822 | createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), |
823 | /*dimSizesValues=*/lvlSizesValues); |
824 | // Construct allocation for each field. |
825 | Value sizeHint; // none |
826 | SmallVector<Value> fields; |
827 | createAllocFields(rewriter, loc, resType, enableBufferInitialization, |
828 | sizeHint, lvlSizesValues, fields); |
829 | |
830 | // Replace operation with resulting memrefs. |
831 | rewriter.replaceOpWithMultiple(op, {fields}); |
832 | return success(); |
833 | } |
834 | |
835 | private: |
836 | bool enableBufferInitialization; |
837 | }; |
838 | |
839 | /// Sparse codegen rule for the dealloc operator. |
840 | class SparseTensorDeallocConverter |
841 | : public OpConversionPattern<bufferization::DeallocTensorOp> { |
842 | public: |
843 | using OpConversionPattern::OpConversionPattern; |
844 | SparseTensorDeallocConverter(const TypeConverter &typeConverter, |
845 | MLIRContext *context, bool createDeallocs) |
846 | : OpConversionPattern(typeConverter, context), |
847 | createDeallocs(createDeallocs) {} |
848 | |
849 | LogicalResult |
850 | matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor, |
851 | ConversionPatternRewriter &rewriter) const override { |
852 | auto enc = getSparseTensorEncoding(op.getTensor().getType()); |
853 | if (!enc) |
854 | return failure(); |
855 | |
856 | // If user requests not to deallocate sparse tensors, simply erase the |
857 | // operation. |
858 | if (createDeallocs) { |
859 | // Replace the sparse tensor deallocation with field deallocations. |
860 | Location loc = op.getLoc(); |
861 | auto desc = getDescriptorFromTensorTuple( |
862 | adaptor.getTensor(), |
863 | cast<RankedTensorType>(op.getTensor().getType())); |
864 | for (auto input : desc.getMemRefFields()) |
865 | // Deallocate every buffer used to store the sparse tensor handler. |
866 | rewriter.create<memref::DeallocOp>(loc, input); |
867 | } |
868 | rewriter.eraseOp(op: op); |
869 | return success(); |
870 | } |
871 | |
872 | private: |
873 | const bool createDeallocs; |
874 | }; |
875 | |
876 | /// Sparse codegen rule for tensor rematerialization. |
877 | class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { |
878 | public: |
879 | using OpConversionPattern::OpConversionPattern; |
880 | LogicalResult |
881 | matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor, |
882 | ConversionPatternRewriter &rewriter) const override { |
883 | // Prepare descriptor. |
884 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
885 | op.getTensor().getType()); |
886 | // Generate optional insertion finalization code. |
887 | if (op.getHasInserts()) |
888 | genEndInsert(rewriter, op.getLoc(), desc); |
889 | // Replace operation with resulting memrefs. |
890 | rewriter.replaceOpWithMultiple(op, {desc.getFields()}); |
891 | return success(); |
892 | } |
893 | }; |
894 | |
895 | /// Sparse codegen rule for the expand op. |
896 | class SparseExpandConverter : public OpConversionPattern<ExpandOp> { |
897 | public: |
898 | using OpConversionPattern::OpConversionPattern; |
899 | LogicalResult |
900 | matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor, |
901 | ConversionPatternRewriter &rewriter) const override { |
902 | if (!getSparseTensorEncoding(op.getTensor().getType())) |
903 | return failure(); |
904 | Location loc = op->getLoc(); |
905 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
906 | op.getTensor().getType()); |
907 | const auto srcType = getSparseTensorType(op.getTensor()); |
908 | Type eltType = srcType.getElementType(); |
909 | Type boolType = rewriter.getIntegerType(1); |
910 | Type idxType = rewriter.getIndexType(); |
911 | // All initialization should be done on entry of the loop nest. |
912 | rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); |
913 | |
914 | // Determine the size for access expansion (always the innermost stored |
915 | // level size). |
916 | const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1); |
917 | // Generate a memref for `sz` elements of type `t`. |
918 | const auto genAlloc = [&](Type t) { |
919 | const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); |
920 | return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); |
921 | }; |
922 | // Allocate temporary buffers for values/filled-switch and added. |
923 | // We do not use stack buffers for this, since the expanded size may |
924 | // be rather large (as it envelops a single expanded dense dimension). |
925 | Value values = genAlloc(eltType); |
926 | Value filled = genAlloc(boolType); |
927 | Value added = genAlloc(idxType); |
928 | Value zero = constantZero(builder&: rewriter, loc, tp: idxType); |
929 | // Reset the values/filled-switch to all-zero/false. Note that this |
930 | // introduces an O(N) operation into the computation, but this reset |
931 | // operation is amortized over the innermost loops for the access |
932 | // pattern expansion. As noted in the operation doc, we would like |
933 | // to amortize this setup cost even between kernels. |
934 | rewriter.create<linalg::FillOp>( |
935 | loc, ValueRange{constantZero(rewriter, loc, eltType)}, |
936 | ValueRange{values}); |
937 | rewriter.create<linalg::FillOp>( |
938 | loc, ValueRange{constantZero(rewriter, loc, boolType)}, |
939 | ValueRange{filled}); |
940 | // Replace expansion op with these buffers and initial coordinate. |
941 | assert(op.getNumResults() == 4); |
942 | rewriter.replaceOp(op, {values, filled, added, zero}); |
943 | return success(); |
944 | } |
945 | }; |
946 | |
947 | /// Sparse codegen rule for the compress operator. |
948 | class SparseCompressConverter : public OpConversionPattern<CompressOp> { |
949 | public: |
950 | using OpConversionPattern::OpConversionPattern; |
951 | LogicalResult |
952 | matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor, |
953 | ConversionPatternRewriter &rewriter) const override { |
954 | Location loc = op->getLoc(); |
955 | SmallVector<Value> fields; |
956 | auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields, |
957 | op.getTensor().getType()); |
958 | Value values = llvm::getSingleElement(adaptor.getValues()); |
959 | Value filled = llvm::getSingleElement(adaptor.getFilled()); |
960 | Value added = llvm::getSingleElement(adaptor.getAdded()); |
961 | Value count = llvm::getSingleElement(adaptor.getCount()); |
962 | const SparseTensorType dstType(desc.getRankedTensorType()); |
963 | Type eltType = dstType.getElementType(); |
964 | |
965 | // If the innermost level is ordered, we need to sort the coordinates |
966 | // in the "added" array prior to applying the compression. |
967 | if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) |
968 | rewriter.create<SortOp>( |
969 | loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1), |
970 | rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); |
971 | // While performing the insertions, we also need to reset the elements |
972 | // of the values/filled-switch by only iterating over the set elements, |
973 | // to ensure that the runtime complexity remains proportional to the |
974 | // sparsity of the expanded access pattern. |
975 | // |
976 | // Generate |
977 | // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) { |
978 | // crd = added[i]; |
979 | // value = values[crd]; |
980 | // insert({lvlCoords, crd}, value); |
981 | // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value); |
982 | // values[crd] = 0; |
983 | // filled[crd] = false; |
984 | // yield new_memrefs |
985 | // } |
986 | scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); |
987 | Value i = loop.getInductionVar(); |
988 | |
989 | Value crd = genLoad(builder&: rewriter, loc, mem: added, idx: i); |
990 | Value value = genLoad(builder&: rewriter, loc, mem: values, idx: crd); |
991 | SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end()); |
992 | SmallVector<Type> flatSpTensorTps = llvm::to_vector( |
993 | llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); |
994 | SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords()); |
995 | params.append(in_start: flatLvlCoords.begin(), in_end: flatLvlCoords.end()); |
996 | params.push_back(Elt: crd); |
997 | params.push_back(Elt: value); |
998 | SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, |
999 | params, /*genCall=*/true); |
1000 | SmallVector<Value> insertRet = insertGen.genCallOrInline(builder&: rewriter, loc); |
1001 | genStore(builder&: rewriter, loc, val: constantZero(builder&: rewriter, loc, tp: eltType), mem: values, idx: crd); |
1002 | genStore(builder&: rewriter, loc, val: constantI1(builder&: rewriter, loc, b: false), mem: filled, idx: crd); |
1003 | rewriter.create<scf::YieldOp>(loc, insertRet); |
1004 | |
1005 | rewriter.setInsertionPointAfter(loop); |
1006 | // Deallocate the buffers on exit of the full loop nest. |
1007 | Operation *parent = getTop(op); |
1008 | rewriter.setInsertionPointAfter(parent); |
1009 | rewriter.create<memref::DeallocOp>(loc, values); |
1010 | rewriter.create<memref::DeallocOp>(loc, filled); |
1011 | rewriter.create<memref::DeallocOp>(loc, added); |
1012 | // Replace operation with resulting memrefs. |
1013 | rewriter.replaceOpWithMultiple(op, {loop->getResults()}); |
1014 | return success(); |
1015 | } |
1016 | }; |
1017 | |
1018 | /// Sparse codegen rule for the insert operator. |
1019 | class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { |
1020 | public: |
1021 | using OpConversionPattern::OpConversionPattern; |
1022 | LogicalResult |
1023 | matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor, |
1024 | ConversionPatternRewriter &rewriter) const override { |
1025 | auto stt = getSparseTensorType(op.getDest()); |
1026 | if (!stt.hasEncoding()) |
1027 | return failure(); |
1028 | assert(stt.isIdentity() && "Run reinterpret-map before conversion." ); |
1029 | |
1030 | Location loc = op.getLoc(); |
1031 | auto desc = |
1032 | getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType()); |
1033 | TypeRange flatSpTensorTps = desc.getFields().getTypes(); |
1034 | SmallVector<Value> params = llvm::to_vector(desc.getFields()); |
1035 | SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices()); |
1036 | params.append(in_start: flatIndices.begin(), in_end: flatIndices.end()); |
1037 | params.push_back(Elt: llvm::getSingleElement(adaptor.getScalar())); |
1038 | SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps, |
1039 | params, /*genCall=*/true); |
1040 | SmallVector<Value> ret = insertGen.genCallOrInline(builder&: rewriter, loc); |
1041 | // Replace operation with resulting memrefs. |
1042 | rewriter.replaceOpWithMultiple(op, {ret}); |
1043 | return success(); |
1044 | } |
1045 | }; |
1046 | |
1047 | /// Sparse codegen rule for position accesses. |
1048 | class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> { |
1049 | public: |
1050 | using OpAdaptor = typename ToPositionsOp::Adaptor; |
1051 | using OpConversionPattern<ToPositionsOp>::OpConversionPattern; |
1052 | LogicalResult |
1053 | matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, |
1054 | ConversionPatternRewriter &rewriter) const override { |
1055 | // Replace the requested position access with corresponding field. |
1056 | // The view is restricted to the actual size to ensure clients |
1057 | // of this operation truly observe size, not capacity! |
1058 | Location loc = op.getLoc(); |
1059 | Level lvl = op.getLevel(); |
1060 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
1061 | op.getTensor().getType()); |
1062 | auto mem = desc.getPosMemRef(lvl); |
1063 | auto size = desc.getPosMemSize(rewriter, loc, lvl); |
1064 | rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); |
1065 | return success(); |
1066 | } |
1067 | }; |
1068 | |
1069 | /// Sparse codegen rule for accessing the coordinates arrays. |
1070 | class SparseToCoordinatesConverter |
1071 | : public OpConversionPattern<ToCoordinatesOp> { |
1072 | public: |
1073 | using OpAdaptor = typename ToCoordinatesOp::Adaptor; |
1074 | using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern; |
1075 | LogicalResult |
1076 | matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, |
1077 | ConversionPatternRewriter &rewriter) const override { |
1078 | // Replace the requested coordinates access with corresponding field. |
1079 | // The view is restricted to the actual size to ensure clients |
1080 | // of this operation truly observe size, not capacity! |
1081 | Location loc = op.getLoc(); |
1082 | Level lvl = op.getLevel(); |
1083 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
1084 | op.getTensor().getType()); |
1085 | auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl); |
1086 | if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) { |
1087 | auto size = desc.getCrdMemSize(rewriter, loc, lvl); |
1088 | mem = genSliceToSize(rewriter, loc, mem, size); |
1089 | } |
1090 | rewriter.replaceOp(op, mem); |
1091 | return success(); |
1092 | } |
1093 | }; |
1094 | |
1095 | /// Sparse codegen rule for accessing the linear coordinates buffer. |
1096 | class SparseToCoordinatesBufferConverter |
1097 | : public OpConversionPattern<ToCoordinatesBufferOp> { |
1098 | public: |
1099 | using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; |
1100 | using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern; |
1101 | LogicalResult |
1102 | matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, |
1103 | ConversionPatternRewriter &rewriter) const override { |
1104 | // Replace the requested coordinates access with corresponding field. |
1105 | // The view is restricted to the actual size to ensure clients |
1106 | // of this operation truly observe size, not capacity! |
1107 | Location loc = op.getLoc(); |
1108 | Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart(); |
1109 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
1110 | op.getTensor().getType()); |
1111 | auto mem = desc.getAOSMemRef(); |
1112 | auto size = desc.getCrdMemSize(rewriter, loc, lvl); |
1113 | rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); |
1114 | return success(); |
1115 | } |
1116 | }; |
1117 | |
1118 | /// Sparse codegen rule for value accesses. |
1119 | class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> { |
1120 | public: |
1121 | using OpAdaptor = typename ToValuesOp::Adaptor; |
1122 | using OpConversionPattern<ToValuesOp>::OpConversionPattern; |
1123 | LogicalResult |
1124 | matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, |
1125 | ConversionPatternRewriter &rewriter) const override { |
1126 | // Replace the requested values access with corresponding field. |
1127 | // The view is restricted to the actual size to ensure clients |
1128 | // of this operation truly observe size, not capacity! |
1129 | Location loc = op.getLoc(); |
1130 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
1131 | op.getTensor().getType()); |
1132 | auto mem = desc.getValMemRef(); |
1133 | auto size = desc.getValMemSize(rewriter, loc); |
1134 | rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); |
1135 | return success(); |
1136 | } |
1137 | }; |
1138 | |
1139 | /// Sparse codegen rule for the convert operator. |
1140 | class SparseConvertConverter : public OpConversionPattern<ConvertOp> { |
1141 | public: |
1142 | using OpConversionPattern::OpConversionPattern; |
1143 | LogicalResult |
1144 | matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor, |
1145 | ConversionPatternRewriter &rewriter) const override { |
1146 | SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); |
1147 | SparseTensorEncodingAttr encSrc = |
1148 | getSparseTensorEncoding(op.getSource().getType()); |
1149 | // The output tensor can not be a slice and those cases should have been |
1150 | // rejected by ConvertOp::verify() already. |
1151 | assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices." ); |
1152 | // Different encoding (except for different bitwidth) should be handled by |
1153 | // rewriting. |
1154 | // We need further rewrites if the input tensor is a slice too. |
1155 | if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() || |
1156 | encSrc.isSlice()) { |
1157 | return failure(); |
1158 | } |
1159 | |
1160 | Type retElemTp = op.getResult().getType().getElementType(); |
1161 | Type srcElemTp = op.getSource().getType().getElementType(); |
1162 | // Fold the trivial cases. |
1163 | if (retElemTp == srcElemTp && encDst == encSrc) { |
1164 | rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); |
1165 | return success(); |
1166 | } |
1167 | // |
1168 | // Do element-wise type conversion without using InsertOp. |
1169 | // |
1170 | // for each memref in srcTensor: |
1171 | // dst = memref.alloc |
1172 | // if srcMemRefType != dstMemRefType: |
1173 | // for every dst[i] = cast(src[i]) |
1174 | // else: |
1175 | // dst = memref.copy(src) |
1176 | Location loc = op.getLoc(); |
1177 | auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(), |
1178 | op.getSource().getType()); |
1179 | SmallVector<Value> fields; |
1180 | foreachFieldAndTypeInSparseTensor( |
1181 | SparseTensorType(cast<RankedTensorType>(op.getResult().getType())), |
1182 | [&rewriter, &fields, srcDesc, |
1183 | loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl, |
1184 | LevelType /*lt*/) -> bool { |
1185 | // Simply reuses the storage specifier as it is an SSA value. |
1186 | if (fKind == SparseTensorFieldKind::StorageSpec) { |
1187 | fields.push_back(Elt: srcDesc.getSpecifier()); |
1188 | } else { |
1189 | // Allocates new memrefs |
1190 | Value srcMem = srcDesc.getMemRefField(fIdx); |
1191 | // TODO: We can instead use the actual memSize in specifier, that |
1192 | // would require a subViewOp to avoid overflow when copying |
1193 | // values. |
1194 | Value sz = linalg::createOrFoldDimOp(b&: rewriter, loc, val: srcMem, dim: 0); |
1195 | auto dstMem = rewriter.create<memref::AllocOp>( |
1196 | loc, cast<MemRefType>(fTp), sz); |
1197 | if (fTp != srcMem.getType()) { |
1198 | // Converts elements type. |
1199 | scf::buildLoopNest( |
1200 | builder&: rewriter, loc, lbs: constantIndex(builder&: rewriter, loc, i: 0), ubs: sz, |
1201 | steps: constantIndex(builder&: rewriter, loc, i: 1), |
1202 | bodyBuilder: [srcMem, &dstMem](OpBuilder &builder, Location loc, |
1203 | ValueRange ivs) { |
1204 | Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs); |
1205 | Value casted = genCast(builder, loc, v, |
1206 | dstMem.getType().getElementType()); |
1207 | builder.create<memref::StoreOp>(loc, casted, dstMem, ivs); |
1208 | }); |
1209 | } else { |
1210 | // TODO: We can even reuse the same memref for the new tensor, |
1211 | // but that requires a `ref-counting` based memory management |
1212 | // for shared memrefs between multiple sparse tensors. |
1213 | rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); |
1214 | } |
1215 | fields.push_back(Elt: dstMem); |
1216 | } |
1217 | return true; |
1218 | }); |
1219 | |
1220 | rewriter.replaceOpWithMultiple(op, {fields}); |
1221 | return success(); |
1222 | } |
1223 | }; |
1224 | |
1225 | class |
1226 | : public OpConversionPattern<tensor::ExtractSliceOp> { |
1227 | public: |
1228 | using OpConversionPattern::OpConversionPattern; |
1229 | LogicalResult |
1230 | matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor, |
1231 | ConversionPatternRewriter &rewriter) const override { |
1232 | Location loc = op.getLoc(); |
1233 | MLIRContext *ctx = op.getContext(); |
1234 | auto srcEnc = getSparseTensorEncoding(op.getSourceType()); |
1235 | auto dstEnc = getSparseTensorEncoding(op.getResult().getType()); |
1236 | // TODO: We should check these in ExtractSliceOp::verify. |
1237 | if (!srcEnc || !dstEnc || !dstEnc.isSlice()) |
1238 | return failure(); |
1239 | assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); |
1240 | |
1241 | SmallVector<Value> fields; |
1242 | auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields, |
1243 | op.getSource().getType()); |
1244 | |
1245 | auto newSpec = rewriter.create<StorageSpecifierInitOp>( |
1246 | loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); |
1247 | desc.setSpecifier(newSpec); |
1248 | |
1249 | // Fills in slice information. |
1250 | for (auto [idx, offset, size, stride] : llvm::enumerate( |
1251 | op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) { |
1252 | Dimension dim = idx; |
1253 | |
1254 | Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset); |
1255 | Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size); |
1256 | Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride); |
1257 | // TODO: We could probably only set dynamic value here. But it would |
1258 | // requires us to fill the hole when casting a static slice to dynamic |
1259 | // slice. |
1260 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset, |
1261 | dim, offsetV); |
1262 | |
1263 | // FIXME: we need to distinguish level sizes and dimension size for slices |
1264 | // here. Maybe we should store slice level sizes in a different array |
1265 | // instead of reusing it. |
1266 | assert(srcEnc.isIdentity()); |
1267 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim, |
1268 | sizeV); |
1269 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride, |
1270 | dim, strideV); |
1271 | } |
1272 | |
1273 | // NOTE: we can not generate tuples directly from descriptor here, as the |
1274 | // descriptor is holding the original type, yet we want the slice type |
1275 | // here (they shared every memref but with an updated specifier). |
1276 | rewriter.replaceOpWithMultiple(op, {desc.getFields()}); |
1277 | return success(); |
1278 | } |
1279 | }; |
1280 | |
1281 | /// Sparse codegen rule for number of entries operator. |
1282 | class SparseNumberOfEntriesConverter |
1283 | : public OpConversionPattern<NumberOfEntriesOp> { |
1284 | public: |
1285 | using OpConversionPattern::OpConversionPattern; |
1286 | LogicalResult |
1287 | matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor, |
1288 | ConversionPatternRewriter &rewriter) const override { |
1289 | // Query memSizes for the actually stored values. |
1290 | // FIXME: the nse value computed in this way might be wrong when there is |
1291 | // any "loose_compressed" level. |
1292 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
1293 | op.getTensor().getType()); |
1294 | rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); |
1295 | return success(); |
1296 | } |
1297 | }; |
1298 | |
1299 | struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> { |
1300 | using OpConversionPattern::OpConversionPattern; |
1301 | LogicalResult |
1302 | matchAndRewrite(AssembleOp op, OpAdaptor adaptor, |
1303 | ConversionPatternRewriter &rewriter) const override { |
1304 | Location loc = op.getLoc(); |
1305 | const auto stt = getSparseTensorType(op.getResult()); |
1306 | |
1307 | SmallVector<Value> fields; |
1308 | |
1309 | foreachFieldAndTypeInSparseTensor( |
1310 | stt, |
1311 | [&rewriter, &fields, &op, &stt, |
1312 | loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, |
1313 | Level /*lvl*/, LevelType lt) -> bool { |
1314 | assert(fields.size() == fIdx); |
1315 | if (fKind == SparseTensorFieldKind::StorageSpec) { |
1316 | fields.push_back( |
1317 | Elt: SparseTensorSpecifier::getInitValue(builder&: rewriter, loc, stt: stt)); |
1318 | } else { |
1319 | // Else simply takes the inputs. |
1320 | Value tensor = fKind == SparseTensorFieldKind::ValMemRef |
1321 | ? op.getValues() |
1322 | : op.getLevels()[fIdx]; |
1323 | // TODO: handle batch. |
1324 | TypedValue<BaseMemRefType> mem = genToMemref(builder&: rewriter, loc, tensor); |
1325 | if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) { |
1326 | // Flattens the buffer to batchLvlRank. |
1327 | auto reassoc = getReassociationForFlattening( |
1328 | mem.getType(), stt.getBatchLvlRank()); |
1329 | mem = rewriter.create<memref::CastOp>( |
1330 | loc, fType, |
1331 | rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc)); |
1332 | } else { |
1333 | mem = rewriter.create<memref::CastOp>(loc, fType, mem); |
1334 | } |
1335 | fields.push_back(Elt: mem); |
1336 | } |
1337 | return true; |
1338 | }); |
1339 | |
1340 | MutSparseTensorDescriptor desc(stt, fields); |
1341 | Value c0 = constantIndex(builder&: rewriter, loc, i: 0); |
1342 | Value c1 = constantIndex(builder&: rewriter, loc, i: 1); |
1343 | Value c2 = constantIndex(builder&: rewriter, loc, i: 2); |
1344 | Value posBack = c0; // index to the last value in the position array |
1345 | Value memSize = c1; // memory size for current array |
1346 | |
1347 | Level trailCOOStart = stt.getAoSCOOStart(); |
1348 | Level trailCOORank = stt.getLvlRank() - trailCOOStart; |
1349 | // Sets up SparseTensorSpecifier. |
1350 | for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { |
1351 | assert(!ShapedType::isDynamic(stt.getDimShape()[lvl])); |
1352 | |
1353 | // Sets up the level size. |
1354 | auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]); |
1355 | desc.setLvlSize(builder&: rewriter, loc, lvl, v: lvlSize); |
1356 | // We use a single AOS array to store the trailing COO, so there is only |
1357 | // one memory size to set for the entire COO section. |
1358 | if (lvl > trailCOOStart) |
1359 | continue; |
1360 | |
1361 | // Sets up the memory size by reading the last value in position array. |
1362 | LevelType lt = stt.getLvlType(lvl); |
1363 | // Simply forwards the position index when this is a dense level. |
1364 | if (lt.isa<LevelFormat::Dense>()) { |
1365 | memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize); |
1366 | posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); |
1367 | continue; |
1368 | } |
1369 | if (lt.isa<LevelFormat::Batch>()) { |
1370 | // Skips batch levels as it is not linearized. |
1371 | // FIXME: this assumes that every batch has the same number of nse, need |
1372 | // to be generalized to handle varied-size batches. |
1373 | continue; |
1374 | } |
1375 | |
1376 | if (isWithPosLT(lt)) { |
1377 | assert(isCompressedLT(lt) || isLooseCompressedLT(lt)); |
1378 | if (isLooseCompressedLT(lt)) { |
1379 | memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2); |
1380 | posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); |
1381 | } else { |
1382 | assert(isCompressedLT(lt)); |
1383 | posBack = memSize; |
1384 | memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1); |
1385 | } |
1386 | desc.setPosMemSize(builder&: rewriter, loc, lvl, v: memSize); |
1387 | // The last value in position array is the memory size for next level. |
1388 | // FIXME: this assumes that every batch has the same number of nse, need |
1389 | // to be generalized to handle varied-size batches. |
1390 | SmallVector<Value> batched(stt.getBatchLvlRank(), |
1391 | constantIndex(builder&: rewriter, loc, i: 0)); |
1392 | batched.push_back(Elt: posBack); |
1393 | memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched); |
1394 | posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1); |
1395 | } |
1396 | assert(isWithCrdLT(lt) && lvl <= trailCOOStart); |
1397 | // FIXME: This seems to be unnecessarily complex, can we simplify it? |
1398 | if (lvl == trailCOOStart) { |
1399 | Value cooSz = rewriter.create<arith::MulIOp>( |
1400 | loc, memSize, constantIndex(rewriter, loc, trailCOORank)); |
1401 | desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: cooSz); |
1402 | } else { |
1403 | desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: memSize); |
1404 | } |
1405 | } |
1406 | desc.setValMemSize(builder&: rewriter, loc, v: memSize); |
1407 | |
1408 | rewriter.replaceOpWithMultiple(op, {desc.getFields()}); |
1409 | return success(); |
1410 | } |
1411 | }; |
1412 | |
1413 | struct SparseDisassembleOpConverter |
1414 | : public OpConversionPattern<DisassembleOp> { |
1415 | using OpConversionPattern::OpConversionPattern; |
1416 | SparseDisassembleOpConverter(const TypeConverter &typeConverter, |
1417 | MLIRContext *context) |
1418 | : OpConversionPattern(typeConverter, context) {} |
1419 | |
1420 | LogicalResult |
1421 | matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor, |
1422 | ConversionPatternRewriter &rewriter) const override { |
1423 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
1424 | op.getTensor().getType()); |
1425 | Location loc = op.getLoc(); |
1426 | SmallVector<Value> retMem; |
1427 | SmallVector<Value> retLen; |
1428 | desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, |
1429 | &retLen](FieldIndex fid, |
1430 | SparseTensorFieldKind fKind, |
1431 | Level lvl, LevelType lt) -> bool { |
1432 | if (fKind == SparseTensorFieldKind::StorageSpec) |
1433 | return true; |
1434 | SparseTensorType stt(desc.getRankedTensorType()); |
1435 | Value sz, src; |
1436 | TypedValue<BaseMemRefType> dst; |
1437 | if (fKind == SparseTensorFieldKind::ValMemRef) { |
1438 | sz = desc.getValMemSize(rewriter, loc); |
1439 | src = desc.getValMemRef(); |
1440 | dst = genToMemref(rewriter, loc, op.getOutValues()); |
1441 | |
1442 | retMem.push_back(Elt: dst); |
1443 | Type valLenTp = op.getValLen().getType(); |
1444 | retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: valLenTp)); |
1445 | } else { |
1446 | assert(fKind == SparseTensorFieldKind::PosMemRef || |
1447 | fKind == SparseTensorFieldKind::CrdMemRef); |
1448 | |
1449 | sz = fKind == SparseTensorFieldKind::PosMemRef |
1450 | ? desc.getPosMemSize(rewriter, loc, lvl) |
1451 | : desc.getCrdMemSize(rewriter, loc, lvl); |
1452 | src = desc.getMemRefField(fid); |
1453 | dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]); |
1454 | retMem.push_back(Elt: dst); |
1455 | // Retrieves the corresponding level length type. |
1456 | Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()]; |
1457 | retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: lvlLenTp)); |
1458 | } |
1459 | Value flatOut = dst; |
1460 | if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) { |
1461 | auto reassoc = |
1462 | getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank()); |
1463 | flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc); |
1464 | } |
1465 | Value dstMem = genSliceToSize(builder&: rewriter, loc, mem: flatOut, sz); |
1466 | Value srcMem = genSliceToSize(builder&: rewriter, loc, mem: src, sz); |
1467 | rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); |
1468 | return true; |
1469 | }); |
1470 | |
1471 | // Converts MemRefs back to Tensors. |
1472 | SmallVector<Value> retValues = llvm::to_vector( |
1473 | Range: llvm::map_range(C&: retMem, F: [&rewriter, loc](Value v) -> Value { |
1474 | return rewriter.create<bufferization::ToTensorOp>(loc, v); |
1475 | })); |
1476 | // Appends the actual memory length used in each buffer returned. |
1477 | retValues.append(in_start: retLen.begin(), in_end: retLen.end()); |
1478 | rewriter.replaceOp(op, retValues); |
1479 | return success(); |
1480 | } |
1481 | }; |
1482 | |
1483 | struct SparseNewConverter : public OpConversionPattern<NewOp> { |
1484 | using OpConversionPattern::OpConversionPattern; |
1485 | LogicalResult |
1486 | matchAndRewrite(NewOp op, OpAdaptor adaptor, |
1487 | ConversionPatternRewriter &rewriter) const override { |
1488 | Location loc = op.getLoc(); |
1489 | const auto dstTp = getSparseTensorType(op.getResult()); |
1490 | // Creating COO with NewOp is handled by direct IR codegen. All other cases |
1491 | // are handled by rewriting. |
1492 | if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0) |
1493 | return failure(); |
1494 | |
1495 | // Implement as follows: |
1496 | // %reader = @createCheckedSparseTensorReader(%filename) |
1497 | // %nse = @getSparseTensorNSE(%reader) |
1498 | // %coo = bufferization.alloc_tensor an ordered COO with |
1499 | // dst dim ordering, size_hint = %nse |
1500 | // %coordinates = sparse_tensor.coordinates_buffer(%coo) |
1501 | // %values = sparse_tensor.values(%coo) |
1502 | // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values) |
1503 | // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values) |
1504 | // update storage specifier |
1505 | // @delSparseTensorReader(%reader) |
1506 | SmallVector<Value> dimSizesValues; |
1507 | Value dimSizesBuffer; |
1508 | Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0], |
1509 | dimSizesValues, dimSizesBuffer); |
1510 | |
1511 | // Get the number of stored entries. |
1512 | const Type indexTp = rewriter.getIndexType(); |
1513 | Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE" , |
1514 | {indexTp}, {reader}, EmitCInterface::Off) |
1515 | .getResult(0); |
1516 | |
1517 | // Construct the lvl sizes and the dim2lvl/lvl2dim buffers. |
1518 | SmallVector<Value> lvlSizesValues; |
1519 | Value dim2lvlBuffer; |
1520 | Value lvl2dimBuffer; |
1521 | genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer, |
1522 | lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer); |
1523 | |
1524 | // Construct allocation for each field. |
1525 | Value sizeHint = nse; |
1526 | SmallVector<Value> fields; |
1527 | createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint, |
1528 | lvlSizesValues, fields); |
1529 | |
1530 | // Read the COO tensor data. |
1531 | MutSparseTensorDescriptor desc(dstTp, fields); |
1532 | Value xs = desc.getAOSMemRef(); |
1533 | Value ys = desc.getValMemRef(); |
1534 | const Type boolTp = rewriter.getIntegerType(1); |
1535 | const Type elemTp = dstTp.getElementType(); |
1536 | const Type crdTp = dstTp.getCrdType(); |
1537 | SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers" , |
1538 | overheadTypeFunctionSuffix(overheadTp: crdTp), |
1539 | primaryTypeFunctionSuffix(elemTp)}; |
1540 | Value isSorted = |
1541 | createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp}, |
1542 | {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys}, |
1543 | EmitCInterface::On) |
1544 | .getResult(0); |
1545 | |
1546 | // If the destination tensor is a sorted COO, we need to sort the COO tensor |
1547 | // data if the input elements aren't sorted yet. |
1548 | const Level lvlRank = dstTp.getLvlRank(); |
1549 | if (dstTp.isOrderedLvl(lvlRank - 1)) { |
1550 | Value kFalse = constantI1(builder&: rewriter, loc, b: false); |
1551 | Value notSorted = rewriter.create<arith::CmpIOp>( |
1552 | loc, arith::CmpIPredicate::eq, isSorted, kFalse); |
1553 | scf::IfOp ifOp = |
1554 | rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false); |
1555 | rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
1556 | auto xPerm = rewriter.getMultiDimIdentityMap(rank: lvlRank); |
1557 | rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm, |
1558 | rewriter.getIndexAttr(0), |
1559 | SparseTensorSortKind::HybridQuickSort); |
1560 | rewriter.setInsertionPointAfter(ifOp); |
1561 | } |
1562 | |
1563 | // Set PosMemRef0[1] = nse. |
1564 | const Value c1 = constantIndex(builder&: rewriter, loc, i: 1); |
1565 | const Value posMemref0 = desc.getPosMemRef(0); |
1566 | const Type posTp = dstTp.getPosType(); |
1567 | const Value posNse = genCast(builder&: rewriter, loc, value: nse, dstTy: posTp); |
1568 | rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1); |
1569 | |
1570 | // Update storage specifier. |
1571 | Value coordinatesSize = rewriter.create<arith::MulIOp>( |
1572 | loc, nse, constantIndex(rewriter, loc, lvlRank)); |
1573 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0, |
1574 | coordinatesSize); |
1575 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize, |
1576 | std::nullopt, nse); |
1577 | |
1578 | // Release the sparse tensor reader. |
1579 | createFuncCall(builder&: rewriter, loc, name: "delSparseTensorReader" , resultType: {}, operands: {reader}, |
1580 | emitCInterface: EmitCInterface::Off); |
1581 | |
1582 | // Replace operation with resulting memrefs. |
1583 | rewriter.replaceOpWithMultiple(op, {fields}); |
1584 | return success(); |
1585 | } |
1586 | }; |
1587 | |
1588 | struct SparseHasRuntimeLibraryConverter |
1589 | : public OpConversionPattern<HasRuntimeLibraryOp> { |
1590 | using OpConversionPattern::OpConversionPattern; |
1591 | LogicalResult |
1592 | matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor, |
1593 | ConversionPatternRewriter &rewriter) const override { |
1594 | auto i1Type = rewriter.getI1Type(); |
1595 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
1596 | op, i1Type, rewriter.getIntegerAttr(i1Type, 0)); |
1597 | return success(); |
1598 | } |
1599 | }; |
1600 | |
1601 | } // namespace |
1602 | |
1603 | //===----------------------------------------------------------------------===// |
1604 | // Public method for populating conversion rules. |
1605 | //===----------------------------------------------------------------------===// |
1606 | |
1607 | /// Populates the given patterns list with conversion rules required for |
1608 | /// the sparsification of linear algebra operations. |
1609 | void mlir::populateSparseTensorCodegenPatterns( |
1610 | const TypeConverter &typeConverter, RewritePatternSet &patterns, |
1611 | bool createSparseDeallocs, bool enableBufferInitialization) { |
1612 | patterns.add< |
1613 | SparseAssembleOpConverter, SparseDisassembleOpConverter, |
1614 | SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter, |
1615 | SparseCastConverter, SparseExtractSliceConverter, |
1616 | SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, |
1617 | SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter, |
1618 | SparseSliceGetterOpConverter<ToSliceOffsetOp, |
1619 | StorageSpecifierKind::DimOffset>, |
1620 | SparseSliceGetterOpConverter<ToSliceStrideOp, |
1621 | StorageSpecifierKind::DimStride>, |
1622 | SparseToPositionsConverter, SparseToCoordinatesConverter, |
1623 | SparseToCoordinatesBufferConverter, SparseToValuesConverter, |
1624 | SparseConvertConverter, SparseNewConverter, |
1625 | SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>( |
1626 | typeConverter, patterns.getContext()); |
1627 | patterns.add<SparseTensorDeallocConverter>( |
1628 | arg: typeConverter, args: patterns.getContext(), args&: createSparseDeallocs); |
1629 | patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>( |
1630 | arg: typeConverter, args: patterns.getContext(), args&: enableBufferInitialization); |
1631 | } |
1632 | |