1//===- Sparsification.cpp - Implementation of sparsification --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements converting sparse tensor types to actual sparse code.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenEnv.h"
14#include "Utils/CodegenUtils.h"
15#include "Utils/LoopEmitter.h"
16
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/Dialect/Linalg/IR/Linalg.h"
24#include "mlir/Dialect/Linalg/Utils/Utils.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/SCF/Transforms/Transforms.h"
28#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
29#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
30#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
31#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
32#include "mlir/Dialect/Tensor/IR/Tensor.h"
33#include "mlir/IR/AffineExprVisitor.h"
34#include "mlir/IR/Matchers.h"
35#include "mlir/IR/TensorEncoding.h"
36#include "llvm/ADT/SmallBitVector.h"
37
38#include <optional>
39
40using namespace mlir;
41using namespace mlir::sparse_tensor;
42
43//===----------------------------------------------------------------------===//
44// Sparsifier analysis methods.
45//===----------------------------------------------------------------------===//
46
47/// Returns true iff affine expression is invariant. Sets the
48/// parameter `isCurrentLoop` when expression just became invariant.
49static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
50 switch (a.getKind()) {
51 case AffineExprKind::DimId: {
52 const LoopId i = cast<AffineDimExpr>(Val&: a).getPosition();
53 if (i + 1 == curr) {
54 isCurrentLoop = true;
55 return true; // becomes invariant at current loop
56 }
57 return i < curr; // invariant when already generated
58 }
59 case AffineExprKind::Add:
60 case AffineExprKind::Mul: {
61 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
62 return isInvariantAffine(a: binOp.getLHS(), curr, isCurrentLoop) &&
63 isInvariantAffine(a: binOp.getRHS(), curr, isCurrentLoop);
64 }
65 default: {
66 assert(isa<AffineConstantExpr>(a));
67 return true;
68 }
69 }
70}
71
72/// Helper method to inspect affine expressions. Rejects cases where the
73/// same index is used more than once. Also rejects compound affine
74/// expressions in sparse dimensions.
75static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
76 LevelType lt, bool setLvlFormat = true) {
77 switch (a.getKind()) {
78 case AffineExprKind::DimId: {
79 const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition());
80 if (!isUndefLT(lt: merger.getLvlType(t: tid, i: idx)))
81 return false; // used more than once
82 if (setLvlFormat)
83 merger.setLevelAndType(t: tid, i: idx, lvl, lt);
84 return true;
85 }
86 case AffineExprKind::Add:
87 case AffineExprKind::Mul:
88 case AffineExprKind::Constant: {
89 assert(lt.hasDenseSemantic());
90 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: a)) {
91 // We do not set dim level format for affine expression like d0 + d1 on
92 // either loop index at d0 or d1. We continue the recursion merely to
93 // check whether current affine is admissible or not.
94 return findAffine(merger, tid, lvl, a: binOp.getLHS(), lt, setLvlFormat: false) &&
95 findAffine(merger, tid, lvl, a: binOp.getRHS(), lt, setLvlFormat: false);
96 }
97 // Falls through when it is a constant Affine
98 return true;
99 }
100 default:
101 return false;
102 }
103}
104
105/// Helper method to inspect affine expressions for index variable reduction
106/// based codegen. It finds the dependent index set for all tensor levels in the
107/// current expression we are generating.
108///
109/// For example, when handling A[i+j][j+k], we build the two way mapping in
110/// merger between (tensor, level) pairs and their dependent index variable set:
111/// A_0 <=> [i, j] and A_1 <=> [j, k]
112///
113/// It rejects cases (returns false)
114/// 1st, when the same index is used more than once, e.g., A[i+j][i]
115/// 2nd, when multiplication is used in the non-trivial index expression.
116/// 3rd, when a constant operand is used in the non-trivial index expression.
117///
118/// TODO: constant should be easy to handle.
119static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
120 AffineExpr a, LevelType lt, bool isSubExp = false,
121 int64_t coefficient = 1) {
122 switch (a.getKind()) {
123 case AffineExprKind::DimId: {
124 // Only allow positive coefficients on AffineDimExpr.
125 if (coefficient <= 0)
126 return false;
127
128 const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition());
129 if (!isUndefLT(lt: merger.getLvlType(t: tensor, i: idx)))
130 return false; // used more than once, e.g., A[i][i]
131
132 // TODO: Generalizes the following two cases. A[i] (with trivial index
133 // expression) can be treated as a special affine index expression. We do
134 // not necessarily need to differentiate them.
135 if (!isSubExp) {
136 assert(coefficient == 1);
137 merger.setLevelAndType(t: tensor, i: idx, lvl, lt);
138 }
139
140 if (isSubExp) {
141 // The current loops appears in more than one affine expressions on the
142 // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
143 // used twice.
144 if (merger.hasDependentLvl(i: idx, t: tensor)) {
145 // TODO: This can be supported by coiterate slices if the loop idx is
146 // appeared on affine index for different tensor, or take slice on
147 // multiple dimensions when it is on the same tensor.
148 // E.g.,
149 // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
150 // d0_1 = getNextSliceOffset t0 along lvl0
151 // d0_2 = getNextSliceOffset t1 along lvl0
152 // if d0_1 == d0_2 then d0 = d0_1 = d0_1
153 // else increase min(d0_1, d0_2).
154 return false;
155 }
156 merger.setLoopDependentTensorLevel(i: idx, t: tensor, lvl, lt, coefficient);
157 }
158 return true;
159 }
160 case AffineExprKind::Constant:
161 case AffineExprKind::Mul: {
162 // TODO: Support index expression like `2 * d0`, we now only support more
163 // complicated cases like `2 * d0 + d1`.
164 if (!isSubExp)
165 return false;
166
167 // TODO: Support Constant AffineExp for slice-based codegen
168 if (isa<AffineConstantExpr>(Val: a))
169 llvm_unreachable("Not yet implemented");
170
171 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
172 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
173 if (isa<AffineConstantExpr>(Val: rhs))
174 std::swap(a&: lhs, b&: rhs);
175 // Must be in form of `constant * d`.
176 assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
177 int64_t coefficient = cast<AffineConstantExpr>(Val&: lhs).getValue();
178 return findDepIdxSet(merger, tensor, lvl, a: rhs, lt, isSubExp, coefficient);
179 }
180 case AffineExprKind::Add: {
181 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
182 return findDepIdxSet(merger, tensor, lvl, a: binOp.getLHS(), lt, isSubExp: true) &&
183 findDepIdxSet(merger, tensor, lvl, a: binOp.getRHS(), lt, isSubExp: true);
184 }
185 default:
186 return false;
187 }
188}
189
190/// Gets the total number of compound affine expressions in the
191/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
192///
193/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
194///
195/// Returns 1 (because the first level is compressed and its corresponding
196/// indexing-expression is `d0 + d1`)
197static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
198 Value tensor) {
199 // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
200 // we can't use `getRankedTensorType`/`getSparseTensorType` here.
201 // However, we don't need to handle `StorageSpecifierType`, so we
202 // can use `SparseTensorType` once we guard against non-tensors.
203 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
204 if (!rtp)
205 return 0;
206 const SparseTensorType stt(rtp);
207
208 const Level lvlRank = stt.getLvlRank();
209 const auto exprs = map.getResults();
210 assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
211 "AffineMap does not have dimension-rank many results");
212 unsigned num = 0;
213 for (Level l = 0; l < lvlRank; l++) {
214 if (!isa<AffineDimExpr>(Val: exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
215 num++;
216 }
217 return num;
218}
219
220/// Gets the total number of sparse levels with compound affine
221/// expressions, summed over all operands of the `GenericOp`.
222static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
223 unsigned num = 0;
224 for (OpOperand &t : op->getOpOperands())
225 num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
226 t.get());
227 return num;
228}
229
230// Returns true iff output has nontrivial affine indices.
231static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
232 OpOperand *out = op.getDpsInitOperand(0);
233 if (getSparseTensorType(val: out->get()).isAllDense())
234 return false;
235 return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
236 out->get());
237}
238
239/// Helper method to inspect sparse encodings in the tensor types.
240/// Fills the per-dimension sparsity information for all tensors.
241/// Returns true if the sparse annotations and affine subscript
242/// expressions of all tensors are admissible. Returns false if
243/// no annotations are found or inadmissible constructs occur.
244/// We currently support two different ways to handle non-trivial index
245/// expression on sparse tensors, and they accept different affine expressions.
246/// When using dependent index reducton-based approach, it currently only
247/// supports affine addition index expression.
248static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
249 bool annotated = false;
250 for (OpOperand &t : env.op()->getOpOperands()) {
251 const TensorId tid = env.makeTensorId(t.getOperandNumber());
252 const auto map = env.op().getMatchingIndexingMap(&t);
253 const auto enc = getSparseTensorEncoding(t.get().getType());
254 if (enc)
255 annotated = true;
256 const Level lvlRank = map.getNumResults();
257 assert(!enc || lvlRank == enc.getLvlRank());
258 assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
259 // We only need to do index reduction if there is at least one
260 // non-trivial index expression on sparse levels. If all non-trivial
261 // index expression is on dense levels, we can efficiently rely on
262 // the random access to locate the element.
263 bool needIdxReduc =
264 enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
265 // If then current tensor being inspected requires affine index, it need
266 // to be sliced.
267 for (Level l = 0; l < lvlRank; l++) {
268 const AffineExpr a = map.getResult(l);
269 const LevelType lt = enc.getLvlType(l);
270 if (idxReducBased && needIdxReduc) {
271 if (!findDepIdxSet(env.merger(), tid, l, a, lt))
272 return false; // inadmissible affine expression
273 } else {
274 if (!findAffine(env.merger(), tid, l, a, lt))
275 return false; // inadmissible affine expression
276 }
277 }
278 }
279 return annotated;
280}
281
282//===----------------------------------------------------------------------===//
283// Sparsifier synthesis methods (statements and expressions).
284//===----------------------------------------------------------------------===//
285
286/// Local bufferization of all dense and sparse data structures.
287static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288 linalg::GenericOp op = env.op();
289 Location loc = op.getLoc();
290 assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
291
292 SmallVector<Range, 4> loopRange =
293 llvm::cast<linalg::LinalgOp>(op.getOperation())
294 .createLoopRanges(builder, loc);
295
296 env.emitter().initializeLoopEmit(
297 builder, loc,
298 /// Generates buffer for the output tensor.
299 /// Note that all sparse kernels assume that when all elements are written
300 /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
301 /// to all zeroes and only nonzeroes values are computed and written out.
302 /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
303 /// for the updates and no assumption on the original contents of the
304 /// output buffer is necessary.
305 [&op](OpBuilder &builder, Location loc, Value memref,
306 Value tensor) -> Value {
307 // Must not be a sparse tensor.
308 assert(!getSparseTensorEncoding(tensor.getType()));
309 // Two output tensor references should point to the same object.
310 OpOperand *lhs = op.getDpsInitOperand(0);
311 assert(lhs->get() == tensor);
312 // An output tensor can simply materialize from the buffer of the tensor
313 // that appears in the outs() clause. For updates, this has the
314 // advantage that only the nonzero value are involved in the
315 // computation, keeping the operation O(nnz). In all other cases, we are
316 // forced to zero out the buffer to enforce the assumption above, which
317 // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
318 // O(nnz) for matrices).
319 // TODO: use better analysis to avoid zeroing out the buffer?
320 bool isInit = op.isInitTensor(lhs);
321 Value init = memref;
322 if (!isInit) {
323 Value zero = constantZero(builder, loc,
324 tp: getElementTypeOrSelf(type: tensor.getType()));
325 builder.create<linalg::FillOp>(loc, ValueRange{zero},
326 ValueRange{init});
327 }
328 return init;
329 },
330 [&loopRange](OpBuilder &b, Location loc, Level l) {
331 assert(l < loopRange.size());
332 return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
333 });
334}
335
336/// Generates index for load/store on sparse tensor.
337static Value genIndex(CodegenEnv &env, OpOperand *t) {
338 const auto map = env.op().getMatchingIndexingMap(t);
339 const auto stt = getSparseTensorType(val: t->get());
340 const Level lvlRank = stt.getLvlRank();
341 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
342 const AffineExpr a = map.getResult(lvlRank - 1);
343 assert(a.getKind() == AffineExprKind::DimId);
344 const LoopId idx = env.makeLoopId(i: cast<AffineDimExpr>(Val: a).getPosition());
345 return env.getLoopVar(i: idx);
346}
347
348/// Generates subscript for load/store on a dense or sparse tensor.
349static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
350 SmallVectorImpl<Value> &args) {
351 const Location loc = env.op().getLoc();
352 const TensorId tid = env.makeTensorId(t: t->getOperandNumber());
353 const auto map = env.op().getMatchingIndexingMap(t);
354 const auto stt = getSparseTensorType(val: t->get());
355 if (stt.hasEncoding()) {
356 // For sparse tensors we only push the last-level's position onto `args`.
357 const auto pos = env.emitter().getValPosits(tid);
358 assert(!pos.empty());
359 args.append(RHS: pos);
360 // Simply returns the tensor to extract value using iterators.
361 if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator)
362 return t->get();
363 } else {
364 // For dense tensors we push all level's coordinates onto `args`.
365 const Level lvlRank = stt.getLvlRank();
366 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
367 for (Level l = 0; l < lvlRank; l++) {
368 const auto lvlExpr = map.getResult(l);
369 const auto lvlCrd = env.emitter().genAffine(builder, loc, a: lvlExpr);
370 args.push_back(Elt: lvlCrd);
371 }
372 }
373 return env.emitter().getValBuffer()[tid];
374}
375
376/// Generates insertion code to implement dynamic tensor load.
377static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
378 OpOperand *t) {
379 linalg::GenericOp op = env.op();
380 Location loc = op.getLoc();
381 // Direct lexicographic coordinate order, tensor loads as zero.
382 if (!env.isExpand()) {
383 Type tp = getElementTypeOrSelf(type: t->get().getType());
384 return constantZero(builder, loc, tp);
385 }
386 // Load from expanded access pattern.
387 Value index = genIndex(env, t);
388 return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
389}
390
391/// Generates insertion code to implement dynamic tensor load for reduction.
392static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
393 OpOperand *t) {
394 linalg::GenericOp op = env.op();
395 Location loc = op.getLoc();
396 Value identity = env.getCustomRedId();
397 // Direct lexicographic coordinate order, tensor loads as identity.
398 if (!env.isExpand())
399 return identity;
400 // Load from expanded access pattern if filled, identity otherwise.
401 Value values = env.getExpandValues();
402 Value filled = env.getExpandFilled();
403 Value index = genIndex(env, t);
404 Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
405 Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
406 return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
407}
408
409static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
410 Value sparseOut, ValueRange ivs, Value v) {
411 scf::IfOp condInsert =
412 builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
413 // True branch.
414 builder.setInsertionPointToStart(condInsert.thenBlock());
415 Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
416 builder.create<scf::YieldOp>(loc, res);
417 // False branch.
418 builder.setInsertionPointToStart(condInsert.elseBlock());
419 builder.create<scf::YieldOp>(loc, sparseOut);
420 // Value assignment.
421 builder.setInsertionPointAfter(condInsert);
422 return condInsert.getResult(0);
423}
424
425/// Generates insertion code to implement dynamic tensor store.
426static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
427 Value rhs) {
428 linalg::GenericOp op = env.op();
429 Location loc = op.getLoc();
430 // Direct insertion in lexicographic coordinate order.
431 if (!env.isExpand()) {
432 const LoopId numLoops = op.getRank(t);
433 // Retrieves the first `numLoop` induction variables.
434 SmallVector<Value> ivs = llvm::to_vector(Range: llvm::drop_end(
435 RangeOrContainer: env.emitter().getLoopIVsRange(), N: env.getCurrentDepth() - numLoops));
436 Value chain = env.getInsertionChain();
437 if (env.isValidLexInsert()) {
438 // Generates runtime check for a valid lex during reduction,
439 // to avoid inserting the identity value for empty reductions.
440 // if (validLexInsert) then
441 // insert(rhs) into chain
442 // return updated chain
443 // else
444 // return unmodified chain
445 Value out = genConditionalInsert(loc, builder, cond: env.getValidLexInsert(),
446 sparseOut: chain, ivs, v: rhs);
447 env.updateInsertionChain(chain: out);
448 } else {
449 Value sparseOut;
450 if (!hasAnySparseType(env.op().getInputs().getTypes())) {
451 // This is an all-dense -> sparse kernel, test rhs != 0 before
452 // insertion.
453 Value nz = genIsNonzero(builder, loc, v: rhs);
454 sparseOut = genConditionalInsert(loc, builder, cond: nz, sparseOut: chain, ivs, v: rhs);
455 } else {
456 sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
457 }
458 // Generates regular insertion chain.
459 env.updateInsertionChain(chain: sparseOut);
460 }
461 return;
462 }
463 // Generates insertion code along expanded access pattern.
464 // if (!expFilled[i]) then
465 // expFilled[i] = true
466 // expAdded[inserts++] = i
467 // endif
468 // values[i] = rhs
469 Value values = env.getExpandValues();
470 Value filled = env.getExpandFilled();
471 Value added = env.getExpandAdded();
472 Value count = env.getExpandCount();
473 Value index = genIndex(env, t);
474 Value fval = constantI1(builder, loc, b: false);
475 Value tval = constantI1(builder, loc, b: true);
476 // If statement.
477 Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
478 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
479 isFilled, fval);
480 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
481 /*else=*/true);
482 // True branch.
483 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
484 builder.create<memref::StoreOp>(loc, tval, filled, index);
485 builder.create<memref::StoreOp>(loc, index, added, count);
486 Value one = constantIndex(builder, loc, i: 1);
487 Value add = builder.create<arith::AddIOp>(loc, count, one);
488 builder.create<scf::YieldOp>(loc, add);
489 // False branch.
490 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
491 builder.create<scf::YieldOp>(loc, count);
492 builder.setInsertionPointAfter(ifOp);
493 // Value assignment.
494 env.updateExpandCount(count: ifOp.getResult(0));
495 builder.create<memref::StoreOp>(loc, rhs, values, index);
496}
497
498/// Generates a load on a dense or sparse tensor.
499static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
500 // Test if the load was hoisted to a higher loop nest.
501 Value val = env.exp(e: exp).val;
502 if (val)
503 return val;
504 // Get tensor operand.
505 linalg::GenericOp op = env.op();
506 Location loc = op.getLoc();
507 OpOperand *t = &op->getOpOperand(env.exp(e: exp).tensor);
508 // Fold binary-valued tensor into explicit value.
509 const auto stt = getSparseTensorType(val: t->get());
510 if (auto explVal = stt.getExplicitVal())
511 return genValFromAttr(builder, loc, explVal);
512 // Load during insertion.
513 if (env.isSparseOutput(o: t)) {
514 if (env.isCustomReduc())
515 return genInsertionLoadReduce(env, builder, t);
516 return genInsertionLoad(env, builder, t);
517 }
518
519 // Actual load.
520 SmallVector<Value> args;
521 Value ptr = genSubscript(env, builder, t, args);
522 if (llvm::isa<TensorType>(Val: ptr.getType())) {
523 assert(env.options().sparseEmitStrategy ==
524 SparseEmitStrategy::kSparseIterator);
525 return builder.create<ExtractValOp>(loc, ptr, llvm::getSingleElement(args));
526 }
527 return builder.create<memref::LoadOp>(loc, ptr, args);
528}
529
530/// Generates a store on a dense or sparse tensor.
531static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
532 Value rhs) {
533 // Only unary and binary are allowed to return an uninitialized rhs
534 // to indicate missing output. Or otherwise a custom reduction that
535 // received no value to accumulate.
536 if (!rhs) {
537 assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
538 env.exp(exp).kind == TensorExp::Kind::kBinary ||
539 env.exp(exp).kind == TensorExp::Kind::kReduce);
540 return;
541 }
542 // Test if this is a scalarized reduction.
543 if (env.isReduc()) {
544 env.updateReduc(val: rhs);
545 return;
546 }
547 // Regular store.
548 linalg::GenericOp op = env.op();
549 Location loc = op.getLoc();
550 OpOperand *t = op.getDpsInitOperand(0);
551 if (!env.isSparseOutput(o: t)) {
552 SmallVector<Value> args;
553 Value ptr = genSubscript(env, builder, t, args);
554 builder.create<memref::StoreOp>(loc, rhs, ptr, args);
555 return;
556 }
557 // Store during sparse insertion.
558 if (env.exp(e: exp).kind != TensorExp::Kind::kSelect) {
559 genInsertionStore(env, builder, t, rhs);
560 return;
561 }
562 // Select operation insertion.
563 Value chain = env.getInsertionChain();
564 scf::IfOp ifOp =
565 builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
566 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
567 // Existing value was preserved to be used here.
568 assert(env.exp(exp).val);
569 Value v0 = env.exp(e: exp).val;
570 genInsertionStore(env, builder, t, rhs: v0);
571 env.merger().clearExprValue(e: exp);
572 // Yield modified insertion chain along true branch.
573 Value mchain = env.getInsertionChain();
574 builder.create<scf::YieldOp>(op.getLoc(), mchain);
575 // Yield original insertion chain along false branch.
576 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
577 builder.create<scf::YieldOp>(loc, chain);
578 // Done with if statement.
579 env.updateInsertionChain(chain: ifOp->getResult(0));
580 builder.setInsertionPointAfter(ifOp);
581}
582
583/// Generates an invariant value.
584inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
585 return env.exp(e: exp).val;
586}
587
588/// Semi-ring branches are simply inlined by the sparsifier. Prior
589/// analysis has verified that all computations are "local" to the inlined
590/// branch or otherwise invariantly defined outside the loop nest, with the
591/// exception of index computations, which need to be relinked to actual
592/// inlined cloned code.
593static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
594 Value e) {
595 if (auto arg = dyn_cast<BlockArgument>(Val&: e)) {
596 // Direct arguments of the original linalg op must be converted
597 // into dense tensor loads. Note that we should not encounter
598 // anything else. This needs to be verified by semi-ring ops.
599 linalg::GenericOp op = env.op();
600 if (arg.getOwner()->getParentOp() == op) {
601 const TensorId tid = env.makeTensorId(t: arg.getArgNumber());
602 OpOperand *t = &op->getOpOperand(tid);
603 assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
604 SmallVector<Value> args;
605 Value ptr = genSubscript(env, builder&: rewriter, t, args);
606 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
607 }
608 } else if (Operation *def = e.getDefiningOp()) {
609 // Handle index computation.
610 if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
611 return env.getLoopVar(i: env.makeLoopId(i: indexOp.getDim()));
612 // When still defined in new body, recurse into operands.
613 if (def->getBlock() == block) {
614 rewriter.setInsertionPoint(def);
615 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
616 rewriter.modifyOpInPlace(root: def, callable: [&]() {
617 def->setOperand(
618 idx: i, value: relinkBranch(env, rewriter, block, e: def->getOperand(idx: i)));
619 });
620 }
621 }
622 }
623 return e;
624}
625
626/// Recursively generates tensor expression.
627static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
628 if (e == ::mlir::sparse_tensor::detail::kInvalidId)
629 return Value();
630
631 linalg::GenericOp op = env.op();
632 Location loc = op.getLoc();
633 const TensorExp &exp = env.exp(e);
634 const auto kind = exp.kind;
635 if (kind == TensorExp::Kind::kTensor)
636 return genTensorLoad(env, builder&: rewriter, exp: e);
637 if (kind == TensorExp::Kind::kInvariant)
638 return genInvariantValue(env, exp: e);
639 if (kind == TensorExp::Kind::kLoopVar)
640 return env.getLoopVar(i: exp.loop);
641
642 if (kind == TensorExp::Kind::kReduce)
643 env.startCustomReduc(exp: e); // enter custom
644
645 // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
646 // based on the type of the other operand.
647 Value v0, v1;
648 if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
649 env.exp(e: exp.children.e0).kind == TensorExp::Kind::kSynZero) {
650 v1 = genExp(env, rewriter, e: exp.children.e1);
651 v0 = constantZero(builder&: rewriter, loc, tp: v1.getType());
652 } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
653 env.exp(e: exp.children.e1).kind == TensorExp::Kind::kSynZero) {
654 v0 = genExp(env, rewriter, e: exp.children.e0);
655 v1 = constantZero(builder&: rewriter, loc, tp: v0.getType());
656 } else {
657 v0 = genExp(env, rewriter, e: exp.children.e0);
658 v1 = genExp(env, rewriter, e: exp.children.e1);
659 }
660
661 Value ee;
662 if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
663 // custom reduce did not receive a value
664 } else {
665 ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
666 if (ee &&
667 (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
668 kind == TensorExp::Kind::kBinaryBranch ||
669 kind == TensorExp::Kind::kReduce ||
670 kind == TensorExp::Kind::kSelect)) {
671 OpBuilder::InsertionGuard guard(rewriter);
672 ee = relinkBranch(env, rewriter, block: ee.getParentBlock(), e: ee);
673 }
674 }
675
676 if (kind == TensorExp::Kind::kReduce)
677 env.endCustomReduc(); // exit custom
678
679 if (kind == TensorExp::Kind::kSelect)
680 env.merger().setExprValue(e, v: v0); // Preserve value for later use.
681
682 return ee;
683}
684
685/// Hoists loop invariant tensor loads for which indices have been exhausted.
686static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
687 LoopId curr, bool isStart) {
688 if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
689 return;
690 if (env.exp(e: exp).kind == TensorExp::Kind::kTensor) {
691 // Inspect tensor indices.
692 linalg::GenericOp op = env.op();
693 OpOperand &t = op->getOpOperand(env.exp(e: exp).tensor);
694 const auto map = op.getMatchingIndexingMap(&t);
695 const auto stt = getSparseTensorType(val: t.get());
696 const Level lvlRank = stt.getLvlRank();
697 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
698 bool isCurrentLoop = curr == 0; // for scalar tensors
699 for (Level l = 0; l < lvlRank; l++) {
700 const AffineExpr a = map.getResult(l);
701 if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
702 return; // still in play
703 }
704 // All exhausted at current level.
705 if (!isCurrentLoop)
706 return;
707 // Generate code for a scalarized reduction or invariant. Note that
708 // because custom reduction lhs may occur several times in the IR,
709 // we have a built-in safety for only initializing and wrapping-up
710 // the scalarized reduction once.
711 OpOperand *lhs = op.getDpsInitOperand(0);
712 if (lhs == &t) {
713 // Start or end a scalarized reduction.
714 if (isStart) {
715 if (env.isCustomReduc()) {
716 if (!env.isReduc())
717 env.startReduc(exp, val: env.getCustomRedId());
718 } else {
719 env.startReduc(exp, val: genTensorLoad(env, builder, exp));
720 }
721 if (env.hasSparseOutput())
722 env.startValidLexInsert(
723 val: constantI1(builder, env.op().getLoc(), false));
724 } else {
725 if (!env.isCustomReduc() || env.isReduc())
726 genTensorStore(env, builder, exp, rhs: env.endReduc());
727 if (env.hasSparseOutput())
728 env.endValidLexInsert();
729 }
730 } else {
731 // Start or end loop invariant hoisting of a tensor load.
732 if (isStart) {
733 env.merger().setExprValue(e: exp, v: genTensorLoad(env, builder, exp));
734 } else {
735 env.merger().clearExprValue(e: exp);
736 }
737 }
738 } else if (env.exp(e: exp).kind != TensorExp::Kind::kInvariant &&
739 env.exp(e: exp).kind != TensorExp::Kind::kLoopVar &&
740 env.exp(e: exp).kind != TensorExp::Kind::kSynZero) {
741 // Traverse into the binary operations. Note that we only hoist
742 // tensor loads, since subsequent MLIR/LLVM passes know how to
743 // deal with all other kinds of derived loop invariants.
744 if (env.exp(e: exp).kind == TensorExp::Kind::kReduce)
745 env.startCustomReduc(exp); // enter custom
746 const ExprId e0 = env.exp(e: exp).children.e0;
747 const ExprId e1 = env.exp(e: exp).children.e1;
748 genInvariants(env, builder, exp: e0, curr, isStart);
749 genInvariants(env, builder, exp: e1, curr, isStart);
750 if (env.exp(e: exp).kind == TensorExp::Kind::kReduce)
751 env.endCustomReduc(); // exit custom
752 }
753}
754
755/// Generates an expanded access pattern in innermost dimension.
756static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
757 bool isStart) {
758 linalg::GenericOp op = env.op();
759 OpOperand *lhs = op.getDpsInitOperand(0);
760 if (!env.atExpandLevel(o: lhs, rank: op.getRank(lhs), n: curr))
761 return; // not needed at current level
762 assert(!env.isReduc());
763 // Generate start or end of an expanded access pattern. Note that because
764 // an expansion does not rely on the ongoing contents of the sparse storage
765 // scheme, we can use the original tensor as incoming SSA value (which
766 // simplifies codegen a bit). If expansion on the actual contents is ever
767 // needed, we will need to use the SSA value in the insertion chain instead.
768 Value tensor = lhs->get();
769 Location loc = op.getLoc();
770 if (isStart) {
771 auto dynShape = {ShapedType::kDynamic};
772 Type etp = cast<ShapedType>(tensor.getType()).getElementType();
773 Type t1 = MemRefType::get(dynShape, etp);
774 Type t2 = MemRefType::get(dynShape, builder.getI1Type());
775 Type t3 = MemRefType::get(dynShape, builder.getIndexType());
776 Type t4 = builder.getIndexType();
777 auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
778 assert(r.getNumResults() == 4);
779 env.startExpand(values: r.getResult(0), filled: r.getResult(1), added: r.getResult(2),
780 count: r.getResult(3));
781 } else {
782 SmallVector<Value> indices;
783 for (LoopId i = 0; i < curr; i++)
784 indices.push_back(Elt: env.emitter().getLoopIV(n: i));
785 Value values = env.getExpandValues();
786 Value filled = env.getExpandFilled();
787 Value added = env.getExpandAdded();
788 Value count = env.getExpandCount();
789 Value chain = env.getInsertionChain();
790 Value compress = builder.create<CompressOp>(loc, values, filled, added,
791 count, chain, indices);
792 env.updateInsertionChain(chain: compress);
793 env.endExpand();
794 }
795}
796
797/// Returns parallelization strategy. Any implicit loop in the Linalg
798/// operation that is marked "parallel" is a candidate. Whether it is actually
799/// converted to a parallel operation depends on the requested strategy.
800static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
801 // Reject parallelization of sparse output.
802 if (env.hasSparseOutput())
803 return false;
804 // Parallel loops on tensor expansion can cause data races.
805 if (env.isExpand())
806 return false;
807 // Inspect strategy.
808 switch (env.options().parallelizationStrategy) {
809 case SparseParallelizationStrategy::kNone:
810 return false;
811 case SparseParallelizationStrategy::kDenseOuterLoop:
812 return isOuter && !isSparse;
813 case SparseParallelizationStrategy::kAnyStorageOuterLoop:
814 return isOuter;
815 case SparseParallelizationStrategy::kDenseAnyLoop:
816 return !isSparse;
817 case SparseParallelizationStrategy::kAnyStorageAnyLoop:
818 return true;
819 }
820 llvm_unreachable("unexpected parallelization strategy");
821}
822
823/// Whether or not the current loop being generated should be parallized (if
824/// possible) according to the configuration.
825static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
826 ArrayRef<TensorLevel> tidLvls) {
827 linalg::GenericOp op = env.op();
828 auto iteratorTypes = op.getIteratorTypesArray();
829 bool isSparse = llvm::any_of(Range&: tidLvls, P: [curr, &env](TensorLevel tidLvl) {
830 // Queries the LT based on the tensor and loop id, as requested by
831 // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
832 // should be consistent with the LT indexed by <TensorId, Level>.
833 const auto lt = env.lt(t: env.unpackTensorLevel(tl: tidLvl).first, i: curr);
834 return lt.hasSparseSemantic();
835 });
836 return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
837}
838
839/// Emit a loop to coiterate over the list of tensor levels. The generated loop
840/// can either be a for loop or while loop depending on whether there is at most
841/// one sparse level in the list.
842static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
843 ArrayRef<TensorLevel> tidLvls,
844 unsigned numCases, bool tryParallel,
845 bool needsUniv) {
846 Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
847 // Construct while-loop with a parameter for each index.
848 return env.emitter().enterCoIterationOverTensorsAtLvls(
849 builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
850 needsUniv);
851 });
852 assert(loop);
853 return loop;
854}
855
856/// Generates a for-loop or a while-loop, depending on whether it implements
857/// singleton iteration or co-iteration over the given conjunction.
858static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
859 unsigned numCases, bool needsUniv,
860 ArrayRef<TensorLevel> tidLvls) {
861 bool tryParallel = shouldTryParallize(env, curr, tidLvls);
862 return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
863 needsUniv);
864}
865
866/// Generates the induction structure for a while-loop.
867static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
868 bool needsUniv) {
869 Location loc = env.op().getLoc();
870 // Finalize each else branch of all if statements.
871 if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
872 while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
873 builder.getInsertionBlock()->getParentOp())) {
874 // Break on IfOp for slicing filtering.
875 if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
876 StringAttr::get(ifOp->getContext(), "slice"))
877 break;
878
879 unsigned y = 0;
880 SmallVector<Value> yields;
881 if (env.isReduc()) {
882 yields.push_back(Elt: env.getReduc());
883 env.updateReduc(val: ifOp.getResult(y++));
884 if (env.isValidLexInsert()) {
885 yields.push_back(Elt: env.getValidLexInsert());
886 env.updateValidLexInsert(val: ifOp.getResult(y++));
887 }
888 }
889 if (env.isExpand()) {
890 yields.push_back(Elt: env.getExpandCount());
891 env.updateExpandCount(count: ifOp->getResult(y++));
892 }
893 if (env.getInsertionChain()) {
894 yields.push_back(Elt: env.getInsertionChain());
895 env.updateInsertionChain(chain: ifOp->getResult(y++));
896 }
897 assert(y == yields.size());
898 builder.create<scf::YieldOp>(loc, yields);
899 builder.setInsertionPointAfter(ifOp);
900 }
901 }
902 // No need to set the insertion point here as LoopEmitter keeps track of the
903 // basic block where scf::Yield should be inserted.
904}
905
906/// Generates a case region in the coiterate operation.
907static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
908 unsigned caseIdx, LatPointId allCase,
909 LatPointId curCase,
910 MutableArrayRef<Value> reduc) {
911 assert(allCase == curCase || env.merger().latGT(allCase, curCase));
912 const BitVector &allCaseBits = env.merger().lat(p: allCase).simple;
913 const BitVector &curCaseBits = env.merger().lat(p: curCase).simple;
914
915 /// Computes the subset of iterators that are valid in the current case being
916 /// generated.
917 I64BitSet caseBit(0);
918 for (auto [idx, set] : llvm::enumerate(First: allCaseBits.set_bits()))
919 if (curCaseBits.test(Idx: set))
920 caseBit.set(idx);
921
922 env.emitter().enterCurrentCoIterationCase(builder, loc: env.op().getLoc(), caseBit,
923 caseIdx, reduc);
924}
925
926/// Generates a single if-statement within a while-loop.
927static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
928 LatPointId p) {
929 Location loc = env.op().getLoc();
930 SmallVector<Type> types;
931 Value cond;
932 env.merger().foreachTensorLoopId(
933 p, /*simple=*/true,
934 callback: [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
935 bool isIdxRed) {
936 if (isIdxRed) {
937 // Since there is no 1:1 mapping from loop to level (multiple loops
938 // are required to resolve one level with non-trivial index
939 // expression), we need to reconstruct the tensor level types if this
940 // loop requires index reduction condition.
941 assert(lvl.has_value() && isUndefLT(lt));
942 auto stt = getSparseTensorType(env.op().getInputs()[tid]);
943 lt = stt.getLvlType(*lvl);
944 }
945 assert(curr == env.merger().loop(b));
946 Value clause;
947 if (lt.hasSparseSemantic()) {
948 assert(lvl.has_value());
949 const Value crd = env.emitter().getCoord(tid, lvl: *lvl);
950 const Value lvar = env.getLoopVar(i: curr);
951 clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
952 crd, lvar);
953 } else {
954 assert(lt.hasDenseSemantic() || isUndefLT(lt));
955 clause = constantI1(builder, loc, b: true);
956 }
957 cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
958 });
959 if (env.isReduc()) {
960 types.push_back(Elt: env.getReduc().getType());
961 if (env.isValidLexInsert())
962 types.push_back(Elt: env.getValidLexInsert().getType());
963 }
964 if (env.isExpand())
965 types.push_back(builder.getIndexType());
966 if (env.getInsertionChain())
967 types.push_back(Elt: env.getInsertionChain().getType());
968 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
969 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
970 return ifOp;
971}
972
973/// Generates end of true branch of if-statement within a while-loop.
974static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
975 Value redInput, Value cntInput, Value insInput,
976 Value validIns) {
977 SmallVector<Value> operands;
978 if (env.isReduc()) {
979 operands.push_back(Elt: env.getReduc());
980 env.updateReduc(val: redInput);
981 if (env.isValidLexInsert()) {
982 // Any overlapping indices during a reduction creates a valid lex insert.
983 operands.push_back(Elt: constantI1(builder, env.op().getLoc(), true));
984 env.updateValidLexInsert(val: validIns);
985 }
986 }
987 if (env.isExpand()) {
988 operands.push_back(Elt: env.getExpandCount());
989 env.updateExpandCount(count: cntInput);
990 }
991 if (env.getInsertionChain()) {
992 operands.push_back(Elt: env.getInsertionChain());
993 env.updateInsertionChain(chain: insInput);
994 }
995 if (!operands.empty())
996 builder.create<scf::YieldOp>(env.op().getLoc(), operands);
997 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
998}
999
1000//===----------------------------------------------------------------------===//
1001// Sparsifier synthesis methods (loop sequence).
1002//===----------------------------------------------------------------------===//
1003
1004static bool getAllTidLvlsInLatPoints(
1005 CodegenEnv &env, LatPointId li, LoopId curr,
1006 llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1007 const BitVector &simple = env.lat(l: li).simple;
1008 const TensorId outTid = env.merger().getOutTensorID();
1009 const std::optional<Level> outLvl = env.merger().getLvl(t: outTid, i: curr);
1010
1011 unsigned numloopCond = 0;
1012 bool hasNonUnique = false;
1013 env.merger().foreachTensorLoopId(
1014 p: li, callback: [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1015 LevelType lt, bool isIdxReduc) {
1016 if (simple[b]) {
1017 if (isIdxReduc) {
1018 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
1019 numloopCond++;
1020 return;
1021 }
1022 if (isUndefLT(lt)) {
1023 // An undefined lt in the lattices, we probably mean to
1024 // generate a dense loop according to the synthetic tensor (for
1025 // invariants and sparse output tensor).
1026 if (env.merger().getSynTensorID() == tid) {
1027 // Coiterating with an invariant
1028 // e.g., out = prod(in[i][j] op invariant);
1029 // or a broadcast
1030 // e.g., out[i][j] = in[i] (j is undef for input)
1031 //
1032 // The level of the synthetic tensor is the current loop depth;
1033 // the rank of the synthetic tensor equals to number of loops.
1034 assert(curr == env.getCurrentDepth());
1035 lvl = curr;
1036 } else if (!lvl) {
1037 // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1038 return;
1039 }
1040 }
1041 hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1042 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
1043 numloopCond++;
1044 } else if (lt.hasDenseSemantic() || isIdxReduc) {
1045 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
1046 } else {
1047 assert(isUndefLT(lt));
1048 linalg::GenericOp op = env.op();
1049 if (tid >= op.getNumDpsInputs())
1050 // We only handle affine expression on input tensors (for now).
1051 return;
1052 OpOperand *operand = &op->getOpOperand(tid);
1053 const auto stt = getSparseTensorType(val: operand->get());
1054 // Non-annotated dense tensors requires no special handling.
1055 if (!stt.hasEncoding())
1056 return;
1057
1058 ArrayRef<AffineExpr> affines =
1059 op.getMatchingIndexingMap(operand).getResults();
1060 const Level lvlRank = stt.getLvlRank();
1061 assert(affines.size() == static_cast<size_t>(lvlRank));
1062 for (Level l = 0; l < lvlRank; l++) {
1063 AffineExpr exp = affines[l];
1064 // Skip simple affine expression and non-dense levels (which
1065 // have their own filter loop).
1066 LevelType lt = stt.getLvlType(l);
1067 if (isa<AffineDimExpr>(Val: exp) || !lt.hasDenseSemantic())
1068 continue;
1069
1070 // Constant affine expression are handled in genLoop.
1071 if (!isa<AffineConstantExpr>(Val: exp)) {
1072 bool isCurrentLoop = false;
1073 assert(curr == env.getCurrentDepth());
1074 if (isInvariantAffine(a: exp, curr: curr + 1, /*out*/ isCurrentLoop) &&
1075 isCurrentLoop) {
1076 // If the compound affine is invariant and we are right at the
1077 // level. We need to generate the address according to the
1078 // affine expression. This is also the best place we can do it
1079 // to avoid putting it inside inner loops.
1080 callback(env.makeTensorLevel(t: tid, l), exp);
1081 }
1082 }
1083 }
1084 }
1085 });
1086
1087 if (isDenseLT(lt: env.lt(t: outTid, i: curr))) {
1088 auto stt = getSparseTensorType(env.op().getOutputs().front());
1089 // Note that we generate dense indices of the output tensor unconditionally,
1090 // since they may not appear in the lattice, but may be needed for
1091 // linearized env.
1092 // TODO: we should avoid introducing corner cases for all-dense sparse
1093 // tensors.
1094 if (stt.hasEncoding() && stt.isAllDense())
1095 callback(env.makeTensorLevel(t: outTid, l: *outLvl), nullptr);
1096 }
1097
1098 if (numloopCond == 0) {
1099 // Corner cases where the loop bound is defined by a *unused* operand, in
1100 // this case, we just generate a dense "fake" loop by iterating over the
1101 // synthetic tensor.
1102 callback(env.makeTensorLevel(t: env.merger().getSynTensorID(), l: curr), nullptr);
1103 numloopCond++;
1104 }
1105 // If we just need to one loop conditions and the conditions is not imposed on
1106 // non-unique level, the loop can be generated by a for loop.
1107 // Or, if we are generating sparse-iterator-based loops, we always generate
1108 // `sparse_tensor.iterate` regardless whether the level is unique or not.
1109 return numloopCond == 1 &&
1110 (!hasNonUnique || env.options().sparseEmitStrategy ==
1111 SparseEmitStrategy::kSparseIterator);
1112}
1113
1114/// Starts a loop sequence at given level. Returns true if
1115/// the universal loop index must be maintained at this level.
1116static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1117 LoopId curr, LatSetId lts) {
1118 assert(!env.getLoopVar(curr));
1119 // Emit invariants at this loop sequence level.
1120 genInvariants(env, builder, exp, curr, /*isStart=*/true);
1121 // Emit access pattern expansion for sparse tensor output.
1122 genExpand(env, builder, curr, /*isStart=*/true);
1123 // Emit further initialization at this loop sequence level.
1124 const LatPointId l0 = env.set(lts)[0];
1125
1126 SmallVector<TensorLevel> tidLvls;
1127 getAllTidLvlsInLatPoints(env, li: l0, curr, callback: [&](TensorLevel tl, AffineExpr) {
1128 // TODO: remove this! The same tensor level might be added for multiple
1129 // times due to the special handling for all-dense "sparse" output tensor
1130 // (see L1038).
1131 if (llvm::is_contained(Range&: tidLvls, Element: tl))
1132 return;
1133 tidLvls.emplace_back(Args&: tl);
1134 });
1135
1136 env.emitter().enterNewLoopSeq(builder, loc: env.op().getLoc(), tidLvls);
1137
1138 // Maintain the universal index only if it is actually
1139 // consumed by a subsequent lattice point.
1140 for (const LatPointId li : env.set(lts).drop_front())
1141 if (!env.merger().hasAnySparse(bits: env.lat(l: li).simple))
1142 return true;
1143
1144 return false;
1145}
1146
1147// Generates dense affine address for encoding.
1148static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1149 OpBuilder &builder, TensorId tid,
1150 Level startLvl) {
1151 // TODO: Handle affine expression on output tensor.
1152 linalg::GenericOp op = env.op();
1153 assert(tid < op.getNumDpsInputs());
1154 OpOperand *input = op.getDpsInputOperands()[tid];
1155 const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1156 const auto enc = getSparseTensorEncoding(input->get().getType());
1157 if (enc) {
1158 const Location loc = op.getLoc();
1159 const TensorId tid = env.makeTensorId(t: input->getOperandNumber());
1160 const Level lvlRank = enc.getLvlRank();
1161 assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1162 for (Level l = startLvl; l < lvlRank; l++) {
1163 AffineExpr lvlExpr = lvlExprs[l];
1164 if (enc.getLvlType(l).hasDenseSemantic() &&
1165 isa<AffineConstantExpr>(Val: lvlExpr))
1166 env.emitter().locateLvlAtAffineAddress(
1167 builder, loc, tidLvl: env.makeTensorLevel(t: tid, l), lvlExpr);
1168 else
1169 return; // break on first non-dense non-constant level
1170 }
1171 }
1172}
1173
1174// We can generate address for constant affine expression before any loops
1175// starting from the first level as they do not depend on anything.
1176// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1177// levels can be determined before loops.
1178static void genInitConstantDenseAddress(CodegenEnv &env,
1179 RewriterBase &rewriter) {
1180 for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1181 genConstantDenseAddressFromLevel(env, builder&: rewriter, tid, startLvl: 0);
1182}
1183
1184/// Returns true if the lattice bit can be iterated by a for loop.
1185static bool translateBitsToTidLvlPairs(
1186 CodegenEnv &env, LatPointId li, LoopId curr,
1187 SmallVectorImpl<TensorLevel> &tidLvls,
1188 SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1189 return getAllTidLvlsInLatPoints(env, li, curr,
1190 callback: [&](TensorLevel tl, AffineExpr exp) {
1191 if (exp)
1192 affineTidLvls.emplace_back(Args&: tl, Args&: exp);
1193 else
1194 tidLvls.emplace_back(Args&: tl);
1195 });
1196}
1197
1198/// Starts a single loop in current sequence.
1199static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1200 OpBuilder &builder, LoopId curr,
1201 LatPointId li, unsigned numCases,
1202 bool needsUniv) {
1203 // TODO: numCases only used when generating iterator-based loops. Cleanup
1204 // after fully migration.
1205 // The set of tensors + lvls to generate loops on
1206 SmallVector<TensorLevel> tidLvls;
1207
1208 // The set of dense tensors with non-trivial affine expression that just
1209 // becomes invariant and the address are generated at the current level.
1210 SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
1211 bool isSingleCond =
1212 translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1213
1214 // Emit the for/while-loop control.
1215 Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1216 Location loc = env.op().getLoc();
1217 for (auto [tidLvl, exp] : affineTidLvls) {
1218 env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, lvlExpr: exp);
1219 }
1220
1221 // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1222 // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1223 // on constant affines expression may now be determined.
1224 auto allTidLvls =
1225 llvm::concat<TensorLevel>(Ranges&: tidLvls, Ranges: llvm::make_first_range(c&: affineTidLvls));
1226 for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1227 if (tid != env.merger().getOutTensorID() &&
1228 tid != env.merger().getSynTensorID())
1229 genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1230 }
1231
1232 return std::make_pair(x&: loop, y&: isSingleCond);
1233}
1234
1235/// Ends a single loop in current sequence. Returns new values for needsUniv.
1236static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1237 LatPointId li, bool needsUniv, bool isSingleCond) {
1238 // Either a for-loop or a while-loop that iterates over a slice.
1239 if (isSingleCond) {
1240 // Any iteration creates a valid lex insert.
1241 if (env.isReduc() && env.isValidLexInsert())
1242 env.updateValidLexInsert(val: constantI1(rewriter, env.op().getLoc(), true));
1243 } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1244 // End a while-loop.
1245 finalizeWhileOp(env, builder&: rewriter, needsUniv);
1246 } else {
1247 needsUniv = false;
1248 }
1249 env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) {
1250 env.emitter().exitCurrentLoop(rewriter, loc: env.op().getLoc(), reduc);
1251 return std::nullopt;
1252 });
1253 return needsUniv;
1254}
1255
1256/// Ends a loop sequence at given level.
1257static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1258 unsigned at) {
1259 assert(!env.getLoopVar(at));
1260 env.emitter().exitCurrentLoopSeq(builder, loc: env.op().getLoc());
1261 // Unmark bookkeeping of invariants and loop index.
1262 genInvariants(env, builder, exp, curr: at, /*isStart=*/false);
1263 // Finalize access pattern expansion for sparse tensor output.
1264 genExpand(env, builder, curr: at, /*isStart=*/false);
1265}
1266
1267/// Recursively generates code while computing iteration lattices in order
1268/// to manage the complexity of implementing co-iteration over unions
1269/// and intersections of sparse iterations spaces.
1270static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1271 LoopId curr) {
1272 assert(curr == env.getCurrentDepth());
1273
1274 // At each leaf, assign remaining tensor (sub)expression to output tensor.
1275 if (curr == env.getLoopNum()) {
1276 Value rhs = genExp(env, rewriter, e: exp);
1277 genTensorStore(env, builder&: rewriter, exp, rhs);
1278 return;
1279 }
1280
1281 // Construct iteration lattices for current loop index.
1282 const LatSetId lts =
1283 env.merger().optimizeSet(s: env.merger().buildLattices(e: exp, i: curr));
1284
1285 // Start a loop sequence.
1286 bool needsUniv = startLoopSeq(env, builder&: rewriter, exp, curr, lts);
1287
1288 // When using sparse-iterator-based loops, we only need one loops, as
1289 // opposed to a loop sequence, to cover all the iterator spaces.
1290 const unsigned lsize = env.set(lts).size();
1291 if (env.generatingSparseIterator()) {
1292 // Get the largest lattice point and start a loop.
1293 const LatPointId li = env.set(lts)[0];
1294 auto [loop, isSingleCond] =
1295 startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv);
1296 assert(isSingleCond == llvm::isa<IterateOp>(loop));
1297 // We cannot change this to `for (const LatPointId li : env.set(lts))`
1298 // because the loop body causes data-movement which invalidates
1299 // the iterator.
1300 for (unsigned j = 0; j < lsize; j++) {
1301 const LatPointId lj = env.set(lts)[j];
1302 const ExprId ej = env.lat(l: lj).exp;
1303 // Recurse into body of each branch.
1304 if (!isSingleCond) {
1305 env.genLoopBoundary(callback: [&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1306 genCoIterationCase(env, builder&: rewriter, /*caseIdx*/ j, allCase: li, curCase: lj, reduc);
1307 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1308 // TODO: handle yield values.
1309 assert(reduc.empty() && "Not Implemented");
1310 rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
1311 return std::nullopt;
1312 });
1313 // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1314 } else {
1315 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1316 }
1317 }
1318 // End a loop.
1319 needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond);
1320 } else {
1321 // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1322 for (unsigned i = 0; i < lsize; i++) {
1323 const LatPointId li = env.set(lts)[i];
1324 // Start a loop.
1325 auto [loop, isSingleCond] =
1326 startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv);
1327
1328 // Visit all lattices points with Li >= Lj to generate the
1329 // loop-body, possibly with if statements for coiteration.
1330 Value redInput = env.getReduc();
1331 Value cntInput = env.getExpandCount();
1332 Value insInput = env.getInsertionChain();
1333 Value validIns = env.getValidLexInsert();
1334 // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1335 // because the loop body causes data-movement which invalidates the
1336 // iterator.
1337 for (unsigned j = 0; j < lsize; j++) {
1338 const LatPointId lj = env.set(lts)[j];
1339 const ExprId ej = env.lat(l: lj).exp;
1340 if (li == lj || env.merger().latGT(p0: li, p1: lj)) {
1341 // Recurse into body of each branch.
1342 if (!isSingleCond) {
1343 scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1344 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1345 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1346 } else {
1347 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1348 }
1349 }
1350 }
1351
1352 // End a loop.
1353 needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond);
1354 }
1355 }
1356
1357 // End a loop sequence.
1358 endLoopSeq(env, builder&: rewriter, exp, at: curr);
1359 assert(curr == env.getCurrentDepth());
1360}
1361
1362/// Converts the result computed by the sparse kernel into the required form.
1363static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1364 linalg::GenericOp op = env.op();
1365 OpOperand *lhs = op.getDpsInitOperand(0);
1366 Value tensor = lhs->get();
1367 Type resType = tensor.getType();
1368 if (getSparseTensorEncoding(type: resType)) {
1369 // The sparse tensor rematerializes from the original sparse tensor's
1370 // underlying sparse storage format. For an insertion chain, the
1371 // tensor materializes from the chain with 'hasInserts' enabled.
1372 bool hasInserts = false;
1373 if (Value chain = env.getInsertionChain()) {
1374 hasInserts = true;
1375 tensor = chain;
1376 }
1377 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1378 } else {
1379 // To rematerialize an non-annotated tensor, simply load it
1380 // from the bufferized value.
1381 Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1382 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1383 }
1384}
1385
1386//===----------------------------------------------------------------------===//
1387// Sparsifier rewriting methods.
1388//===----------------------------------------------------------------------===//
1389
1390namespace {
1391
1392/// Sparse rewriting rule for generic Lingalg operation.
1393struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1394public:
1395 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1396 : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1397
1398 LogicalResult matchAndRewrite(linalg::GenericOp op,
1399 PatternRewriter &rewriter) const override {
1400 // Only accept single output operations with pure tensor semantics.
1401 if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1402 return failure();
1403
1404 // Only accept trivial affine indices.
1405 if (hasNonTrivialAffineOnSparseOut(op))
1406 return failure();
1407
1408 // Only accept scheduled loops.
1409 if (!op->hasAttr("sorted")) {
1410 return rewriter.notifyMatchFailure(
1411 op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1412 "before sparsification.");
1413 }
1414
1415 // Must have been demapped as well if the generic op is sorted.
1416 assert(!hasAnyNonIdentityOperandsOrResults(op));
1417
1418 // Sets up a code generation environment.
1419 const unsigned numTensors = op->getNumOperands();
1420 const unsigned numLoops = op.getNumLoops();
1421 bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1422 // If we have indexing map like (d0) -> (0, d0), there might be more
1423 // levels then loops because of the constant index, that means we can not
1424 // use numLoops as the upper bound for ranks of all tensors.
1425 // TODO: Constant indices are currently not support on sparse tensor, but
1426 // are allowed in non-annotated dense tensor. Support it, it would be
1427 // required for sparse tensor slice rank reducing too.
1428 Level maxLvlRank = 0;
1429 for (auto operand : op.getOperands()) {
1430 if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1431 maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1432 }
1433 }
1434
1435 // Detects sparse annotations and translates the per-level sparsity
1436 // information for all tensors to loop indices in the kernel.
1437 CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1438 if (!findSparseAnnotations(env, idxReducBased: needIdxRed))
1439 return failure();
1440
1441 // Only standard reduction operations (add, sub, or, xor) that can be
1442 // sparsified by merely reducing the stored values are admissible. More
1443 // elaborate reduction operations (such as mul, and, min, max) would need
1444 // to know whether implicit zeros occur as well. They can still be
1445 // implemented with a custom reduction operation, accepted here as well.
1446 if (op.getNumReductionLoops() > 0) {
1447 Operation *yield = op.getRegion().front().getTerminator();
1448 assert(isa<linalg::YieldOp>(yield));
1449 Operation *redop = yield->getOperand(idx: 0).getDefiningOp();
1450 if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1451 !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1452 !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1453 !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1454 !isa<ReduceOp>(redop)) {
1455 return failure();
1456 }
1457 }
1458
1459 // Constructs the tensor expressions tree from `op`, returns failure if the
1460 // tree can not be built or the tensor expression is inadmissible.
1461 if (failed(Result: env.initTensorExp()))
1462 return failure();
1463
1464 // Recursively generates code if admissible.
1465 env.startEmit(emitStrategy: options.sparseEmitStrategy);
1466 genBuffers(env, builder&: rewriter);
1467 // TODO: Constant affine expression should be handled differently when using
1468 // slice-based codegen, it does not matter now because we already reject the
1469 // constant expression at an earlier stage.
1470 genInitConstantDenseAddress(env, rewriter);
1471 genStmt(env, rewriter, exp: env.getExprId(), curr: 0);
1472 genResult(env, rewriter);
1473 return success();
1474 }
1475
1476private:
1477 /// Options to control sparse code generation.
1478 SparsificationOptions options;
1479};
1480
1481} // namespace
1482
1483/// Populates the given patterns list with rewriting rules required for
1484/// the sparsification of linear algebra operations.
1485void mlir::populateSparsificationPatterns(
1486 RewritePatternSet &patterns, const SparsificationOptions &options) {
1487 patterns.add<GenericOpSparsifier>(arg: patterns.getContext(), args: options);
1488}
1489

source code of mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp