1//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts sparse tensor types and primitives to actual compiler
10// visible buffers and actual compiler IR that implements these primitives on
11// the selected sparse tensor storage schemes. This pass provides an alternative
12// to the SparseTensorConversion pass, eliminating the dependence on a runtime
13// support library (other than for file I/O), and providing many more
14// opportunities for subsequent compiler optimization of the generated code.
15//
16//===----------------------------------------------------------------------===//
17
18#include "Utils/CodegenUtils.h"
19#include "Utils/SparseTensorDescriptor.h"
20
21#include "mlir/Dialect/Arith/Utils/Utils.h"
22#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Linalg/Utils/Utils.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SparseTensor/IR/Enums.h"
27#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30#include "mlir/Dialect/Tensor/IR/Tensor.h"
31#include "mlir/Transforms/DialectConversion.h"
32
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::sparse_tensor;
37
38//===----------------------------------------------------------------------===//
39// Helper methods.
40//===----------------------------------------------------------------------===//
41
42/// Flattens a list of operands that may contain sparse tensors.
43static void flattenOperands(ValueRange operands,
44 SmallVectorImpl<Value> &flattened) {
45 // In case of
46 // sparse_tensor, c, sparse_tensor
47 // ==>
48 // memref ..., c, memref ...
49 for (auto operand : operands) {
50 if (getSparseTensorEncoding(type: operand.getType())) {
51 auto tuple = getTuple(operand);
52 // An unrealized_conversion_cast will be inserted by type converter to
53 // inter-mix the gap between 1:N conversion between sparse tensors and
54 // fields. In this case, take the operands in the cast and replace the
55 // sparse tensor output with the flattened type array.
56 flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
57 } else {
58 flattened.push_back(Elt: operand);
59 }
60 }
61}
62
63/// Generates a load with proper `index` typing.
64static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
65 idx = genCast(builder, loc, idx, builder.getIndexType());
66 return builder.create<memref::LoadOp>(loc, mem, idx);
67}
68
69/// Generates a store with proper `index` typing and proper value.
70static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
71 Value idx) {
72 idx = genCast(builder, loc, idx, builder.getIndexType());
73 val = genCast(builder, loc, val,
74 cast<ShapedType>(mem.getType()).getElementType());
75 builder.create<memref::StoreOp>(loc, val, mem, idx);
76}
77
78/// Creates a straightforward counting for-loop.
79static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
80 MutableArrayRef<Value> fields,
81 Value lower = Value()) {
82 Type indexType = builder.getIndexType();
83 if (!lower)
84 lower = constantZero(builder, loc, tp: indexType);
85 Value one = constantOne(builder, loc, tp: indexType);
86 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
87 for (unsigned i = 0, e = fields.size(); i < e; i++)
88 fields[i] = forOp.getRegionIterArg(i);
89 builder.setInsertionPointToStart(forOp.getBody());
90 return forOp;
91}
92
93/// Creates a push back operation.
94static void createPushback(OpBuilder &builder, Location loc,
95 MutSparseTensorDescriptor desc,
96 SparseTensorFieldKind kind, std::optional<Level> lvl,
97 Value value, Value repeat = Value()) {
98 Type etp = desc.getMemRefElementType(kind, lvl);
99 Value field = desc.getMemRefField(kind, lvl);
100 StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
101
102 auto pushBackOp = builder.create<PushBackOp>(
103 loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
104 genCast(builder, loc, value, etp), repeat);
105
106 desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
107 desc.setSpecifierField(builder, loc, specFieldKind, lvl,
108 pushBackOp.getNewSize());
109}
110
111/// Generates code that allocates a sparse storage scheme for given rank.
112static void allocSchemeForRank(OpBuilder &builder, Location loc,
113 MutSparseTensorDescriptor desc, Level startLvl) {
114 const SparseTensorType stt(desc.getRankedTensorType());
115 Value linear = constantIndex(builder, loc, i: 1);
116 const Level lvlRank = stt.getLvlRank();
117 for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
118 const auto lt = stt.getLvlType(l: lvl);
119 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
120 // Append linear x positions, initialized to zero. Since each compressed
121 // dimension initially already has a single zero entry, this maintains
122 // the desired "linear + 1" length property at all times. For loose
123 // compression, we multiply linear by two in order to append both the
124 // lo/hi positions.
125 Value posZero = constantZero(builder, loc, tp: stt.getPosType());
126 if (isLooseCompressedLT(lt)) {
127 Value two = constantIndex(builder, loc, i: 2);
128 linear = builder.create<arith::MulIOp>(loc, linear, two);
129 }
130 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl,
131 /*value=*/posZero, /*repeat=*/linear);
132 return;
133 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
134 return; // nothing to do
135 }
136 // Keep compounding the size, but nothing needs to be initialized
137 // at this level. We will eventually reach a compressed level or
138 // otherwise the values array for the from-here "all-dense" case.
139 assert(isDenseLT(lt));
140 Value size = desc.getLvlSize(builder, loc, lvl);
141 linear = builder.create<arith::MulIOp>(loc, linear, size);
142 }
143 // Reached values array so prepare for an insertion.
144 Value valZero = constantZero(builder, loc, tp: stt.getElementType());
145 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef,
146 lvl: std::nullopt, /*value=*/valZero, /*repeat=*/linear);
147}
148
149/// Creates allocation operation.
150static Value createAllocation(OpBuilder &builder, Location loc,
151 MemRefType memRefType, Value sz,
152 bool enableInit) {
153 Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
154 Type elemType = memRefType.getElementType();
155 if (enableInit) {
156 Value fillValue = constantZero(builder, loc, tp: elemType);
157 builder.create<linalg::FillOp>(loc, fillValue, buffer);
158 }
159 return buffer;
160}
161
162/// Creates the dim sizes array, filling in from dynamic sizes.
163static void createDimSizes(OpBuilder &builder, Location loc,
164 SparseTensorType stt, ValueRange dynSizes,
165 /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
166 const Dimension dimRank = stt.getDimRank();
167 dimSizesValues.clear();
168 dimSizesValues.reserve(N: dimRank);
169 unsigned i = 0;
170 for (const Size sz : stt.getDimShape())
171 dimSizesValues.push_back(ShapedType::isDynamic(sz)
172 ? dynSizes[i++]
173 : constantIndex(builder, loc, sz));
174}
175
176/// Creates allocation for each field in sparse tensor type. Note that
177/// for all dynamic memrefs in the sparse tensor stroage layout, the
178/// memory size is really the capacity of the "vector", while the actual
179/// size resides in the sizes array.
180static void createAllocFields(OpBuilder &builder, Location loc,
181 SparseTensorType stt, bool enableInit,
182 Value sizeHint,
183 SmallVectorImpl<Value> &lvlSizesValues,
184 /*out*/ SmallVectorImpl<Value> &fields) {
185 Level lvlRank = stt.getLvlRank();
186 // Set up some heuristic sizes. We try to set the initial
187 // size based on available information. Otherwise we just
188 // initialize a few elements to start the reallocation chain.
189 // TODO: refine this
190 Value posHeuristic, crdHeuristic, valHeuristic;
191 if (stt.isAllDense()) {
192 valHeuristic = lvlSizesValues[0];
193 for (Level lvl = 1; lvl < lvlRank; lvl++)
194 valHeuristic =
195 builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196 } else if (sizeHint) {
197 if (stt.getAoSCOOStart() == 0) {
198 posHeuristic = constantIndex(builder, loc, i: 2);
199 crdHeuristic = builder.create<arith::MulIOp>(
200 loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
201 } else if (lvlRank == 2 && stt.isDenseLvl(l: 0) && stt.isCompressedLvl(l: 1)) {
202 posHeuristic = builder.create<arith::AddIOp>(
203 loc, sizeHint, constantIndex(builder, loc, 1));
204 crdHeuristic = sizeHint;
205 } else {
206 posHeuristic = crdHeuristic = constantIndex(builder, loc, i: 16);
207 }
208 valHeuristic = sizeHint;
209 } else {
210 posHeuristic = crdHeuristic = valHeuristic =
211 constantIndex(builder, loc, i: 16);
212 }
213 // Initializes all fields. An initial storage specifier and allocated
214 // positions/coordinates/values memrefs (with heuristic capacity).
215 foreachFieldAndTypeInSparseTensor(
216 stt,
217 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
218 enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
219 Level /*lvl*/, LevelType /*lt*/) -> bool {
220 assert(fields.size() == fIdx);
221 Value field;
222 switch (fKind) {
223 case SparseTensorFieldKind::StorageSpec:
224 field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
225 break;
226 case SparseTensorFieldKind::PosMemRef:
227 field = createAllocation(builder, loc, cast<MemRefType>(fType),
228 posHeuristic, enableInit);
229 break;
230 case SparseTensorFieldKind::CrdMemRef:
231 field = createAllocation(builder, loc, cast<MemRefType>(fType),
232 crdHeuristic, enableInit);
233 break;
234 case SparseTensorFieldKind::ValMemRef:
235 field = createAllocation(builder, loc, cast<MemRefType>(fType),
236 valHeuristic, enableInit);
237 break;
238 }
239 assert(field);
240 fields.push_back(Elt: field);
241 // Returns true to continue the iteration.
242 return true;
243 });
244 // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
245 // and gives all position fields an initial zero entry, so that it is
246 // easier to maintain the "linear + 1" length property.
247 MutSparseTensorDescriptor desc(stt, fields);
248 Value posZero = constantZero(builder, loc, tp: stt.getPosType());
249 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
250 desc.setLvlSize(builder, loc, lvl, v: lvlSizesValues[lvl]);
251 const auto lt = stt.getLvlType(l: lvl);
252 if (isCompressedLT(lt) || isLooseCompressedLT(lt))
253 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl,
254 /*value=*/posZero);
255 }
256 allocSchemeForRank(builder, loc, desc, /*rank=*/startLvl: 0);
257}
258
259/// Helper method that generates block specific to compressed case:
260///
261/// // given: parentPos = posCursor[lvl-1]
262/// pstart = desc.positions[lvl][parentPos]
263/// pstop = desc.positions[lvl][parentPos+1]
264/// plast = pstop - 1
265/// msz = desc.coordinates[lvl].size()
266/// if (pstart < pstop) {
267/// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
268/// } else { // first insertion
269/// isPresent = false
270/// desc.positions[lvl][parentPos] = msz
271/// }
272/// if (isPresent) { // coordinate is already present
273/// pnext = plast
274/// } else {
275/// desc.coordinates[lvl].push_back(lvlCoords[lvl])
276/// desc.positions[lvl][parentPos+1] = msz+1
277/// pnext = msz
278/// <prepare level lvl+1>
279/// }
280/// posCursor[lvl] = pnext
281static Value genCompressed(OpBuilder &builder, Location loc,
282 MutSparseTensorDescriptor desc, ValueRange lvlCoords,
283 Value /*unused*/, Value parentPos, Level lvl) {
284 const SparseTensorType stt(desc.getRankedTensorType());
285 const Level lvlRank = stt.getLvlRank();
286 assert(lvl < lvlRank && "Level is out of bounds");
287 assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
288 "Level-rank mismatch");
289 SmallVector<Type> types;
290 Type indexType = builder.getIndexType();
291 Type boolType = builder.getIntegerType(1);
292 unsigned crdFidx;
293 unsigned crdStride;
294 std::tie(args&: crdFidx, args&: crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
295 const Value one = constantIndex(builder, loc, i: 1);
296 const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
297 const Value positionsAtLvl = desc.getPosMemRef(lvl);
298 const Value pstart = genLoad(builder, loc, mem: positionsAtLvl, idx: parentPos);
299 const Value pstop = genLoad(builder, loc, mem: positionsAtLvl, idx: pp1);
300 const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
301 const Value crdStrideC =
302 crdStride > 1 ? constantIndex(builder, loc, i: crdStride) : Value();
303 const Value msz =
304 crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
305 : crdMsz;
306 const Value plast = builder.create<arith::SubIOp>(
307 loc, genCast(builder, loc, pstop, indexType), one);
308 // Conditional expression.
309 Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
310 pstart, pstop);
311 types.push_back(Elt: boolType);
312 scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
313 types.pop_back();
314 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
315 Value crd =
316 genLoad(builder, loc, desc.getMemRefField(crdFidx),
317 crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
318 : plast);
319 Value eq = builder.create<arith::CmpIOp>(
320 loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
321 lvlCoords[lvl]);
322 builder.create<scf::YieldOp>(loc, eq);
323 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
324 if (lvl > 0)
325 genStore(builder, loc, val: msz, mem: positionsAtLvl, idx: parentPos);
326 builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
327 builder.setInsertionPointAfter(ifOp1);
328 // If present construct. Note that for a non-unique dimension level, we
329 // simply set the condition to false and rely on CSE/DCE to clean up the IR.
330 //
331 // TODO: generate less temporary IR?
332 //
333 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
334 types.push_back(Elt: desc.getField(i).getType());
335 types.push_back(Elt: indexType);
336 const Value p = stt.isUniqueLvl(l: lvl) ? ifOp1.getResult(0)
337 : constantI1(builder, loc, b: false);
338 scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
339 // If present (fields unaffected, update pnext to plast).
340 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
341
342 // FIXME: This does not looks like a clean way, but probably the most
343 // efficient way.
344 desc.getFields().push_back(plast);
345 builder.create<scf::YieldOp>(loc, desc.getFields());
346 desc.getFields().pop_back();
347
348 // If !present (changes fields, update pnext).
349 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
350 Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
351 genStore(builder, loc, val: mszp1, mem: positionsAtLvl, idx: pp1);
352 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef, lvl,
353 /*value=*/lvlCoords[lvl]);
354 // Prepare the next level "as needed".
355 if ((lvl + 1) < lvlRank)
356 allocSchemeForRank(builder, loc, desc, startLvl: lvl + 1);
357
358 desc.getFields().push_back(msz);
359 builder.create<scf::YieldOp>(loc, desc.getFields());
360 desc.getFields().pop_back();
361
362 // Update fields and return next pos.
363 builder.setInsertionPointAfter(ifOp2);
364 unsigned o = 0;
365 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
366 desc.setField(fidx: i, v: ifOp2.getResult(o++));
367 return ifOp2.getResult(o);
368}
369
370/// Generates insertion finalization code.
371static void genEndInsert(OpBuilder &builder, Location loc,
372 SparseTensorDescriptor desc) {
373 const SparseTensorType stt(desc.getRankedTensorType());
374 const Level lvlRank = stt.getLvlRank();
375 for (Level lvl = 0; lvl < lvlRank; lvl++) {
376 const auto lt = stt.getLvlType(l: lvl);
377 if (isCompressedLT(lt)) {
378 // Compressed dimensions need a position cleanup for all entries
379 // that were not visited during the insertion pass.
380 //
381 // TODO: avoid cleanup and keep compressed scheme consistent at all
382 // times?
383 //
384 if (lvl > 0) {
385 Type posType = stt.getPosType();
386 Value posMemRef = desc.getPosMemRef(lvl);
387 Value hi = desc.getPosMemSize(builder, loc, lvl);
388 Value zero = constantIndex(builder, loc, i: 0);
389 Value one = constantIndex(builder, loc, i: 1);
390 // Vector of only one, but needed by createFor's prototype.
391 SmallVector<Value, 1> inits{genLoad(builder, loc, mem: posMemRef, idx: zero)};
392 scf::ForOp loop = createFor(builder, loc, hi, inits, one);
393 Value i = loop.getInductionVar();
394 Value oldv = loop.getRegionIterArg(0);
395 Value newv = genLoad(builder, loc, mem: posMemRef, idx: i);
396 Value posZero = constantZero(builder, loc, tp: posType);
397 Value cond = builder.create<arith::CmpIOp>(
398 loc, arith::CmpIPredicate::eq, newv, posZero);
399 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
400 cond, /*else*/ true);
401 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
402 genStore(builder, loc, val: oldv, mem: posMemRef, idx: i);
403 builder.create<scf::YieldOp>(loc, oldv);
404 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
405 builder.create<scf::YieldOp>(loc, newv);
406 builder.setInsertionPointAfter(ifOp);
407 builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
408 builder.setInsertionPointAfter(loop);
409 }
410 } else {
411 assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
412 isNOutOfMLT(lt));
413 }
414 }
415}
416
417/// Generates a subview into the sizes.
418static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
419 Value sz) {
420 auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
421 return builder
422 .create<memref::SubViewOp>(
423 loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
424 ValueRange{}, ValueRange{sz}, ValueRange{},
425 ArrayRef<int64_t>{0}, // static offset
426 ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
427 ArrayRef<int64_t>{1}) // static stride
428 .getResult();
429}
430
431/// Creates the reassociation array.
432static SmallVector<ReassociationIndices>
433getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
434 SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
435 // Create reassociation in the form:
436 // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
437 for (unsigned i = 0; i < batchLvls; i++)
438 ret[i].push_back(Elt: i);
439
440 for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
441 ret.back().push_back(Elt: i);
442
443 return ret;
444}
445
446//===----------------------------------------------------------------------===//
447// Codegen rules.
448//===----------------------------------------------------------------------===//
449
450namespace {
451
452/// Helper class to help lowering sparse_tensor.insert operation.
453class SparseInsertGenerator
454 : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
455public:
456 SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
457 bool genCall)
458 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
459
460 /// Generates code along an insertion path without the need for a "cursor".
461 /// This current insertion strategy comes at the expense of some testing
462 /// overhead for each insertion. The strategy will be optimized later for
463 /// common insertion patterns. The current insertion strategy also assumes
464 /// insertions occur in "a reasonable order" that enables building the
465 /// storage scheme in an appending/inserting kind of fashion (i.e. no
466 /// in-between insertions that need data movement). The implementation
467 /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
468 ///
469 /// TODO: better unord/not-unique; also generalize, optimize, specialize!
470 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
471 OpBuilder &builder, Location loc) {
472 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
473 const Level lvlRank = stt.getLvlRank();
474 // Extract fields and coordinates from args.
475 SmallVector<Value> fields = llvm::to_vector(Range: args.drop_back(n: lvlRank + 1));
476 MutSparseTensorDescriptor desc(stt, fields);
477 const SmallVector<Value> coords =
478 llvm::to_vector(Range: args.take_back(n: lvlRank + 1).drop_back());
479 Value value = args.back();
480 Value parentPos = constantZero(builder, loc, builder.getIndexType());
481 // Generate code for every level.
482 for (Level lvl = 0; lvl < lvlRank; lvl++) {
483 const auto lt = stt.getLvlType(l: lvl);
484 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
485 // Create:
486 // if (!present) {
487 // coordinates[lvl].push_back(coords[lvl])
488 // <update positions and prepare level lvl + 1>
489 // }
490 // positions[lvl] = coordinates.size() - 1
491 // <insert @ positions[lvl] at next level lvl + 1>
492 if (isLooseCompressedLT(lt)) {
493 Value two = constantIndex(builder, loc, i: 2);
494 parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
495 }
496 parentPos =
497 genCompressed(builder, loc, desc, lvlCoords: coords, value, parentPos, lvl);
498 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
499 // Create:
500 // coordinates[lvl].push_back(coords[lvl])
501 // positions[lvl] = positions[lvl-1]
502 // <insert @ positions[lvl] at next level lvl + 1>
503 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef,
504 lvl, /*value=*/coords[lvl]);
505 } else {
506 assert(isDenseLT(lt));
507 // Construct the new position as:
508 // positions[lvl] = size * positions[lvl-1] + coords[lvl]
509 // <insert @ positions[lvl] at next level lvl + 1>
510 Value size = desc.getLvlSize(builder, loc, lvl);
511 Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
512 parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
513 }
514 }
515 // Reached the actual value append/insert.
516 if (!stt.isDenseLvl(l: lvlRank - 1))
517 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef,
518 lvl: std::nullopt, value);
519 else
520 genStore(builder, loc, value, desc.getValMemRef(), parentPos);
521 return fields;
522 }
523
524 std::string getMangledFuncName() {
525 // The mangled name of the function has this format:
526 // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
527 constexpr const char kInsertFuncNamePrefix[] = "_insert_";
528 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
529 SmallString<32> nameBuffer;
530 llvm::raw_svector_ostream nameOstream(nameBuffer);
531 nameOstream << kInsertFuncNamePrefix;
532 const Level lvlRank = stt.getLvlRank();
533 for (Level l = 0; l < lvlRank; l++) {
534 std::string lvlType = toMLIRString(lt: stt.getLvlType(l));
535 // Replace/remove punctuations in level properties.
536 std::replace_if(
537 first: lvlType.begin(), last: lvlType.end(),
538 pred: [](char c) { return c == '(' || c == ','; }, new_value: '_');
539 llvm::erase_if(C&: lvlType, P: [](char c) { return c == ')' || c == ' '; });
540 nameOstream << lvlType << "_";
541 }
542 // Static dim sizes are used in the generated code while dynamic sizes are
543 // loaded from the dimSizes buffer. This is the reason for adding the shape
544 // to the function name.
545 for (const auto sz : stt.getDimShape())
546 nameOstream << sz << "_";
547 // Permutation information is also used in generating insertion.
548 if (!stt.isIdentity())
549 nameOstream << stt.getDimToLvl() << "_";
550 nameOstream << stt.getElementType() << "_";
551 nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
552 return nameOstream.str().str();
553 }
554
555private:
556 TensorType rtp;
557};
558
559/// Sparse tensor storage conversion rule for returns.
560class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
561public:
562 using OpConversionPattern::OpConversionPattern;
563 LogicalResult
564 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter) const override {
566 SmallVector<Value> flattened;
567 flattenOperands(adaptor.getOperands(), flattened);
568 // Create a return with the flattened value extracted from sparse tensors.
569 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
570 return success();
571 }
572};
573
574/// Sparse tensor storage conversion rule for calls.
575class SparseCallConverter : public OpConversionPattern<func::CallOp> {
576public:
577 // The default CallOp converter can not handle 1:N type conversion.
578 using OpConversionPattern::OpConversionPattern;
579 LogicalResult
580 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
581 ConversionPatternRewriter &rewriter) const override {
582 Location loc = op.getLoc();
583 // In case of:
584 // sparse_tensor, f, sparse_tensor = call @foo(...)
585 // ==>
586 // memref..., f, memref = call @foo(...) replace with
587 // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
588 SmallVector<Type> finalRetTy;
589 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
590 return failure();
591
592 // (1) Generates new call with flattened return value.
593 SmallVector<Value> flattened;
594 flattenOperands(adaptor.getOperands(), flattened);
595 auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
596 finalRetTy, flattened);
597 // (2) Create cast operation for sparse tensor returns.
598 SmallVector<Value> castedRet;
599 // Tracks the offset of current return value (of the original call)
600 // relative to the new call (after sparse tensor flattening);
601 unsigned retOffset = 0;
602 // Temporal buffer to hold the flattened list of type for
603 // a sparse tensor.
604 SmallVector<Type> sparseFlat;
605 for (auto ret : op.getResults()) {
606 assert(retOffset < newCall.getNumResults());
607 auto retType = ret.getType();
608 if (failed(typeConverter->convertType(retType, sparseFlat)))
609 llvm_unreachable("Failed to convert type in sparse tensor codegen");
610
611 // Converted types can not be empty when the type conversion succeed.
612 assert(!sparseFlat.empty());
613 if (sparseFlat.size() > 1) {
614 auto flatSize = sparseFlat.size();
615 ValueRange fields(iterator_range<ResultRange::iterator>(
616 newCall.result_begin() + retOffset,
617 newCall.result_begin() + retOffset + flatSize));
618 castedRet.push_back(genTuple(rewriter, loc, retType, fields));
619 retOffset += flatSize;
620 } else {
621 // If this is an 1:1 conversion, no need for casting.
622 castedRet.push_back(newCall.getResult(retOffset));
623 retOffset++;
624 }
625 sparseFlat.clear();
626 }
627
628 assert(castedRet.size() == op.getNumResults());
629 rewriter.replaceOp(op, castedRet);
630 return success();
631 }
632};
633
634/// Sparse codegen rule for level accesses.
635class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
636public:
637 using OpConversionPattern::OpConversionPattern;
638 LogicalResult
639 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
640 ConversionPatternRewriter &rewriter) const override {
641 std::optional<int64_t> lvl = op.getConstantLvlIndex();
642 if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
643 return failure();
644
645 auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
646 auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
647
648 rewriter.replaceOp(op, sz);
649 return success();
650 }
651};
652
653// TODO: use a new SortCOO operation here instead of reusing convert op.
654struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
655 using OpConversionPattern::OpConversionPattern;
656 LogicalResult
657 matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
658 ConversionPatternRewriter &rewriter) const override {
659 Location loc = op.getLoc();
660 MLIRContext *ctx = op.getContext();
661
662 SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
663 SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
664
665 // Should have been verified.
666 assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
667 dstStt.isCOOType() && srcStt.isCOOType());
668 assert(dstStt.hasSameDimToLvl(srcStt));
669
670 // We don't need a mutable descriptor here as we perform sorting in-place.
671 auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
672 auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
673 auto crd = desc.getAOSMemRef();
674 auto val = desc.getValMemRef();
675
676 // Otherwise we need another data shuffle and a non-identity map.
677 assert(dstStt.hasSameDimToLvl(srcStt));
678 (void)dstStt; // to silence warning when assertion is disabled
679
680 auto id = AffineMap::getMultiDimIdentityMap(numDims: srcStt.getLvlRank(), context: ctx);
681
682 rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
683 rewriter.getIndexAttr(0), op.getAlgorithm());
684
685 // Since we do in-place sorting, the destinate tensor will have the same set
686 // of memrefs as the source tensor.
687 rewriter.replaceOp(op, adaptor.getInputCoo());
688 return success();
689 }
690};
691
692template <typename Op, StorageSpecifierKind kind>
693class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
694public:
695 using OpConversionPattern<Op>::OpConversionPattern;
696 LogicalResult
697 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
698 ConversionPatternRewriter &rewriter) const override {
699 // Simply lowers to specifer.get <field> operation.
700 auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
701 auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
702 op.getDim().getZExtValue());
703
704 rewriter.replaceOp(op, v);
705 return success();
706 }
707};
708
709/// Sparse codegen rule for trivial tensor casts.
710class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
711public:
712 using OpConversionPattern::OpConversionPattern;
713 LogicalResult
714 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter) const override {
716 // Only rewrite identically annotated source/dest.
717 auto encDst = getSparseTensorEncoding(op.getType());
718 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
719 if (!encDst || encDst != encSrc)
720 return failure();
721 rewriter.replaceOp(op, adaptor.getOperands());
722 return success();
723 }
724};
725
726class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
727public:
728 using OpConversionPattern::OpConversionPattern;
729 LogicalResult
730 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
731 ConversionPatternRewriter &rewriter) const override {
732 // Simply fold the operation.
733 rewriter.replaceOp(op, adaptor.getSource());
734 return success();
735 }
736};
737
738/// Sparse codegen rule for the alloc operator.
739class SparseTensorAllocConverter
740 : public OpConversionPattern<bufferization::AllocTensorOp> {
741public:
742 using OpConversionPattern::OpConversionPattern;
743 SparseTensorAllocConverter(TypeConverter &typeConverter, MLIRContext *context,
744 bool enableInit)
745 : OpConversionPattern(typeConverter, context),
746 enableBufferInitialization(enableInit) {}
747
748 LogicalResult
749 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
750 ConversionPatternRewriter &rewriter) const override {
751 const auto resType = getSparseTensorType(op);
752 if (!resType.hasEncoding())
753 return failure();
754
755 Location loc = op.getLoc();
756 // Deal with copy.
757 if (op.getCopy()) {
758 auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
759 SmallVector<Value> fields;
760 fields.reserve(N: desc.getNumFields());
761 // Memcpy on memref fields.
762 for (auto field : desc.getMemRefFields()) {
763 auto memrefTp = cast<MemRefType>(field.getType());
764 auto size = rewriter.create<memref::DimOp>(loc, field, 0);
765 auto copied =
766 rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
767 rewriter.create<memref::CopyOp>(loc, field, copied);
768 fields.push_back(copied);
769 }
770 // Reuses specifier.
771 fields.push_back(Elt: desc.getSpecifier());
772 assert(fields.size() == desc.getNumFields());
773 rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
774 return success();
775 }
776
777 if (!resType.isIdentity()) {
778 return rewriter.notifyMatchFailure(
779 op, "try run --sparse-reinterpret-map before codegen");
780 }
781 // Level size equals to dimension size since lvl2dim map is an identity map.
782 SmallVector<Value> lvlSizesValues;
783 createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
784 /*dimSizesValues=*/lvlSizesValues);
785
786 // Construct allocation for each field.
787 Value sizeHint = op.getSizeHint();
788 SmallVector<Value> fields;
789 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
790 sizeHint, lvlSizesValues, fields);
791
792 // Replace operation with resulting memrefs.
793 rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
794 return success();
795 }
796
797private:
798 bool enableBufferInitialization;
799};
800
801/// Sparse codegen rule for the empty tensor operator.
802class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
803public:
804 using OpConversionPattern::OpConversionPattern;
805 SparseTensorEmptyConverter(TypeConverter &typeConverter, MLIRContext *context,
806 bool enableInit)
807 : OpConversionPattern(typeConverter, context),
808 enableBufferInitialization(enableInit) {}
809
810 LogicalResult
811 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
812 ConversionPatternRewriter &rewriter) const override {
813 const auto resType = getSparseTensorType(op);
814 if (!resType.hasEncoding())
815 return failure();
816
817 if (!resType.isIdentity()) {
818 return rewriter.notifyMatchFailure(
819 op, "try run --sparse-reinterpret-map before codegen");
820 }
821
822 Location loc = op.getLoc();
823 // Level size equals to dimension size since lvl2dim map is an identity map.
824 SmallVector<Value> lvlSizesValues;
825 createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
826 /*dimSizesValues=*/lvlSizesValues);
827 // Construct allocation for each field.
828 Value sizeHint; // none
829 SmallVector<Value> fields;
830 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
831 sizeHint, lvlSizesValues, fields);
832
833 // Replace operation with resulting memrefs.
834 rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
835 return success();
836 }
837
838private:
839 bool enableBufferInitialization;
840};
841
842/// Sparse codegen rule for the dealloc operator.
843class SparseTensorDeallocConverter
844 : public OpConversionPattern<bufferization::DeallocTensorOp> {
845public:
846 using OpConversionPattern::OpConversionPattern;
847 SparseTensorDeallocConverter(TypeConverter &typeConverter,
848 MLIRContext *context, bool createDeallocs)
849 : OpConversionPattern(typeConverter, context),
850 createDeallocs(createDeallocs) {}
851
852 LogicalResult
853 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter) const override {
855 auto enc = getSparseTensorEncoding(op.getTensor().getType());
856 if (!enc)
857 return failure();
858
859 // If user requests not to deallocate sparse tensors, simply erase the
860 // operation.
861 if (createDeallocs) {
862 // Replace the sparse tensor deallocation with field deallocations.
863 Location loc = op.getLoc();
864 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
865 for (auto input : desc.getMemRefFields())
866 // Deallocate every buffer used to store the sparse tensor handler.
867 rewriter.create<memref::DeallocOp>(loc, input);
868 }
869 rewriter.eraseOp(op: op);
870 return success();
871 }
872
873private:
874 const bool createDeallocs;
875};
876
877/// Sparse codegen rule for tensor rematerialization.
878class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
879public:
880 using OpConversionPattern::OpConversionPattern;
881 LogicalResult
882 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
883 ConversionPatternRewriter &rewriter) const override {
884 // Prepare descriptor.
885 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
886 // Generate optional insertion finalization code.
887 if (op.getHasInserts())
888 genEndInsert(rewriter, op.getLoc(), desc);
889 // Replace operation with resulting memrefs.
890 rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
891 return success();
892 }
893};
894
895/// Sparse codegen rule for the expand op.
896class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
897public:
898 using OpConversionPattern::OpConversionPattern;
899 LogicalResult
900 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter) const override {
902 if (!getSparseTensorEncoding(op.getTensor().getType()))
903 return failure();
904 Location loc = op->getLoc();
905 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
906 const auto srcType = getSparseTensorType(op.getTensor());
907 Type eltType = srcType.getElementType();
908 Type boolType = rewriter.getIntegerType(1);
909 Type idxType = rewriter.getIndexType();
910 // All initialization should be done on entry of the loop nest.
911 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
912
913 // Determine the size for access expansion (always the innermost stored
914 // level size).
915 const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
916 // Generate a memref for `sz` elements of type `t`.
917 const auto genAlloc = [&](Type t) {
918 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
919 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
920 };
921 // Allocate temporary buffers for values/filled-switch and added.
922 // We do not use stack buffers for this, since the expanded size may
923 // be rather large (as it envelops a single expanded dense dimension).
924 Value values = genAlloc(eltType);
925 Value filled = genAlloc(boolType);
926 Value added = genAlloc(idxType);
927 Value zero = constantZero(builder&: rewriter, loc, tp: idxType);
928 // Reset the values/filled-switch to all-zero/false. Note that this
929 // introduces an O(N) operation into the computation, but this reset
930 // operation is amortized over the innermost loops for the access
931 // pattern expansion. As noted in the operation doc, we would like
932 // to amortize this setup cost even between kernels.
933 rewriter.create<linalg::FillOp>(
934 loc, ValueRange{constantZero(rewriter, loc, eltType)},
935 ValueRange{values});
936 rewriter.create<linalg::FillOp>(
937 loc, ValueRange{constantZero(rewriter, loc, boolType)},
938 ValueRange{filled});
939 // Replace expansion op with these buffers and initial coordinate.
940 assert(op.getNumResults() == 4);
941 rewriter.replaceOp(op, {values, filled, added, zero});
942 return success();
943 }
944};
945
946/// Sparse codegen rule for the compress operator.
947class SparseCompressConverter : public OpConversionPattern<CompressOp> {
948public:
949 using OpConversionPattern::OpConversionPattern;
950 LogicalResult
951 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
952 ConversionPatternRewriter &rewriter) const override {
953 Location loc = op->getLoc();
954 SmallVector<Value> fields;
955 auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
956 Value values = adaptor.getValues();
957 Value filled = adaptor.getFilled();
958 Value added = adaptor.getAdded();
959 Value count = adaptor.getCount();
960 const SparseTensorType dstType(desc.getRankedTensorType());
961 Type eltType = dstType.getElementType();
962
963 // If the innermost level is ordered, we need to sort the coordinates
964 // in the "added" array prior to applying the compression.
965 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
966 rewriter.create<SortOp>(
967 loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
968 rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
969 // While performing the insertions, we also need to reset the elements
970 // of the values/filled-switch by only iterating over the set elements,
971 // to ensure that the runtime complexity remains proportional to the
972 // sparsity of the expanded access pattern.
973 //
974 // Generate
975 // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
976 // crd = added[i];
977 // value = values[crd];
978 // insert({lvlCoords, crd}, value);
979 // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
980 // values[crd] = 0;
981 // filled[crd] = false;
982 // yield new_memrefs
983 // }
984 scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
985 Value i = loop.getInductionVar();
986
987 Value crd = genLoad(builder&: rewriter, loc, mem: added, idx: i);
988 Value value = genLoad(builder&: rewriter, loc, mem: values, idx: crd);
989 SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
990 SmallVector<Type> flatSpTensorTps = llvm::to_vector(
991 llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
992 params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
993 params.push_back(Elt: crd);
994 params.push_back(Elt: value);
995 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
996 params, /*genCall=*/true);
997 SmallVector<Value> insertRet = insertGen.genCallOrInline(builder&: rewriter, loc);
998 genStore(builder&: rewriter, loc, val: constantZero(builder&: rewriter, loc, tp: eltType), mem: values, idx: crd);
999 genStore(builder&: rewriter, loc, val: constantI1(builder&: rewriter, loc, b: false), mem: filled, idx: crd);
1000 rewriter.create<scf::YieldOp>(loc, insertRet);
1001
1002 rewriter.setInsertionPointAfter(loop);
1003 Value result = genTuple(rewriter, loc, dstType, loop->getResults());
1004 // Deallocate the buffers on exit of the full loop nest.
1005 Operation *parent = getTop(op);
1006 rewriter.setInsertionPointAfter(parent);
1007 rewriter.create<memref::DeallocOp>(loc, values);
1008 rewriter.create<memref::DeallocOp>(loc, filled);
1009 rewriter.create<memref::DeallocOp>(loc, added);
1010 // Replace operation with resulting memrefs.
1011 rewriter.replaceOp(op, result);
1012 return success();
1013 }
1014};
1015
1016/// Sparse codegen rule for the insert operator.
1017class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1018public:
1019 using OpConversionPattern::OpConversionPattern;
1020 LogicalResult
1021 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
1022 ConversionPatternRewriter &rewriter) const override {
1023 auto stt = getSparseTensorType(adaptor.getDest());
1024 if (!stt.hasEncoding())
1025 return failure();
1026 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1027
1028 Location loc = op.getLoc();
1029 auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
1030 TypeRange flatSpTensorTps = desc.getFields().getTypes();
1031 SmallVector<Value> params = llvm::to_vector(desc.getFields());
1032 params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
1033 params.push_back(Elt: adaptor.getScalar());
1034 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1035 params, /*genCall=*/true);
1036 SmallVector<Value> ret = insertGen.genCallOrInline(builder&: rewriter, loc);
1037 // Replace operation with resulting memrefs.
1038 rewriter.replaceOp(op,
1039 genTuple(rewriter, loc, op.getDest().getType(), ret));
1040 return success();
1041 }
1042};
1043
1044/// Sparse codegen rule for position accesses.
1045class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1046public:
1047 using OpAdaptor = typename ToPositionsOp::Adaptor;
1048 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1049 LogicalResult
1050 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
1051 ConversionPatternRewriter &rewriter) const override {
1052 // Replace the requested position access with corresponding field.
1053 // The cast_op is inserted by type converter to intermix 1:N type
1054 // conversion.
1055 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1056 rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
1057 return success();
1058 }
1059};
1060
1061/// Sparse codegen rule for accessing the coordinates arrays.
1062class SparseToCoordinatesConverter
1063 : public OpConversionPattern<ToCoordinatesOp> {
1064public:
1065 using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1066 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1067 LogicalResult
1068 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
1069 ConversionPatternRewriter &rewriter) const override {
1070 // Replace the requested coordinates access with corresponding field.
1071 // The cast_op is inserted by type converter to intermix 1:N type
1072 // conversion.
1073 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1074 rewriter.replaceOp(
1075 op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
1076
1077 return success();
1078 }
1079};
1080
1081/// Sparse codegen rule for accessing the linear coordinates buffer.
1082class SparseToCoordinatesBufferConverter
1083 : public OpConversionPattern<ToCoordinatesBufferOp> {
1084public:
1085 using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1086 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1087 LogicalResult
1088 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
1089 ConversionPatternRewriter &rewriter) const override {
1090 // Replace the requested coordinates access with corresponding field.
1091 // The cast_op is inserted by type converter to intermix 1:N type
1092 // conversion.
1093 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1094 rewriter.replaceOp(op, desc.getAOSMemRef());
1095
1096 return success();
1097 }
1098};
1099
1100/// Sparse codegen rule for value accesses.
1101class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1102public:
1103 using OpAdaptor = typename ToValuesOp::Adaptor;
1104 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1105 LogicalResult
1106 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
1107 ConversionPatternRewriter &rewriter) const override {
1108 // Replace the requested values access with corresponding field.
1109 // The cast_op is inserted by type converter to intermix 1:N type
1110 // conversion.
1111 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1112 rewriter.replaceOp(op, desc.getValMemRef());
1113 return success();
1114 }
1115};
1116
1117/// Sparse codegen rule for the convert operator.
1118class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1119public:
1120 using OpConversionPattern::OpConversionPattern;
1121 LogicalResult
1122 matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
1123 ConversionPatternRewriter &rewriter) const override {
1124 SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1125 SparseTensorEncodingAttr encSrc =
1126 getSparseTensorEncoding(op.getSource().getType());
1127 // The output tensor can not be a slice and those cases should have been
1128 // rejected by ConvertOp::verify() already.
1129 assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1130 // Different encoding (except for different bitwidth) should be handled by
1131 // rewriting.
1132 // We need further rewrites if the input tensor is a slice too.
1133 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1134 encSrc.isSlice()) {
1135 return failure();
1136 }
1137
1138 Type retElemTp = op.getResult().getType().getElementType();
1139 Type srcElemTp = op.getSource().getType().getElementType();
1140 // Fold the trivial cases.
1141 if (retElemTp == srcElemTp && encDst == encSrc) {
1142 rewriter.replaceOp(op, adaptor.getSource());
1143 return success();
1144 }
1145 //
1146 // Do element-wise type conversion without using InsertOp.
1147 //
1148 // for each memref in srcTensor:
1149 // dst = memref.alloc
1150 // if srcMemRefType != dstMemRefType:
1151 // for every dst[i] = cast(src[i])
1152 // else:
1153 // dst = memref.copy(src)
1154 Location loc = op.getLoc();
1155 auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
1156 SmallVector<Value> fields;
1157 foreachFieldAndTypeInSparseTensor(
1158 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1159 [&rewriter, &fields, srcDesc,
1160 loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1161 LevelType /*lt*/) -> bool {
1162 // Simply reuses the storage specifier as it is an SSA value.
1163 if (fKind == SparseTensorFieldKind::StorageSpec) {
1164 fields.push_back(Elt: srcDesc.getSpecifier());
1165 } else {
1166 // Allocates new memrefs
1167 Value srcMem = srcDesc.getMemRefField(fIdx);
1168 // TODO: We can instead use the actual memSize in specifier, that
1169 // would require a subViewOp to avoid overflow when copying
1170 // values.
1171 Value sz = linalg::createOrFoldDimOp(b&: rewriter, loc, val: srcMem, dim: 0);
1172 auto dstMem = rewriter.create<memref::AllocOp>(
1173 loc, cast<MemRefType>(fTp), sz);
1174 if (fTp != srcMem.getType()) {
1175 // Converts elements type.
1176 scf::buildLoopNest(
1177 builder&: rewriter, loc, lbs: constantIndex(builder&: rewriter, loc, i: 0), ubs: sz,
1178 steps: constantIndex(builder&: rewriter, loc, i: 1),
1179 bodyBuilder: [srcMem, &dstMem](OpBuilder &builder, Location loc,
1180 ValueRange ivs) {
1181 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1182 Value casted = genCast(builder, loc, v,
1183 dstMem.getType().getElementType());
1184 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1185 });
1186 } else {
1187 // TODO: We can even reuse the same memref for the new tensor,
1188 // but that requires a `ref-counting` based memory management
1189 // for shared memrefs between multiple sparse tensors.
1190 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1191 }
1192 fields.push_back(Elt: dstMem);
1193 }
1194 return true;
1195 });
1196
1197 rewriter.replaceOp(
1198 op, genTuple(rewriter, loc, op.getResult().getType(), fields));
1199 return success();
1200 }
1201};
1202
1203class SparseExtractSliceConverter
1204 : public OpConversionPattern<tensor::ExtractSliceOp> {
1205public:
1206 using OpConversionPattern::OpConversionPattern;
1207 LogicalResult
1208 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1209 ConversionPatternRewriter &rewriter) const override {
1210 Location loc = op.getLoc();
1211 MLIRContext *ctx = op.getContext();
1212 auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1213 auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1214 // TODO: We should check these in ExtractSliceOp::verify.
1215 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1216 return failure();
1217 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1218
1219 SmallVector<Value> fields;
1220 auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
1221
1222 auto newSpec = rewriter.create<StorageSpecifierInitOp>(
1223 loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
1224 desc.setSpecifier(newSpec);
1225
1226 // Fills in slice information.
1227 for (auto [idx, offset, size, stride] : llvm::enumerate(
1228 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1229 Dimension dim = idx;
1230
1231 Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1232 Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1233 Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1234 // TODO: We could probably only set dynamic value here. But it would
1235 // requires us to fill the hole when casting a static slice to dynamic
1236 // slice.
1237 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1238 dim, offsetV);
1239
1240 // FIXME: we need to distinguish level sizes and dimension size for slices
1241 // here. Maybe we should store slice level sizes in a different array
1242 // instead of reusing it.
1243 assert(srcEnc.isIdentity());
1244 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1245 sizeV);
1246 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1247 dim, strideV);
1248 }
1249
1250 // NOTE: we can not generate tuples directly from descriptor here, as the
1251 // descriptor is holding the original type, yet we want the slice type
1252 // here (they shared every memref but with an updated specifier).
1253 rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
1254 desc.getFields()));
1255 return success();
1256 }
1257};
1258
1259/// Sparse codegen rule for number of entries operator.
1260class SparseNumberOfEntriesConverter
1261 : public OpConversionPattern<NumberOfEntriesOp> {
1262public:
1263 using OpConversionPattern::OpConversionPattern;
1264 LogicalResult
1265 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
1266 ConversionPatternRewriter &rewriter) const override {
1267 // Query memSizes for the actually stored values.
1268 // FIXME: the nse value computed in this way might be wrong when there is
1269 // any "loose_compressed" level.
1270 rewriter.replaceOp(
1271 op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
1272 return success();
1273 }
1274};
1275
1276struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1277 using OpConversionPattern::OpConversionPattern;
1278 LogicalResult
1279 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1280 ConversionPatternRewriter &rewriter) const override {
1281 Location loc = op.getLoc();
1282 const auto stt = getSparseTensorType(op.getResult());
1283
1284 SmallVector<Value> fields;
1285
1286 foreachFieldAndTypeInSparseTensor(
1287 stt,
1288 [&rewriter, &fields, &op, &stt,
1289 loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1290 Level /*lvl*/, LevelType lt) -> bool {
1291 assert(fields.size() == fIdx);
1292 if (fKind == SparseTensorFieldKind::StorageSpec) {
1293 fields.push_back(
1294 Elt: SparseTensorSpecifier::getInitValue(builder&: rewriter, loc, stt: stt));
1295 } else {
1296 // Else simply takes the inputs.
1297 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1298 ? op.getValues()
1299 : op.getLevels()[fIdx];
1300 // TODO: handle batch.
1301 TypedValue<BaseMemRefType> mem = genToMemref(builder&: rewriter, loc, tensor);
1302 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1303 // Flattens the buffer to batchLvlRank.
1304 auto reassoc = getReassociationForFlattening(
1305 mem.getType(), stt.getBatchLvlRank());
1306 mem = rewriter.create<memref::CastOp>(
1307 loc, fType,
1308 rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1309 } else {
1310 mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1311 }
1312 fields.push_back(Elt: mem);
1313 }
1314 return true;
1315 });
1316
1317 MutSparseTensorDescriptor desc(stt, fields);
1318 Value c0 = constantIndex(builder&: rewriter, loc, i: 0);
1319 Value c1 = constantIndex(builder&: rewriter, loc, i: 1);
1320 Value c2 = constantIndex(builder&: rewriter, loc, i: 2);
1321 Value posBack = c0; // index to the last value in the position array
1322 Value memSize = c1; // memory size for current array
1323
1324 Level trailCOOStart = stt.getAoSCOOStart();
1325 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1326 // Sets up SparseTensorSpecifier.
1327 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1328 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1329
1330 // Sets up the level size.
1331 auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1332 desc.setLvlSize(builder&: rewriter, loc, lvl, v: lvlSize);
1333 // We use a single AOS array to store the trailing COO, so there is only
1334 // one memory size to set for the entire COO section.
1335 if (lvl > trailCOOStart)
1336 continue;
1337
1338 // Sets up the memory size by reading the last value in position array.
1339 LevelType lt = stt.getLvlType(lvl);
1340 // Simply forwards the position index when this is a dense level.
1341 if (lt.isa<LevelFormat::Dense>()) {
1342 memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1343 posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1344 continue;
1345 }
1346 if (lt.isa<LevelFormat::Batch>()) {
1347 // Skips batch levels as it is not linearized.
1348 // FIXME: this assumes that every batch has the same number of nse, need
1349 // to be generalized to handle varied-size batches.
1350 continue;
1351 }
1352
1353 if (isWithPosLT(lt)) {
1354 assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1355 if (isLooseCompressedLT(lt)) {
1356 memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1357 posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1358 } else {
1359 assert(isCompressedLT(lt));
1360 posBack = memSize;
1361 memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1362 }
1363 desc.setPosMemSize(builder&: rewriter, loc, lvl, v: memSize);
1364 // The last value in position array is the memory size for next level.
1365 // FIXME: this assumes that every batch has the same number of nse, need
1366 // to be generalized to handle varied-size batches.
1367 SmallVector<Value> batched(stt.getBatchLvlRank(),
1368 constantIndex(builder&: rewriter, loc, i: 0));
1369 batched.push_back(Elt: posBack);
1370 memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1371 posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1372 }
1373 assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1374 // FIXME: This seems to be unnecessarily complex, can we simplify it?
1375 if (lvl == trailCOOStart) {
1376 Value cooSz = rewriter.create<arith::MulIOp>(
1377 loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1378 desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: cooSz);
1379 } else {
1380 desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: memSize);
1381 }
1382 }
1383 desc.setValMemSize(builder&: rewriter, loc, v: memSize);
1384
1385 rewriter.replaceOp(op, genTuple(builder&: rewriter, loc, desc));
1386 return success();
1387 }
1388};
1389
1390struct SparseDisassembleOpConverter
1391 : public OpConversionPattern<DisassembleOp> {
1392 using OpConversionPattern::OpConversionPattern;
1393 SparseDisassembleOpConverter(TypeConverter &typeConverter,
1394 MLIRContext *context)
1395 : OpConversionPattern(typeConverter, context) {}
1396
1397 LogicalResult
1398 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
1399 ConversionPatternRewriter &rewriter) const override {
1400 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1401 Location loc = op.getLoc();
1402 SmallVector<Value> retMem;
1403 SmallVector<Value> retLen;
1404 desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1405 &retLen](FieldIndex fid,
1406 SparseTensorFieldKind fKind,
1407 Level lvl, LevelType lt) -> bool {
1408 if (fKind == SparseTensorFieldKind::StorageSpec)
1409 return true;
1410 SparseTensorType stt(desc.getRankedTensorType());
1411 Value sz, src;
1412 TypedValue<BaseMemRefType> dst;
1413 if (fKind == SparseTensorFieldKind::ValMemRef) {
1414 sz = desc.getValMemSize(rewriter, loc);
1415 src = desc.getValMemRef();
1416 dst = genToMemref(rewriter, loc, op.getOutValues());
1417
1418 retMem.push_back(Elt: dst);
1419 Type valLenTp = op.getValLen().getType();
1420 retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: valLenTp));
1421 } else {
1422 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1423 fKind == SparseTensorFieldKind::CrdMemRef);
1424
1425 sz = fKind == SparseTensorFieldKind::PosMemRef
1426 ? desc.getPosMemSize(rewriter, loc, lvl)
1427 : desc.getCrdMemSize(rewriter, loc, lvl);
1428 src = desc.getMemRefField(fid);
1429 dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1430 retMem.push_back(Elt: dst);
1431 // Retrieves the corresponding level length type.
1432 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1433 retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: lvlLenTp));
1434 }
1435 Value flatOut = dst;
1436 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1437 auto reassoc =
1438 getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1439 flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1440 }
1441 Value dstMem = genSliceToSize(builder&: rewriter, loc, mem: flatOut, sz);
1442 Value srcMem = genSliceToSize(builder&: rewriter, loc, mem: src, sz);
1443 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1444 return true;
1445 });
1446
1447 // Converts MemRefs back to Tensors.
1448 SmallVector<Value> retValues = llvm::to_vector(
1449 Range: llvm::map_range(C&: retMem, F: [&rewriter, loc](Value v) -> Value {
1450 return rewriter.create<bufferization::ToTensorOp>(loc, v);
1451 }));
1452 // Appends the actual memory length used in each buffer returned.
1453 retValues.append(in_start: retLen.begin(), in_end: retLen.end());
1454 rewriter.replaceOp(op, retValues);
1455 return success();
1456 }
1457};
1458
1459struct SparseNewConverter : public OpConversionPattern<NewOp> {
1460 using OpConversionPattern::OpConversionPattern;
1461 LogicalResult
1462 matchAndRewrite(NewOp op, OpAdaptor adaptor,
1463 ConversionPatternRewriter &rewriter) const override {
1464 Location loc = op.getLoc();
1465 const auto dstTp = getSparseTensorType(op.getResult());
1466 // Creating COO with NewOp is handled by direct IR codegen. All other cases
1467 // are handled by rewriting.
1468 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1469 return failure();
1470
1471 // Implement as follows:
1472 // %reader = @createCheckedSparseTensorReader(%filename)
1473 // %nse = @getSparseTensorNSE(%reader)
1474 // %coo = bufferization.alloc_tensor an ordered COO with
1475 // dst dim ordering, size_hint = %nse
1476 // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1477 // %values = sparse_tensor.values(%coo)
1478 // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1479 // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1480 // update storage specifier
1481 // @delSparseTensorReader(%reader)
1482 SmallVector<Value> dimSizesValues;
1483 Value dimSizesBuffer;
1484 Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1485 dimSizesValues, dimSizesBuffer);
1486
1487 // Get the number of stored entries.
1488 const Type indexTp = rewriter.getIndexType();
1489 Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1490 {indexTp}, {reader}, EmitCInterface::Off)
1491 .getResult(0);
1492
1493 // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1494 SmallVector<Value> lvlSizesValues;
1495 Value dim2lvlBuffer;
1496 Value lvl2dimBuffer;
1497 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1498 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1499
1500 // Construct allocation for each field.
1501 Value sizeHint = nse;
1502 SmallVector<Value> fields;
1503 createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1504 lvlSizesValues, fields);
1505
1506 // Read the COO tensor data.
1507 MutSparseTensorDescriptor desc(dstTp, fields);
1508 Value xs = desc.getAOSMemRef();
1509 Value ys = desc.getValMemRef();
1510 const Type boolTp = rewriter.getIntegerType(1);
1511 const Type elemTp = dstTp.getElementType();
1512 const Type crdTp = dstTp.getCrdType();
1513 SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1514 overheadTypeFunctionSuffix(overheadTp: crdTp),
1515 primaryTypeFunctionSuffix(elemTp)};
1516 Value isSorted =
1517 createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1518 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1519 EmitCInterface::On)
1520 .getResult(0);
1521
1522 // If the destination tensor is a sorted COO, we need to sort the COO tensor
1523 // data if the input elements aren't sorted yet.
1524 const Level lvlRank = dstTp.getLvlRank();
1525 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1526 Value kFalse = constantI1(builder&: rewriter, loc, b: false);
1527 Value notSorted = rewriter.create<arith::CmpIOp>(
1528 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1529 scf::IfOp ifOp =
1530 rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
1531 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1532 auto xPerm = rewriter.getMultiDimIdentityMap(rank: lvlRank);
1533 rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1534 rewriter.getIndexAttr(0),
1535 SparseTensorSortKind::HybridQuickSort);
1536 rewriter.setInsertionPointAfter(ifOp);
1537 }
1538
1539 // Set PosMemRef0[1] = nse.
1540 const Value c1 = constantIndex(builder&: rewriter, loc, i: 1);
1541 const Value posMemref0 = desc.getPosMemRef(0);
1542 const Type posTp = dstTp.getPosType();
1543 const Value posNse = genCast(builder&: rewriter, loc, value: nse, dstTy: posTp);
1544 rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1545
1546 // Update storage specifier.
1547 Value coordinatesSize = rewriter.create<arith::MulIOp>(
1548 loc, nse, constantIndex(rewriter, loc, lvlRank));
1549 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1550 coordinatesSize);
1551 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1552 std::nullopt, nse);
1553
1554 // Release the sparse tensor reader.
1555 createFuncCall(builder&: rewriter, loc, name: "delSparseTensorReader", resultType: {}, operands: {reader},
1556 emitCInterface: EmitCInterface::Off);
1557
1558 // Replace operation with resulting memrefs.
1559 rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
1560 return success();
1561 }
1562};
1563
1564struct SparseHasRuntimeLibraryConverter
1565 : public OpConversionPattern<HasRuntimeLibraryOp> {
1566 using OpConversionPattern::OpConversionPattern;
1567 LogicalResult
1568 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1569 ConversionPatternRewriter &rewriter) const override {
1570 auto i1Type = rewriter.getI1Type();
1571 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1572 op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1573 return success();
1574 }
1575};
1576
1577} // namespace
1578
1579//===----------------------------------------------------------------------===//
1580// Public method for populating conversion rules.
1581//===----------------------------------------------------------------------===//
1582
1583/// Populates the given patterns list with conversion rules required for
1584/// the sparsification of linear algebra operations.
1585void mlir::populateSparseTensorCodegenPatterns(
1586 TypeConverter &typeConverter, RewritePatternSet &patterns,
1587 bool createSparseDeallocs, bool enableBufferInitialization) {
1588 patterns.add<
1589 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1590 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1591 SparseCastConverter, SparseExtractSliceConverter,
1592 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1593 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1594 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1595 StorageSpecifierKind::DimOffset>,
1596 SparseSliceGetterOpConverter<ToSliceStrideOp,
1597 StorageSpecifierKind::DimStride>,
1598 SparseToPositionsConverter, SparseToCoordinatesConverter,
1599 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1600 SparseConvertConverter, SparseNewConverter,
1601 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1602 typeConverter, patterns.getContext());
1603 patterns.add<SparseTensorDeallocConverter>(
1604 arg&: typeConverter, args: patterns.getContext(), args&: createSparseDeallocs);
1605 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1606 arg&: typeConverter, args: patterns.getContext(), args&: enableBufferInitialization);
1607}
1608

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp