1//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts sparse tensor types and primitives to actual compiler
10// visible buffers and actual compiler IR that implements these primitives on
11// the selected sparse tensor storage schemes. This pass provides an alternative
12// to the SparseTensorConversion pass, eliminating the dependence on a runtime
13// support library (other than for file I/O), and providing many more
14// opportunities for subsequent compiler optimization of the generated code.
15//
16//===----------------------------------------------------------------------===//
17
18#include "Utils/CodegenUtils.h"
19#include "Utils/SparseTensorDescriptor.h"
20
21#include "mlir/Dialect/Arith/Utils/Utils.h"
22#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Linalg/Utils/Utils.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SparseTensor/IR/Enums.h"
27#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
28#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
29#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
30#include "mlir/Dialect/Tensor/IR/Tensor.h"
31#include "mlir/Transforms/DialectConversion.h"
32
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::sparse_tensor;
37
38//===----------------------------------------------------------------------===//
39// Helper methods.
40//===----------------------------------------------------------------------===//
41
42/// Flatten the given value ranges into a single vector of values.
43static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
44 SmallVector<Value> result;
45 for (const auto &vals : values)
46 llvm::append_range(C&: result, R: vals);
47 return result;
48}
49
50/// Generates a load with proper `index` typing.
51static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
52 idx = genCast(builder, loc, idx, builder.getIndexType());
53 return builder.create<memref::LoadOp>(loc, mem, idx);
54}
55
56/// Generates a store with proper `index` typing and proper value.
57static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
58 Value idx) {
59 idx = genCast(builder, loc, idx, builder.getIndexType());
60 val = genCast(builder, loc, val,
61 cast<ShapedType>(mem.getType()).getElementType());
62 builder.create<memref::StoreOp>(loc, val, mem, idx);
63}
64
65/// Creates a straightforward counting for-loop.
66static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
67 MutableArrayRef<Value> fields,
68 Value lower = Value()) {
69 Type indexType = builder.getIndexType();
70 if (!lower)
71 lower = constantZero(builder, loc, tp: indexType);
72 Value one = constantOne(builder, loc, tp: indexType);
73 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
74 for (unsigned i = 0, e = fields.size(); i < e; i++)
75 fields[i] = forOp.getRegionIterArg(i);
76 builder.setInsertionPointToStart(forOp.getBody());
77 return forOp;
78}
79
80/// Creates a push back operation.
81static void createPushback(OpBuilder &builder, Location loc,
82 MutSparseTensorDescriptor desc,
83 SparseTensorFieldKind kind, std::optional<Level> lvl,
84 Value value, Value repeat = Value()) {
85 Type etp = desc.getMemRefElementType(kind, lvl);
86 Value field = desc.getMemRefField(kind, lvl);
87 StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
88
89 auto pushBackOp = builder.create<PushBackOp>(
90 loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
91 genCast(builder, loc, value, etp), repeat);
92
93 desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
94 desc.setSpecifierField(builder, loc, specFieldKind, lvl,
95 pushBackOp.getNewSize());
96}
97
98/// Generates code that allocates a sparse storage scheme for given rank.
99static void allocSchemeForRank(OpBuilder &builder, Location loc,
100 MutSparseTensorDescriptor desc, Level startLvl) {
101 const SparseTensorType stt(desc.getRankedTensorType());
102 Value linear = constantIndex(builder, loc, i: 1);
103 const Level lvlRank = stt.getLvlRank();
104 for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
105 const auto lt = stt.getLvlType(l: lvl);
106 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
107 // Append linear x positions, initialized to zero. Since each compressed
108 // dimension initially already has a single zero entry, this maintains
109 // the desired "linear + 1" length property at all times. For loose
110 // compression, we multiply linear by two in order to append both the
111 // lo/hi positions.
112 Value posZero = constantZero(builder, loc, tp: stt.getPosType());
113 if (isLooseCompressedLT(lt)) {
114 Value two = constantIndex(builder, loc, i: 2);
115 linear = builder.create<arith::MulIOp>(loc, linear, two);
116 }
117 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl,
118 /*value=*/posZero, /*repeat=*/linear);
119 return;
120 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
121 return; // nothing to do
122 }
123 // Keep compounding the size, but nothing needs to be initialized
124 // at this level. We will eventually reach a compressed level or
125 // otherwise the values array for the from-here "all-dense" case.
126 assert(isDenseLT(lt));
127 Value size = desc.getLvlSize(builder, loc, lvl);
128 linear = builder.create<arith::MulIOp>(loc, linear, size);
129 }
130 // Reached values array so prepare for an insertion.
131 Value valZero = constantZero(builder, loc, tp: stt.getElementType());
132 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef,
133 lvl: std::nullopt, /*value=*/valZero, /*repeat=*/linear);
134}
135
136/// Creates allocation operation.
137static Value createAllocation(OpBuilder &builder, Location loc,
138 MemRefType memRefType, Value sz,
139 bool enableInit) {
140 Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
141 Type elemType = memRefType.getElementType();
142 if (enableInit) {
143 Value fillValue = constantZero(builder, loc, tp: elemType);
144 builder.create<linalg::FillOp>(loc, fillValue, buffer);
145 }
146 return buffer;
147}
148
149/// Creates the dim sizes array, filling in from dynamic sizes.
150static void createDimSizes(OpBuilder &builder, Location loc,
151 SparseTensorType stt, ValueRange dynSizes,
152 /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
153 const Dimension dimRank = stt.getDimRank();
154 dimSizesValues.clear();
155 dimSizesValues.reserve(N: dimRank);
156 unsigned i = 0;
157 for (const Size sz : stt.getDimShape())
158 dimSizesValues.push_back(ShapedType::isDynamic(sz)
159 ? dynSizes[i++]
160 : constantIndex(builder, loc, sz));
161}
162
163/// Creates allocation for each field in sparse tensor type. Note that
164/// for all dynamic memrefs in the sparse tensor stroage layout, the
165/// memory size is really the capacity of the "vector", while the actual
166/// size resides in the sizes array.
167static void createAllocFields(OpBuilder &builder, Location loc,
168 SparseTensorType stt, bool enableInit,
169 Value sizeHint,
170 SmallVectorImpl<Value> &lvlSizesValues,
171 /*out*/ SmallVectorImpl<Value> &fields) {
172 Level lvlRank = stt.getLvlRank();
173 // Set up some heuristic sizes. We try to set the initial
174 // size based on available information. Otherwise we just
175 // initialize a few elements to start the reallocation chain.
176 // TODO: refine this
177 Value posHeuristic, crdHeuristic, valHeuristic;
178 if (stt.isAllDense()) {
179 valHeuristic = lvlSizesValues[0];
180 for (Level lvl = 1; lvl < lvlRank; lvl++)
181 valHeuristic =
182 builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
183 } else if (sizeHint) {
184 if (stt.getAoSCOOStart() == 0) {
185 posHeuristic = constantIndex(builder, loc, i: 2);
186 crdHeuristic = builder.create<arith::MulIOp>(
187 loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
188 } else if (lvlRank == 2 && stt.isDenseLvl(l: 0) && stt.isCompressedLvl(l: 1)) {
189 posHeuristic = builder.create<arith::AddIOp>(
190 loc, sizeHint, constantIndex(builder, loc, 1));
191 crdHeuristic = sizeHint;
192 } else {
193 posHeuristic = crdHeuristic = constantIndex(builder, loc, i: 16);
194 }
195 valHeuristic = sizeHint;
196 } else {
197 posHeuristic = crdHeuristic = valHeuristic =
198 constantIndex(builder, loc, i: 16);
199 }
200 // Initializes all fields. An initial storage specifier and allocated
201 // positions/coordinates/values memrefs (with heuristic capacity).
202 foreachFieldAndTypeInSparseTensor(
203 stt,
204 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
205 enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
206 Level /*lvl*/, LevelType /*lt*/) -> bool {
207 assert(fields.size() == fIdx);
208 Value field;
209 switch (fKind) {
210 case SparseTensorFieldKind::StorageSpec:
211 field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
212 break;
213 case SparseTensorFieldKind::PosMemRef:
214 field = createAllocation(builder, loc, cast<MemRefType>(fType),
215 posHeuristic, enableInit);
216 break;
217 case SparseTensorFieldKind::CrdMemRef:
218 field = createAllocation(builder, loc, cast<MemRefType>(fType),
219 crdHeuristic, enableInit);
220 break;
221 case SparseTensorFieldKind::ValMemRef:
222 field = createAllocation(builder, loc, cast<MemRefType>(fType),
223 valHeuristic, enableInit);
224 break;
225 }
226 assert(field);
227 fields.push_back(Elt: field);
228 // Returns true to continue the iteration.
229 return true;
230 });
231 // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
232 // and gives all position fields an initial zero entry, so that it is
233 // easier to maintain the "linear + 1" length property.
234 MutSparseTensorDescriptor desc(stt, fields);
235 Value posZero = constantZero(builder, loc, tp: stt.getPosType());
236 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
237 desc.setLvlSize(builder, loc, lvl, v: lvlSizesValues[lvl]);
238 const auto lt = stt.getLvlType(l: lvl);
239 if (isCompressedLT(lt) || isLooseCompressedLT(lt))
240 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl,
241 /*value=*/posZero);
242 }
243 allocSchemeForRank(builder, loc, desc, /*rank=*/startLvl: 0);
244}
245
246/// Helper method that generates block specific to compressed case:
247///
248/// // given: parentPos = posCursor[lvl-1]
249/// pstart = desc.positions[lvl][parentPos]
250/// pstop = desc.positions[lvl][parentPos+1]
251/// plast = pstop - 1
252/// msz = desc.coordinates[lvl].size()
253/// if (pstart < pstop) {
254/// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
255/// } else { // first insertion
256/// isPresent = false
257/// desc.positions[lvl][parentPos] = msz
258/// }
259/// if (isPresent) { // coordinate is already present
260/// pnext = plast
261/// } else {
262/// desc.coordinates[lvl].push_back(lvlCoords[lvl])
263/// desc.positions[lvl][parentPos+1] = msz+1
264/// pnext = msz
265/// <prepare level lvl+1>
266/// }
267/// posCursor[lvl] = pnext
268static Value genCompressed(OpBuilder &builder, Location loc,
269 MutSparseTensorDescriptor desc, ValueRange lvlCoords,
270 Value /*unused*/, Value parentPos, Level lvl) {
271 const SparseTensorType stt(desc.getRankedTensorType());
272 const Level lvlRank = stt.getLvlRank();
273 assert(lvl < lvlRank && "Level is out of bounds");
274 assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
275 "Level-rank mismatch");
276 SmallVector<Type> types;
277 Type indexType = builder.getIndexType();
278 Type boolType = builder.getIntegerType(1);
279 unsigned crdFidx;
280 unsigned crdStride;
281 std::tie(args&: crdFidx, args&: crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
282 const Value one = constantIndex(builder, loc, i: 1);
283 const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
284 const Value positionsAtLvl = desc.getPosMemRef(lvl);
285 const Value pstart = genLoad(builder, loc, mem: positionsAtLvl, idx: parentPos);
286 const Value pstop = genLoad(builder, loc, mem: positionsAtLvl, idx: pp1);
287 const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
288 const Value crdStrideC =
289 crdStride > 1 ? constantIndex(builder, loc, i: crdStride) : Value();
290 const Value msz =
291 crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
292 : crdMsz;
293 const Value plast = builder.create<arith::SubIOp>(
294 loc, genCast(builder, loc, pstop, indexType), one);
295 // Conditional expression.
296 Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
297 pstart, pstop);
298 types.push_back(Elt: boolType);
299 scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
300 types.pop_back();
301 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
302 Value crd =
303 genLoad(builder, loc, desc.getMemRefField(crdFidx),
304 crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
305 : plast);
306 Value eq = builder.create<arith::CmpIOp>(
307 loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
308 lvlCoords[lvl]);
309 builder.create<scf::YieldOp>(loc, eq);
310 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
311 if (lvl > 0)
312 genStore(builder, loc, val: msz, mem: positionsAtLvl, idx: parentPos);
313 builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
314 builder.setInsertionPointAfter(ifOp1);
315 // If present construct. Note that for a non-unique dimension level, we
316 // simply set the condition to false and rely on CSE/DCE to clean up the IR.
317 //
318 // TODO: generate less temporary IR?
319 //
320 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
321 types.push_back(Elt: desc.getField(i).getType());
322 types.push_back(Elt: indexType);
323 const Value p = stt.isUniqueLvl(l: lvl) ? ifOp1.getResult(0)
324 : constantI1(builder, loc, b: false);
325 scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
326 // If present (fields unaffected, update pnext to plast).
327 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
328
329 // FIXME: This does not looks like a clean way, but probably the most
330 // efficient way.
331 desc.getFields().push_back(plast);
332 builder.create<scf::YieldOp>(loc, desc.getFields());
333 desc.getFields().pop_back();
334
335 // If !present (changes fields, update pnext).
336 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
337 Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
338 genStore(builder, loc, val: mszp1, mem: positionsAtLvl, idx: pp1);
339 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef, lvl,
340 /*value=*/lvlCoords[lvl]);
341 // Prepare the next level "as needed".
342 if ((lvl + 1) < lvlRank)
343 allocSchemeForRank(builder, loc, desc, startLvl: lvl + 1);
344
345 desc.getFields().push_back(msz);
346 builder.create<scf::YieldOp>(loc, desc.getFields());
347 desc.getFields().pop_back();
348
349 // Update fields and return next pos.
350 builder.setInsertionPointAfter(ifOp2);
351 unsigned o = 0;
352 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
353 desc.setField(fidx: i, v: ifOp2.getResult(o++));
354 return ifOp2.getResult(o);
355}
356
357/// Generates insertion finalization code.
358static void genEndInsert(OpBuilder &builder, Location loc,
359 SparseTensorDescriptor desc) {
360 const SparseTensorType stt(desc.getRankedTensorType());
361 const Level lvlRank = stt.getLvlRank();
362 for (Level lvl = 0; lvl < lvlRank; lvl++) {
363 const auto lt = stt.getLvlType(l: lvl);
364 if (isCompressedLT(lt)) {
365 // Compressed dimensions need a position cleanup for all entries
366 // that were not visited during the insertion pass.
367 //
368 // TODO: avoid cleanup and keep compressed scheme consistent at all
369 // times?
370 //
371 if (lvl > 0) {
372 Type posType = stt.getPosType();
373 Value posMemRef = desc.getPosMemRef(lvl);
374 Value hi = desc.getPosMemSize(builder, loc, lvl);
375 Value zero = constantIndex(builder, loc, i: 0);
376 Value one = constantIndex(builder, loc, i: 1);
377 // Vector of only one, but needed by createFor's prototype.
378 SmallVector<Value, 1> inits{genLoad(builder, loc, mem: posMemRef, idx: zero)};
379 scf::ForOp loop = createFor(builder, loc, hi, inits, one);
380 Value i = loop.getInductionVar();
381 Value oldv = loop.getRegionIterArg(0);
382 Value newv = genLoad(builder, loc, mem: posMemRef, idx: i);
383 Value posZero = constantZero(builder, loc, tp: posType);
384 Value cond = builder.create<arith::CmpIOp>(
385 loc, arith::CmpIPredicate::eq, newv, posZero);
386 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
387 cond, /*else*/ true);
388 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
389 genStore(builder, loc, val: oldv, mem: posMemRef, idx: i);
390 builder.create<scf::YieldOp>(loc, oldv);
391 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
392 builder.create<scf::YieldOp>(loc, newv);
393 builder.setInsertionPointAfter(ifOp);
394 builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
395 builder.setInsertionPointAfter(loop);
396 }
397 } else {
398 assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
399 isNOutOfMLT(lt));
400 }
401 }
402}
403
404/// Generates a subview into the sizes.
405static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
406 Value sz) {
407 auto memTp = llvm::cast<MemRefType>(mem.getType());
408 // For higher-dimensional memrefs, we assume that the innermost
409 // dimension is always of the right size.
410 // TODO: generate complex truncating view here too?
411 if (memTp.getRank() > 1)
412 return mem;
413 // Truncate linear memrefs to given size.
414 return builder
415 .create<memref::SubViewOp>(
416 loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
417 mem, ValueRange{}, ValueRange{sz}, ValueRange{},
418 ArrayRef<int64_t>{0}, // static offset
419 ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
420 ArrayRef<int64_t>{1}) // static stride
421 .getResult();
422}
423
424/// Creates the reassociation array.
425static SmallVector<ReassociationIndices>
426getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
427 SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
428 // Create reassociation in the form:
429 // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
430 for (unsigned i = 0; i < batchLvls; i++)
431 ret[i].push_back(Elt: i);
432
433 for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
434 ret.back().push_back(Elt: i);
435
436 return ret;
437}
438
439//===----------------------------------------------------------------------===//
440// Codegen rules.
441//===----------------------------------------------------------------------===//
442
443namespace {
444
445/// Helper class to help lowering sparse_tensor.insert operation.
446class SparseInsertGenerator
447 : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
448public:
449 SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
450 bool genCall)
451 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
452
453 /// Generates code along an insertion path without the need for a "cursor".
454 /// This current insertion strategy comes at the expense of some testing
455 /// overhead for each insertion. The strategy will be optimized later for
456 /// common insertion patterns. The current insertion strategy also assumes
457 /// insertions occur in "a reasonable order" that enables building the
458 /// storage scheme in an appending/inserting kind of fashion (i.e. no
459 /// in-between insertions that need data movement). The implementation
460 /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
461 ///
462 /// TODO: better unord/not-unique; also generalize, optimize, specialize!
463 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
464 OpBuilder &builder, Location loc) {
465 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
466 const Level lvlRank = stt.getLvlRank();
467 // Extract fields and coordinates from args.
468 SmallVector<Value> fields = llvm::to_vector(Range: args.drop_back(n: lvlRank + 1));
469 MutSparseTensorDescriptor desc(stt, fields);
470 const SmallVector<Value> coords =
471 llvm::to_vector(Range: args.take_back(n: lvlRank + 1).drop_back());
472 Value value = args.back();
473 Value parentPos = constantZero(builder, loc, builder.getIndexType());
474 // Generate code for every level.
475 for (Level lvl = 0; lvl < lvlRank; lvl++) {
476 const auto lt = stt.getLvlType(l: lvl);
477 if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
478 // Create:
479 // if (!present) {
480 // coordinates[lvl].push_back(coords[lvl])
481 // <update positions and prepare level lvl + 1>
482 // }
483 // positions[lvl] = coordinates.size() - 1
484 // <insert @ positions[lvl] at next level lvl + 1>
485 if (isLooseCompressedLT(lt)) {
486 Value two = constantIndex(builder, loc, i: 2);
487 parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
488 }
489 parentPos =
490 genCompressed(builder, loc, desc, lvlCoords: coords, value, parentPos, lvl);
491 } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
492 // Create:
493 // coordinates[lvl].push_back(coords[lvl])
494 // positions[lvl] = positions[lvl-1]
495 // <insert @ positions[lvl] at next level lvl + 1>
496 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef,
497 lvl, /*value=*/coords[lvl]);
498 } else {
499 assert(isDenseLT(lt));
500 // Construct the new position as:
501 // positions[lvl] = size * positions[lvl-1] + coords[lvl]
502 // <insert @ positions[lvl] at next level lvl + 1>
503 Value size = desc.getLvlSize(builder, loc, lvl);
504 Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
505 parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
506 }
507 }
508 // Reached the actual value append/insert.
509 if (!stt.isDenseLvl(l: lvlRank - 1))
510 createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef,
511 lvl: std::nullopt, value);
512 else
513 genStore(builder, loc, value, desc.getValMemRef(), parentPos);
514 return fields;
515 }
516
517 std::string getMangledFuncName() {
518 // The mangled name of the function has this format:
519 // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
520 constexpr const char kInsertFuncNamePrefix[] = "_insert_";
521 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
522 SmallString<32> nameBuffer;
523 llvm::raw_svector_ostream nameOstream(nameBuffer);
524 nameOstream << kInsertFuncNamePrefix;
525 const Level lvlRank = stt.getLvlRank();
526 for (Level l = 0; l < lvlRank; l++) {
527 std::string lvlType = toMLIRString(lt: stt.getLvlType(l));
528 // Replace/remove punctuations in level properties.
529 std::replace_if(
530 first: lvlType.begin(), last: lvlType.end(),
531 pred: [](char c) { return c == '(' || c == ','; }, new_value: '_');
532 llvm::erase_if(C&: lvlType, P: [](char c) { return c == ')' || c == ' '; });
533 nameOstream << lvlType << "_";
534 }
535 // Static dim sizes are used in the generated code while dynamic sizes are
536 // loaded from the dimSizes buffer. This is the reason for adding the shape
537 // to the function name.
538 for (const auto sz : stt.getDimShape())
539 nameOstream << sz << "_";
540 // Permutation information is also used in generating insertion.
541 if (!stt.isIdentity())
542 nameOstream << stt.getDimToLvl() << "_";
543 nameOstream << stt.getElementType() << "_";
544 nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
545 return nameOstream.str().str();
546 }
547
548private:
549 TensorType rtp;
550};
551
552/// Sparse tensor storage conversion rule for returns.
553class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
554public:
555 using OpConversionPattern::OpConversionPattern;
556 LogicalResult
557 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
558 ConversionPatternRewriter &rewriter) const override {
559 // Create a return with the flattened value extracted from sparse tensors.
560 rewriter.replaceOpWithNewOp<func::ReturnOp>(
561 op, flattenValues(adaptor.getOperands()));
562 return success();
563 }
564};
565
566/// Sparse tensor storage conversion rule for calls.
567class SparseCallConverter : public OpConversionPattern<func::CallOp> {
568public:
569 // The default CallOp converter can not handle 1:N type conversion.
570 using OpConversionPattern::OpConversionPattern;
571 LogicalResult
572 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
573 ConversionPatternRewriter &rewriter) const override {
574 Location loc = op.getLoc();
575 // In case of:
576 // sparse_tensor, f, sparse_tensor = call @foo(...)
577 // ==>
578 // memref..., f, memref = call @foo(...) replace with
579 // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
580 SmallVector<Type> finalRetTy;
581 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
582 return failure();
583
584 // (1) Generates new call with flattened return value.
585 auto newCall = rewriter.create<func::CallOp>(
586 loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
587 // (2) Gather sparse tensor returns.
588 SmallVector<SmallVector<Value>> packedResultVals;
589 // Tracks the offset of current return value (of the original call)
590 // relative to the new call (after sparse tensor flattening);
591 unsigned retOffset = 0;
592 // Temporal buffer to hold the flattened list of type for
593 // a sparse tensor.
594 SmallVector<Type> sparseFlat;
595 for (auto ret : op.getResults()) {
596 assert(retOffset < newCall.getNumResults());
597 auto retType = ret.getType();
598 if (failed(typeConverter->convertType(retType, sparseFlat)))
599 llvm_unreachable("Failed to convert type in sparse tensor codegen");
600
601 // Converted types can not be empty when the type conversion succeed.
602 assert(!sparseFlat.empty());
603 if (sparseFlat.size() > 1) {
604 auto flatSize = sparseFlat.size();
605 packedResultVals.emplace_back();
606 llvm::append_range(packedResultVals.back(),
607 newCall.getResults().slice(retOffset, flatSize));
608 retOffset += flatSize;
609 } else {
610 // If this is an 1:1 conversion, no need for casting.
611 packedResultVals.emplace_back();
612 packedResultVals.back().push_back(newCall.getResult(retOffset));
613 retOffset++;
614 }
615 sparseFlat.clear();
616 }
617
618 assert(packedResultVals.size() == op.getNumResults());
619 rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
620 return success();
621 }
622};
623
624/// Sparse codegen rule for level accesses.
625class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
626public:
627 using OpConversionPattern::OpConversionPattern;
628 LogicalResult
629 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
630 ConversionPatternRewriter &rewriter) const override {
631 std::optional<int64_t> lvl = op.getConstantLvlIndex();
632 RankedTensorType srcType = op.getSource().getType();
633 if (!lvl || !getSparseTensorEncoding(srcType))
634 return failure();
635
636 auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
637 auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
638
639 rewriter.replaceOp(op, sz);
640 return success();
641 }
642};
643
644// TODO: use a new SortCOO operation here instead of reusing convert op.
645struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
646 using OpConversionPattern::OpConversionPattern;
647 LogicalResult
648 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
649 ConversionPatternRewriter &rewriter) const override {
650 Location loc = op.getLoc();
651 MLIRContext *ctx = op.getContext();
652
653 SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
654 SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
655
656 // Should have been verified.
657 assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
658 dstStt.isCOOType() && srcStt.isCOOType());
659 assert(dstStt.hasSameDimToLvl(srcStt));
660
661 // We don't need a mutable descriptor here as we perform sorting in-place.
662 auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
663 op.getInputCoo().getType());
664 auto nnz = desc.getValMemSize(rewriter, op.getLoc());
665 auto crd = desc.getAOSMemRef();
666 auto val = desc.getValMemRef();
667
668 // Otherwise we need another data shuffle and a non-identity map.
669 assert(dstStt.hasSameDimToLvl(srcStt));
670 (void)dstStt; // to silence warning when assertion is disabled
671
672 auto id = AffineMap::getMultiDimIdentityMap(numDims: srcStt.getLvlRank(), context: ctx);
673
674 rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
675 rewriter.getIndexAttr(0), op.getAlgorithm());
676
677 // Since we do in-place sorting, the destinate tensor will have the same set
678 // of memrefs as the source tensor.
679 rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
680 return success();
681 }
682};
683
684template <typename Op, StorageSpecifierKind kind>
685class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
686public:
687 using OpConversionPattern<Op>::OpConversionPattern;
688 using typename OpConversionPattern<Op>::OneToNOpAdaptor;
689
690 LogicalResult
691 matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
692 ConversionPatternRewriter &rewriter) const override {
693 // Simply lowers to specifer.get <field> operation.
694 auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
695 op.getSlice().getType());
696 auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
697 op.getDim().getZExtValue());
698
699 rewriter.replaceOp(op, v);
700 return success();
701 }
702};
703
704/// Sparse codegen rule for trivial tensor casts.
705class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
706public:
707 using OpConversionPattern::OpConversionPattern;
708 LogicalResult
709 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 // Only rewrite identically annotated source/dest.
712 auto encDst = getSparseTensorEncoding(op.getType());
713 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
714 if (!encDst || encDst != encSrc)
715 return failure();
716 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
717 return success();
718 }
719};
720
721class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
722public:
723 using OpConversionPattern::OpConversionPattern;
724 LogicalResult
725 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
726 ConversionPatternRewriter &rewriter) const override {
727 // Simply fold the operation.
728 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
729 return success();
730 }
731};
732
733/// Sparse codegen rule for the alloc operator.
734class SparseTensorAllocConverter
735 : public OpConversionPattern<bufferization::AllocTensorOp> {
736public:
737 using OpConversionPattern::OpConversionPattern;
738 SparseTensorAllocConverter(const TypeConverter &typeConverter,
739 MLIRContext *context, bool enableInit)
740 : OpConversionPattern(typeConverter, context),
741 enableBufferInitialization(enableInit) {}
742
743 LogicalResult
744 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
745 ConversionPatternRewriter &rewriter) const override {
746 const auto resType = getSparseTensorType(op);
747 if (!resType.hasEncoding())
748 return failure();
749
750 Location loc = op.getLoc();
751 // Deal with copy.
752 if (op.getCopy()) {
753 auto desc = getDescriptorFromTensorTuple(
754 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
755 SmallVector<Value> fields;
756 fields.reserve(N: desc.getNumFields());
757 // Memcpy on memref fields.
758 for (auto field : desc.getMemRefFields()) {
759 auto memrefTp = cast<MemRefType>(field.getType());
760 auto size = rewriter.create<memref::DimOp>(loc, field, 0);
761 auto copied =
762 rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
763 rewriter.create<memref::CopyOp>(loc, field, copied);
764 fields.push_back(copied);
765 }
766 // Reuses specifier.
767 fields.push_back(Elt: desc.getSpecifier());
768 assert(fields.size() == desc.getNumFields());
769 rewriter.replaceOpWithMultiple(op, {fields});
770 return success();
771 }
772
773 if (!resType.isIdentity()) {
774 return rewriter.notifyMatchFailure(
775 op, "try run --sparse-reinterpret-map before codegen");
776 }
777 // Level size equals to dimension size since lvl2dim map is an identity map.
778 SmallVector<Value> lvlSizesValues;
779 createDimSizes(rewriter, loc, resType,
780 flattenValues(adaptor.getDynamicSizes()),
781 /*dimSizesValues=*/lvlSizesValues);
782
783 // Construct allocation for each field.
784 Value sizeHint = op.getSizeHint();
785 SmallVector<Value> fields;
786 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
787 sizeHint, lvlSizesValues, fields);
788
789 // Replace operation with resulting memrefs.
790 rewriter.replaceOpWithMultiple(op, {fields});
791 return success();
792 }
793
794private:
795 bool enableBufferInitialization;
796};
797
798/// Sparse codegen rule for the empty tensor operator.
799class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
800public:
801 using OpConversionPattern::OpConversionPattern;
802 SparseTensorEmptyConverter(const TypeConverter &typeConverter,
803 MLIRContext *context, bool enableInit)
804 : OpConversionPattern(typeConverter, context),
805 enableBufferInitialization(enableInit) {}
806
807 LogicalResult
808 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
809 ConversionPatternRewriter &rewriter) const override {
810 const auto resType = getSparseTensorType(op);
811 if (!resType.hasEncoding())
812 return failure();
813
814 if (!resType.isIdentity()) {
815 return rewriter.notifyMatchFailure(
816 op, "try run --sparse-reinterpret-map before codegen");
817 }
818
819 Location loc = op.getLoc();
820 // Level size equals to dimension size since lvl2dim map is an identity map.
821 SmallVector<Value> lvlSizesValues;
822 createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
823 /*dimSizesValues=*/lvlSizesValues);
824 // Construct allocation for each field.
825 Value sizeHint; // none
826 SmallVector<Value> fields;
827 createAllocFields(rewriter, loc, resType, enableBufferInitialization,
828 sizeHint, lvlSizesValues, fields);
829
830 // Replace operation with resulting memrefs.
831 rewriter.replaceOpWithMultiple(op, {fields});
832 return success();
833 }
834
835private:
836 bool enableBufferInitialization;
837};
838
839/// Sparse codegen rule for the dealloc operator.
840class SparseTensorDeallocConverter
841 : public OpConversionPattern<bufferization::DeallocTensorOp> {
842public:
843 using OpConversionPattern::OpConversionPattern;
844 SparseTensorDeallocConverter(const TypeConverter &typeConverter,
845 MLIRContext *context, bool createDeallocs)
846 : OpConversionPattern(typeConverter, context),
847 createDeallocs(createDeallocs) {}
848
849 LogicalResult
850 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
851 ConversionPatternRewriter &rewriter) const override {
852 auto enc = getSparseTensorEncoding(op.getTensor().getType());
853 if (!enc)
854 return failure();
855
856 // If user requests not to deallocate sparse tensors, simply erase the
857 // operation.
858 if (createDeallocs) {
859 // Replace the sparse tensor deallocation with field deallocations.
860 Location loc = op.getLoc();
861 auto desc = getDescriptorFromTensorTuple(
862 adaptor.getTensor(),
863 cast<RankedTensorType>(op.getTensor().getType()));
864 for (auto input : desc.getMemRefFields())
865 // Deallocate every buffer used to store the sparse tensor handler.
866 rewriter.create<memref::DeallocOp>(loc, input);
867 }
868 rewriter.eraseOp(op: op);
869 return success();
870 }
871
872private:
873 const bool createDeallocs;
874};
875
876/// Sparse codegen rule for tensor rematerialization.
877class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
878public:
879 using OpConversionPattern::OpConversionPattern;
880 LogicalResult
881 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
882 ConversionPatternRewriter &rewriter) const override {
883 // Prepare descriptor.
884 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
885 op.getTensor().getType());
886 // Generate optional insertion finalization code.
887 if (op.getHasInserts())
888 genEndInsert(rewriter, op.getLoc(), desc);
889 // Replace operation with resulting memrefs.
890 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
891 return success();
892 }
893};
894
895/// Sparse codegen rule for the expand op.
896class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
897public:
898 using OpConversionPattern::OpConversionPattern;
899 LogicalResult
900 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter) const override {
902 if (!getSparseTensorEncoding(op.getTensor().getType()))
903 return failure();
904 Location loc = op->getLoc();
905 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
906 op.getTensor().getType());
907 const auto srcType = getSparseTensorType(op.getTensor());
908 Type eltType = srcType.getElementType();
909 Type boolType = rewriter.getIntegerType(1);
910 Type idxType = rewriter.getIndexType();
911 // All initialization should be done on entry of the loop nest.
912 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
913
914 // Determine the size for access expansion (always the innermost stored
915 // level size).
916 const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
917 // Generate a memref for `sz` elements of type `t`.
918 const auto genAlloc = [&](Type t) {
919 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
920 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
921 };
922 // Allocate temporary buffers for values/filled-switch and added.
923 // We do not use stack buffers for this, since the expanded size may
924 // be rather large (as it envelops a single expanded dense dimension).
925 Value values = genAlloc(eltType);
926 Value filled = genAlloc(boolType);
927 Value added = genAlloc(idxType);
928 Value zero = constantZero(builder&: rewriter, loc, tp: idxType);
929 // Reset the values/filled-switch to all-zero/false. Note that this
930 // introduces an O(N) operation into the computation, but this reset
931 // operation is amortized over the innermost loops for the access
932 // pattern expansion. As noted in the operation doc, we would like
933 // to amortize this setup cost even between kernels.
934 rewriter.create<linalg::FillOp>(
935 loc, ValueRange{constantZero(rewriter, loc, eltType)},
936 ValueRange{values});
937 rewriter.create<linalg::FillOp>(
938 loc, ValueRange{constantZero(rewriter, loc, boolType)},
939 ValueRange{filled});
940 // Replace expansion op with these buffers and initial coordinate.
941 assert(op.getNumResults() == 4);
942 rewriter.replaceOp(op, {values, filled, added, zero});
943 return success();
944 }
945};
946
947/// Sparse codegen rule for the compress operator.
948class SparseCompressConverter : public OpConversionPattern<CompressOp> {
949public:
950 using OpConversionPattern::OpConversionPattern;
951 LogicalResult
952 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
953 ConversionPatternRewriter &rewriter) const override {
954 Location loc = op->getLoc();
955 SmallVector<Value> fields;
956 auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
957 op.getTensor().getType());
958 Value values = llvm::getSingleElement(adaptor.getValues());
959 Value filled = llvm::getSingleElement(adaptor.getFilled());
960 Value added = llvm::getSingleElement(adaptor.getAdded());
961 Value count = llvm::getSingleElement(adaptor.getCount());
962 const SparseTensorType dstType(desc.getRankedTensorType());
963 Type eltType = dstType.getElementType();
964
965 // If the innermost level is ordered, we need to sort the coordinates
966 // in the "added" array prior to applying the compression.
967 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
968 rewriter.create<SortOp>(
969 loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
970 rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
971 // While performing the insertions, we also need to reset the elements
972 // of the values/filled-switch by only iterating over the set elements,
973 // to ensure that the runtime complexity remains proportional to the
974 // sparsity of the expanded access pattern.
975 //
976 // Generate
977 // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
978 // crd = added[i];
979 // value = values[crd];
980 // insert({lvlCoords, crd}, value);
981 // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
982 // values[crd] = 0;
983 // filled[crd] = false;
984 // yield new_memrefs
985 // }
986 scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
987 Value i = loop.getInductionVar();
988
989 Value crd = genLoad(builder&: rewriter, loc, mem: added, idx: i);
990 Value value = genLoad(builder&: rewriter, loc, mem: values, idx: crd);
991 SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
992 SmallVector<Type> flatSpTensorTps = llvm::to_vector(
993 llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
994 SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
995 params.append(in_start: flatLvlCoords.begin(), in_end: flatLvlCoords.end());
996 params.push_back(Elt: crd);
997 params.push_back(Elt: value);
998 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
999 params, /*genCall=*/true);
1000 SmallVector<Value> insertRet = insertGen.genCallOrInline(builder&: rewriter, loc);
1001 genStore(builder&: rewriter, loc, val: constantZero(builder&: rewriter, loc, tp: eltType), mem: values, idx: crd);
1002 genStore(builder&: rewriter, loc, val: constantI1(builder&: rewriter, loc, b: false), mem: filled, idx: crd);
1003 rewriter.create<scf::YieldOp>(loc, insertRet);
1004
1005 rewriter.setInsertionPointAfter(loop);
1006 // Deallocate the buffers on exit of the full loop nest.
1007 Operation *parent = getTop(op);
1008 rewriter.setInsertionPointAfter(parent);
1009 rewriter.create<memref::DeallocOp>(loc, values);
1010 rewriter.create<memref::DeallocOp>(loc, filled);
1011 rewriter.create<memref::DeallocOp>(loc, added);
1012 // Replace operation with resulting memrefs.
1013 rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1014 return success();
1015 }
1016};
1017
1018/// Sparse codegen rule for the insert operator.
1019class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1020public:
1021 using OpConversionPattern::OpConversionPattern;
1022 LogicalResult
1023 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1024 ConversionPatternRewriter &rewriter) const override {
1025 auto stt = getSparseTensorType(op.getDest());
1026 if (!stt.hasEncoding())
1027 return failure();
1028 assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1029
1030 Location loc = op.getLoc();
1031 auto desc =
1032 getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1033 TypeRange flatSpTensorTps = desc.getFields().getTypes();
1034 SmallVector<Value> params = llvm::to_vector(desc.getFields());
1035 SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1036 params.append(in_start: flatIndices.begin(), in_end: flatIndices.end());
1037 params.push_back(Elt: llvm::getSingleElement(adaptor.getScalar()));
1038 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1039 params, /*genCall=*/true);
1040 SmallVector<Value> ret = insertGen.genCallOrInline(builder&: rewriter, loc);
1041 // Replace operation with resulting memrefs.
1042 rewriter.replaceOpWithMultiple(op, {ret});
1043 return success();
1044 }
1045};
1046
1047/// Sparse codegen rule for position accesses.
1048class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1049public:
1050 using OpAdaptor = typename ToPositionsOp::Adaptor;
1051 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1052 LogicalResult
1053 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1054 ConversionPatternRewriter &rewriter) const override {
1055 // Replace the requested position access with corresponding field.
1056 // The view is restricted to the actual size to ensure clients
1057 // of this operation truly observe size, not capacity!
1058 Location loc = op.getLoc();
1059 Level lvl = op.getLevel();
1060 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1061 op.getTensor().getType());
1062 auto mem = desc.getPosMemRef(lvl);
1063 auto size = desc.getPosMemSize(rewriter, loc, lvl);
1064 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1065 return success();
1066 }
1067};
1068
1069/// Sparse codegen rule for accessing the coordinates arrays.
1070class SparseToCoordinatesConverter
1071 : public OpConversionPattern<ToCoordinatesOp> {
1072public:
1073 using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1074 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1075 LogicalResult
1076 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1077 ConversionPatternRewriter &rewriter) const override {
1078 // Replace the requested coordinates access with corresponding field.
1079 // The view is restricted to the actual size to ensure clients
1080 // of this operation truly observe size, not capacity!
1081 Location loc = op.getLoc();
1082 Level lvl = op.getLevel();
1083 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1084 op.getTensor().getType());
1085 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1086 if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1087 auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1088 mem = genSliceToSize(rewriter, loc, mem, size);
1089 }
1090 rewriter.replaceOp(op, mem);
1091 return success();
1092 }
1093};
1094
1095/// Sparse codegen rule for accessing the linear coordinates buffer.
1096class SparseToCoordinatesBufferConverter
1097 : public OpConversionPattern<ToCoordinatesBufferOp> {
1098public:
1099 using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1100 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1101 LogicalResult
1102 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1103 ConversionPatternRewriter &rewriter) const override {
1104 // Replace the requested coordinates access with corresponding field.
1105 // The view is restricted to the actual size to ensure clients
1106 // of this operation truly observe size, not capacity!
1107 Location loc = op.getLoc();
1108 Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1109 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1110 op.getTensor().getType());
1111 auto mem = desc.getAOSMemRef();
1112 auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1113 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1114 return success();
1115 }
1116};
1117
1118/// Sparse codegen rule for value accesses.
1119class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1120public:
1121 using OpAdaptor = typename ToValuesOp::Adaptor;
1122 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1123 LogicalResult
1124 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1125 ConversionPatternRewriter &rewriter) const override {
1126 // Replace the requested values access with corresponding field.
1127 // The view is restricted to the actual size to ensure clients
1128 // of this operation truly observe size, not capacity!
1129 Location loc = op.getLoc();
1130 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1131 op.getTensor().getType());
1132 auto mem = desc.getValMemRef();
1133 auto size = desc.getValMemSize(rewriter, loc);
1134 rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1135 return success();
1136 }
1137};
1138
1139/// Sparse codegen rule for the convert operator.
1140class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1141public:
1142 using OpConversionPattern::OpConversionPattern;
1143 LogicalResult
1144 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1145 ConversionPatternRewriter &rewriter) const override {
1146 SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1147 SparseTensorEncodingAttr encSrc =
1148 getSparseTensorEncoding(op.getSource().getType());
1149 // The output tensor can not be a slice and those cases should have been
1150 // rejected by ConvertOp::verify() already.
1151 assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1152 // Different encoding (except for different bitwidth) should be handled by
1153 // rewriting.
1154 // We need further rewrites if the input tensor is a slice too.
1155 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1156 encSrc.isSlice()) {
1157 return failure();
1158 }
1159
1160 Type retElemTp = op.getResult().getType().getElementType();
1161 Type srcElemTp = op.getSource().getType().getElementType();
1162 // Fold the trivial cases.
1163 if (retElemTp == srcElemTp && encDst == encSrc) {
1164 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1165 return success();
1166 }
1167 //
1168 // Do element-wise type conversion without using InsertOp.
1169 //
1170 // for each memref in srcTensor:
1171 // dst = memref.alloc
1172 // if srcMemRefType != dstMemRefType:
1173 // for every dst[i] = cast(src[i])
1174 // else:
1175 // dst = memref.copy(src)
1176 Location loc = op.getLoc();
1177 auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1178 op.getSource().getType());
1179 SmallVector<Value> fields;
1180 foreachFieldAndTypeInSparseTensor(
1181 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1182 [&rewriter, &fields, srcDesc,
1183 loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1184 LevelType /*lt*/) -> bool {
1185 // Simply reuses the storage specifier as it is an SSA value.
1186 if (fKind == SparseTensorFieldKind::StorageSpec) {
1187 fields.push_back(Elt: srcDesc.getSpecifier());
1188 } else {
1189 // Allocates new memrefs
1190 Value srcMem = srcDesc.getMemRefField(fIdx);
1191 // TODO: We can instead use the actual memSize in specifier, that
1192 // would require a subViewOp to avoid overflow when copying
1193 // values.
1194 Value sz = linalg::createOrFoldDimOp(b&: rewriter, loc, val: srcMem, dim: 0);
1195 auto dstMem = rewriter.create<memref::AllocOp>(
1196 loc, cast<MemRefType>(fTp), sz);
1197 if (fTp != srcMem.getType()) {
1198 // Converts elements type.
1199 scf::buildLoopNest(
1200 builder&: rewriter, loc, lbs: constantIndex(builder&: rewriter, loc, i: 0), ubs: sz,
1201 steps: constantIndex(builder&: rewriter, loc, i: 1),
1202 bodyBuilder: [srcMem, &dstMem](OpBuilder &builder, Location loc,
1203 ValueRange ivs) {
1204 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1205 Value casted = genCast(builder, loc, v,
1206 dstMem.getType().getElementType());
1207 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1208 });
1209 } else {
1210 // TODO: We can even reuse the same memref for the new tensor,
1211 // but that requires a `ref-counting` based memory management
1212 // for shared memrefs between multiple sparse tensors.
1213 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1214 }
1215 fields.push_back(Elt: dstMem);
1216 }
1217 return true;
1218 });
1219
1220 rewriter.replaceOpWithMultiple(op, {fields});
1221 return success();
1222 }
1223};
1224
1225class SparseExtractSliceConverter
1226 : public OpConversionPattern<tensor::ExtractSliceOp> {
1227public:
1228 using OpConversionPattern::OpConversionPattern;
1229 LogicalResult
1230 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1231 ConversionPatternRewriter &rewriter) const override {
1232 Location loc = op.getLoc();
1233 MLIRContext *ctx = op.getContext();
1234 auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1235 auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1236 // TODO: We should check these in ExtractSliceOp::verify.
1237 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1238 return failure();
1239 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1240
1241 SmallVector<Value> fields;
1242 auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1243 op.getSource().getType());
1244
1245 auto newSpec = rewriter.create<StorageSpecifierInitOp>(
1246 loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
1247 desc.setSpecifier(newSpec);
1248
1249 // Fills in slice information.
1250 for (auto [idx, offset, size, stride] : llvm::enumerate(
1251 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1252 Dimension dim = idx;
1253
1254 Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1255 Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1256 Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1257 // TODO: We could probably only set dynamic value here. But it would
1258 // requires us to fill the hole when casting a static slice to dynamic
1259 // slice.
1260 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1261 dim, offsetV);
1262
1263 // FIXME: we need to distinguish level sizes and dimension size for slices
1264 // here. Maybe we should store slice level sizes in a different array
1265 // instead of reusing it.
1266 assert(srcEnc.isIdentity());
1267 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1268 sizeV);
1269 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1270 dim, strideV);
1271 }
1272
1273 // NOTE: we can not generate tuples directly from descriptor here, as the
1274 // descriptor is holding the original type, yet we want the slice type
1275 // here (they shared every memref but with an updated specifier).
1276 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1277 return success();
1278 }
1279};
1280
1281/// Sparse codegen rule for number of entries operator.
1282class SparseNumberOfEntriesConverter
1283 : public OpConversionPattern<NumberOfEntriesOp> {
1284public:
1285 using OpConversionPattern::OpConversionPattern;
1286 LogicalResult
1287 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1288 ConversionPatternRewriter &rewriter) const override {
1289 // Query memSizes for the actually stored values.
1290 // FIXME: the nse value computed in this way might be wrong when there is
1291 // any "loose_compressed" level.
1292 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1293 op.getTensor().getType());
1294 rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1295 return success();
1296 }
1297};
1298
1299struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1300 using OpConversionPattern::OpConversionPattern;
1301 LogicalResult
1302 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1303 ConversionPatternRewriter &rewriter) const override {
1304 Location loc = op.getLoc();
1305 const auto stt = getSparseTensorType(op.getResult());
1306
1307 SmallVector<Value> fields;
1308
1309 foreachFieldAndTypeInSparseTensor(
1310 stt,
1311 [&rewriter, &fields, &op, &stt,
1312 loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1313 Level /*lvl*/, LevelType lt) -> bool {
1314 assert(fields.size() == fIdx);
1315 if (fKind == SparseTensorFieldKind::StorageSpec) {
1316 fields.push_back(
1317 Elt: SparseTensorSpecifier::getInitValue(builder&: rewriter, loc, stt: stt));
1318 } else {
1319 // Else simply takes the inputs.
1320 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1321 ? op.getValues()
1322 : op.getLevels()[fIdx];
1323 // TODO: handle batch.
1324 TypedValue<BaseMemRefType> mem = genToMemref(builder&: rewriter, loc, tensor);
1325 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1326 // Flattens the buffer to batchLvlRank.
1327 auto reassoc = getReassociationForFlattening(
1328 mem.getType(), stt.getBatchLvlRank());
1329 mem = rewriter.create<memref::CastOp>(
1330 loc, fType,
1331 rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1332 } else {
1333 mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1334 }
1335 fields.push_back(Elt: mem);
1336 }
1337 return true;
1338 });
1339
1340 MutSparseTensorDescriptor desc(stt, fields);
1341 Value c0 = constantIndex(builder&: rewriter, loc, i: 0);
1342 Value c1 = constantIndex(builder&: rewriter, loc, i: 1);
1343 Value c2 = constantIndex(builder&: rewriter, loc, i: 2);
1344 Value posBack = c0; // index to the last value in the position array
1345 Value memSize = c1; // memory size for current array
1346
1347 Level trailCOOStart = stt.getAoSCOOStart();
1348 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1349 // Sets up SparseTensorSpecifier.
1350 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1351 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1352
1353 // Sets up the level size.
1354 auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1355 desc.setLvlSize(builder&: rewriter, loc, lvl, v: lvlSize);
1356 // We use a single AOS array to store the trailing COO, so there is only
1357 // one memory size to set for the entire COO section.
1358 if (lvl > trailCOOStart)
1359 continue;
1360
1361 // Sets up the memory size by reading the last value in position array.
1362 LevelType lt = stt.getLvlType(lvl);
1363 // Simply forwards the position index when this is a dense level.
1364 if (lt.isa<LevelFormat::Dense>()) {
1365 memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1366 posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1367 continue;
1368 }
1369 if (lt.isa<LevelFormat::Batch>()) {
1370 // Skips batch levels as it is not linearized.
1371 // FIXME: this assumes that every batch has the same number of nse, need
1372 // to be generalized to handle varied-size batches.
1373 continue;
1374 }
1375
1376 if (isWithPosLT(lt)) {
1377 assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1378 if (isLooseCompressedLT(lt)) {
1379 memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1380 posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1381 } else {
1382 assert(isCompressedLT(lt));
1383 posBack = memSize;
1384 memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1385 }
1386 desc.setPosMemSize(builder&: rewriter, loc, lvl, v: memSize);
1387 // The last value in position array is the memory size for next level.
1388 // FIXME: this assumes that every batch has the same number of nse, need
1389 // to be generalized to handle varied-size batches.
1390 SmallVector<Value> batched(stt.getBatchLvlRank(),
1391 constantIndex(builder&: rewriter, loc, i: 0));
1392 batched.push_back(Elt: posBack);
1393 memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1394 posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1395 }
1396 assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1397 // FIXME: This seems to be unnecessarily complex, can we simplify it?
1398 if (lvl == trailCOOStart) {
1399 Value cooSz = rewriter.create<arith::MulIOp>(
1400 loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1401 desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: cooSz);
1402 } else {
1403 desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: memSize);
1404 }
1405 }
1406 desc.setValMemSize(builder&: rewriter, loc, v: memSize);
1407
1408 rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1409 return success();
1410 }
1411};
1412
1413struct SparseDisassembleOpConverter
1414 : public OpConversionPattern<DisassembleOp> {
1415 using OpConversionPattern::OpConversionPattern;
1416 SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1417 MLIRContext *context)
1418 : OpConversionPattern(typeConverter, context) {}
1419
1420 LogicalResult
1421 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1422 ConversionPatternRewriter &rewriter) const override {
1423 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1424 op.getTensor().getType());
1425 Location loc = op.getLoc();
1426 SmallVector<Value> retMem;
1427 SmallVector<Value> retLen;
1428 desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1429 &retLen](FieldIndex fid,
1430 SparseTensorFieldKind fKind,
1431 Level lvl, LevelType lt) -> bool {
1432 if (fKind == SparseTensorFieldKind::StorageSpec)
1433 return true;
1434 SparseTensorType stt(desc.getRankedTensorType());
1435 Value sz, src;
1436 TypedValue<BaseMemRefType> dst;
1437 if (fKind == SparseTensorFieldKind::ValMemRef) {
1438 sz = desc.getValMemSize(rewriter, loc);
1439 src = desc.getValMemRef();
1440 dst = genToMemref(rewriter, loc, op.getOutValues());
1441
1442 retMem.push_back(Elt: dst);
1443 Type valLenTp = op.getValLen().getType();
1444 retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: valLenTp));
1445 } else {
1446 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1447 fKind == SparseTensorFieldKind::CrdMemRef);
1448
1449 sz = fKind == SparseTensorFieldKind::PosMemRef
1450 ? desc.getPosMemSize(rewriter, loc, lvl)
1451 : desc.getCrdMemSize(rewriter, loc, lvl);
1452 src = desc.getMemRefField(fid);
1453 dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1454 retMem.push_back(Elt: dst);
1455 // Retrieves the corresponding level length type.
1456 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1457 retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: lvlLenTp));
1458 }
1459 Value flatOut = dst;
1460 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1461 auto reassoc =
1462 getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1463 flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1464 }
1465 Value dstMem = genSliceToSize(builder&: rewriter, loc, mem: flatOut, sz);
1466 Value srcMem = genSliceToSize(builder&: rewriter, loc, mem: src, sz);
1467 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1468 return true;
1469 });
1470
1471 // Converts MemRefs back to Tensors.
1472 SmallVector<Value> retValues = llvm::to_vector(
1473 Range: llvm::map_range(C&: retMem, F: [&rewriter, loc](Value v) -> Value {
1474 return rewriter.create<bufferization::ToTensorOp>(loc, v);
1475 }));
1476 // Appends the actual memory length used in each buffer returned.
1477 retValues.append(in_start: retLen.begin(), in_end: retLen.end());
1478 rewriter.replaceOp(op, retValues);
1479 return success();
1480 }
1481};
1482
1483struct SparseNewConverter : public OpConversionPattern<NewOp> {
1484 using OpConversionPattern::OpConversionPattern;
1485 LogicalResult
1486 matchAndRewrite(NewOp op, OpAdaptor adaptor,
1487 ConversionPatternRewriter &rewriter) const override {
1488 Location loc = op.getLoc();
1489 const auto dstTp = getSparseTensorType(op.getResult());
1490 // Creating COO with NewOp is handled by direct IR codegen. All other cases
1491 // are handled by rewriting.
1492 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1493 return failure();
1494
1495 // Implement as follows:
1496 // %reader = @createCheckedSparseTensorReader(%filename)
1497 // %nse = @getSparseTensorNSE(%reader)
1498 // %coo = bufferization.alloc_tensor an ordered COO with
1499 // dst dim ordering, size_hint = %nse
1500 // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1501 // %values = sparse_tensor.values(%coo)
1502 // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1503 // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1504 // update storage specifier
1505 // @delSparseTensorReader(%reader)
1506 SmallVector<Value> dimSizesValues;
1507 Value dimSizesBuffer;
1508 Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1509 dimSizesValues, dimSizesBuffer);
1510
1511 // Get the number of stored entries.
1512 const Type indexTp = rewriter.getIndexType();
1513 Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1514 {indexTp}, {reader}, EmitCInterface::Off)
1515 .getResult(0);
1516
1517 // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1518 SmallVector<Value> lvlSizesValues;
1519 Value dim2lvlBuffer;
1520 Value lvl2dimBuffer;
1521 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1522 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1523
1524 // Construct allocation for each field.
1525 Value sizeHint = nse;
1526 SmallVector<Value> fields;
1527 createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1528 lvlSizesValues, fields);
1529
1530 // Read the COO tensor data.
1531 MutSparseTensorDescriptor desc(dstTp, fields);
1532 Value xs = desc.getAOSMemRef();
1533 Value ys = desc.getValMemRef();
1534 const Type boolTp = rewriter.getIntegerType(1);
1535 const Type elemTp = dstTp.getElementType();
1536 const Type crdTp = dstTp.getCrdType();
1537 SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1538 overheadTypeFunctionSuffix(overheadTp: crdTp),
1539 primaryTypeFunctionSuffix(elemTp)};
1540 Value isSorted =
1541 createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1542 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1543 EmitCInterface::On)
1544 .getResult(0);
1545
1546 // If the destination tensor is a sorted COO, we need to sort the COO tensor
1547 // data if the input elements aren't sorted yet.
1548 const Level lvlRank = dstTp.getLvlRank();
1549 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1550 Value kFalse = constantI1(builder&: rewriter, loc, b: false);
1551 Value notSorted = rewriter.create<arith::CmpIOp>(
1552 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1553 scf::IfOp ifOp =
1554 rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
1555 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1556 auto xPerm = rewriter.getMultiDimIdentityMap(rank: lvlRank);
1557 rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1558 rewriter.getIndexAttr(0),
1559 SparseTensorSortKind::HybridQuickSort);
1560 rewriter.setInsertionPointAfter(ifOp);
1561 }
1562
1563 // Set PosMemRef0[1] = nse.
1564 const Value c1 = constantIndex(builder&: rewriter, loc, i: 1);
1565 const Value posMemref0 = desc.getPosMemRef(0);
1566 const Type posTp = dstTp.getPosType();
1567 const Value posNse = genCast(builder&: rewriter, loc, value: nse, dstTy: posTp);
1568 rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1569
1570 // Update storage specifier.
1571 Value coordinatesSize = rewriter.create<arith::MulIOp>(
1572 loc, nse, constantIndex(rewriter, loc, lvlRank));
1573 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1574 coordinatesSize);
1575 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1576 std::nullopt, nse);
1577
1578 // Release the sparse tensor reader.
1579 createFuncCall(builder&: rewriter, loc, name: "delSparseTensorReader", resultType: {}, operands: {reader},
1580 emitCInterface: EmitCInterface::Off);
1581
1582 // Replace operation with resulting memrefs.
1583 rewriter.replaceOpWithMultiple(op, {fields});
1584 return success();
1585 }
1586};
1587
1588struct SparseHasRuntimeLibraryConverter
1589 : public OpConversionPattern<HasRuntimeLibraryOp> {
1590 using OpConversionPattern::OpConversionPattern;
1591 LogicalResult
1592 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1593 ConversionPatternRewriter &rewriter) const override {
1594 auto i1Type = rewriter.getI1Type();
1595 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1596 op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1597 return success();
1598 }
1599};
1600
1601} // namespace
1602
1603//===----------------------------------------------------------------------===//
1604// Public method for populating conversion rules.
1605//===----------------------------------------------------------------------===//
1606
1607/// Populates the given patterns list with conversion rules required for
1608/// the sparsification of linear algebra operations.
1609void mlir::populateSparseTensorCodegenPatterns(
1610 const TypeConverter &typeConverter, RewritePatternSet &patterns,
1611 bool createSparseDeallocs, bool enableBufferInitialization) {
1612 patterns.add<
1613 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1614 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1615 SparseCastConverter, SparseExtractSliceConverter,
1616 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1617 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1618 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1619 StorageSpecifierKind::DimOffset>,
1620 SparseSliceGetterOpConverter<ToSliceStrideOp,
1621 StorageSpecifierKind::DimStride>,
1622 SparseToPositionsConverter, SparseToCoordinatesConverter,
1623 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1624 SparseConvertConverter, SparseNewConverter,
1625 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1626 typeConverter, patterns.getContext());
1627 patterns.add<SparseTensorDeallocConverter>(
1628 arg: typeConverter, args: patterns.getContext(), args&: createSparseDeallocs);
1629 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1630 arg: typeConverter, args: patterns.getContext(), args&: enableBufferInitialization);
1631}
1632

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp