1//===- Sparsification.cpp - Implementation of sparsification --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements converting sparse tensor types to actual sparse code.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenEnv.h"
14#include "Utils/CodegenUtils.h"
15#include "Utils/LoopEmitter.h"
16
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/Dialect/Linalg/IR/Linalg.h"
24#include "mlir/Dialect/Linalg/Utils/Utils.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/SCF/Transforms/Transforms.h"
28#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
29#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
30#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
31#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
32#include "mlir/Dialect/Tensor/IR/Tensor.h"
33#include "mlir/IR/AffineExprVisitor.h"
34#include "mlir/IR/Matchers.h"
35#include "mlir/IR/TensorEncoding.h"
36#include "llvm/ADT/SmallBitVector.h"
37
38#include <optional>
39
40using namespace mlir;
41using namespace mlir::sparse_tensor;
42
43//===----------------------------------------------------------------------===//
44// Sparsifier analysis methods.
45//===----------------------------------------------------------------------===//
46
47/// Returns true iff affine expression is invariant. Sets the
48/// parameter `isCurrentLoop` when expression just became invariant.
49static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
50 switch (a.getKind()) {
51 case AffineExprKind::DimId: {
52 const LoopId i = cast<AffineDimExpr>(Val&: a).getPosition();
53 if (i + 1 == curr) {
54 isCurrentLoop = true;
55 return true; // becomes invariant at current loop
56 }
57 return i < curr; // invariant when already generated
58 }
59 case AffineExprKind::Add:
60 case AffineExprKind::Mul: {
61 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
62 return isInvariantAffine(a: binOp.getLHS(), curr, isCurrentLoop) &&
63 isInvariantAffine(a: binOp.getRHS(), curr, isCurrentLoop);
64 }
65 default: {
66 assert(isa<AffineConstantExpr>(a));
67 return true;
68 }
69 }
70}
71
72/// Helper method to inspect affine expressions. Rejects cases where the
73/// same index is used more than once. Also rejects compound affine
74/// expressions in sparse dimensions.
75static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
76 LevelType lt, bool setLvlFormat = true) {
77 switch (a.getKind()) {
78 case AffineExprKind::DimId: {
79 const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition());
80 if (!isUndefLT(lt: merger.getLvlType(t: tid, i: idx)))
81 return false; // used more than once
82 if (setLvlFormat)
83 merger.setLevelAndType(t: tid, i: idx, lvl, lt);
84 return true;
85 }
86 case AffineExprKind::Add:
87 case AffineExprKind::Mul:
88 case AffineExprKind::Constant: {
89 assert(lt.hasDenseSemantic());
90 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: a)) {
91 // We do not set dim level format for affine expression like d0 + d1 on
92 // either loop index at d0 or d1. We continue the recursion merely to
93 // check whether current affine is admissible or not.
94 return findAffine(merger, tid, lvl, a: binOp.getLHS(), lt, setLvlFormat: false) &&
95 findAffine(merger, tid, lvl, a: binOp.getRHS(), lt, setLvlFormat: false);
96 }
97 // Falls through when it is a constant Affine
98 return true;
99 }
100 default:
101 return false;
102 }
103}
104
105/// Helper method to inspect affine expressions for index variable reduction
106/// based codegen. It finds the dependent index set for all tensor levels in the
107/// current expression we are generating.
108///
109/// For example, when handling A[i+j][j+k], we build the two way mapping in
110/// merger between (tensor, level) pairs and their dependent index variable set:
111/// A_0 <=> [i, j] and A_1 <=> [j, k]
112///
113/// It rejects cases (returns false)
114/// 1st, when the same index is used more than once, e.g., A[i+j][i]
115/// 2nd, when multiplication is used in the non-trivial index expression.
116/// 3rd, when a constant operand is used in the non-trivial index expression.
117///
118/// TODO: constant should be easy to handle.
119static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
120 AffineExpr a, LevelType lt, bool isSubExp = false,
121 int64_t coefficient = 1) {
122 switch (a.getKind()) {
123 case AffineExprKind::DimId: {
124 // Only allow positive coefficients on AffineDimExpr.
125 if (coefficient <= 0)
126 return false;
127
128 const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition());
129 if (!isUndefLT(lt: merger.getLvlType(t: tensor, i: idx)))
130 return false; // used more than once, e.g., A[i][i]
131
132 // TODO: Generalizes the following two cases. A[i] (with trivial index
133 // expression) can be treated as a special affine index expression. We do
134 // not necessarily need to differentiate them.
135 if (!isSubExp) {
136 assert(coefficient == 1);
137 merger.setLevelAndType(t: tensor, i: idx, lvl, lt);
138 }
139
140 if (isSubExp) {
141 // The current loops appears in more than one affine expressions on the
142 // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
143 // used twice.
144 if (merger.hasDependentLvl(i: idx, t: tensor)) {
145 // TODO: This can be supported by coiterate slices if the loop idx is
146 // appeared on affine index for different tensor, or take slice on
147 // multiple dimensions when it is on the same tensor.
148 // E.g.,
149 // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
150 // d0_1 = getNextSliceOffset t0 along lvl0
151 // d0_2 = getNextSliceOffset t1 along lvl0
152 // if d0_1 == d0_2 then d0 = d0_1 = d0_1
153 // else increase min(d0_1, d0_2).
154 return false;
155 }
156 merger.setLoopDependentTensorLevel(i: idx, t: tensor, lvl, lt, coefficient);
157 }
158 return true;
159 }
160 case AffineExprKind::Constant:
161 case AffineExprKind::Mul: {
162 // TODO: Support index expression like `2 * d0`, we now only support more
163 // complicated cases like `2 * d0 + d1`.
164 if (!isSubExp)
165 return false;
166
167 // TODO: Support Constant AffineExp for slice-based codegen
168 if (isa<AffineConstantExpr>(Val: a))
169 llvm_unreachable("Not yet implemented");
170
171 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
172 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
173 if (isa<AffineConstantExpr>(Val: rhs))
174 std::swap(a&: lhs, b&: rhs);
175 // Must be in form of `constant * d`.
176 assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
177 int64_t coefficient = cast<AffineConstantExpr>(Val&: lhs).getValue();
178 return findDepIdxSet(merger, tensor, lvl, a: rhs, lt, isSubExp, coefficient);
179 }
180 case AffineExprKind::Add: {
181 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
182 return findDepIdxSet(merger, tensor, lvl, a: binOp.getLHS(), lt, isSubExp: true) &&
183 findDepIdxSet(merger, tensor, lvl, a: binOp.getRHS(), lt, isSubExp: true);
184 }
185 default:
186 return false;
187 }
188}
189
190/// Gets the total number of compound affine expressions in the
191/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
192///
193/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
194///
195/// Returns 1 (because the first level is compressed and its corresponding
196/// indexing-expression is `d0 + d1`)
197static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
198 Value tensor) {
199 // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
200 // we can't use `getRankedTensorType`/`getSparseTensorType` here.
201 // However, we don't need to handle `StorageSpecifierType`, so we
202 // can use `SparseTensorType` once we guard against non-tensors.
203 const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
204 if (!rtp)
205 return 0;
206 const SparseTensorType stt(rtp);
207
208 const Level lvlRank = stt.getLvlRank();
209 const auto exprs = map.getResults();
210 assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
211 "AffineMap does not have dimension-rank many results");
212 unsigned num = 0;
213 for (Level l = 0; l < lvlRank; l++) {
214 if (!isa<AffineDimExpr>(Val: exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
215 num++;
216 }
217 return num;
218}
219
220/// Gets the total number of sparse levels with compound affine
221/// expressions, summed over all operands of the `GenericOp`.
222static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
223 unsigned num = 0;
224 for (OpOperand &t : op->getOpOperands())
225 num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
226 t.get());
227 return num;
228}
229
230// Returns true iff output has nontrivial affine indices.
231static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
232 OpOperand *out = op.getDpsInitOperand(0);
233 if (getSparseTensorType(val: out->get()).isAllDense())
234 return false;
235 return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
236 out->get());
237}
238
239/// Helper method to inspect sparse encodings in the tensor types.
240/// Fills the per-dimension sparsity information for all tensors.
241/// Returns true if the sparse annotations and affine subscript
242/// expressions of all tensors are admissible. Returns false if
243/// no annotations are found or inadmissible constructs occur.
244/// We currently support two different ways to handle non-trivial index
245/// expression on sparse tensors, and they accept different affine expressions.
246/// When using dependent index reducton-based approach, it currently only
247/// supports affine addition index expression.
248static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
249 bool annotated = false;
250 for (OpOperand &t : env.op()->getOpOperands()) {
251 const TensorId tid = env.makeTensorId(t.getOperandNumber());
252 const auto map = env.op().getMatchingIndexingMap(&t);
253 const auto enc = getSparseTensorEncoding(t.get().getType());
254 if (enc)
255 annotated = true;
256 const Level lvlRank = map.getNumResults();
257 assert(!enc || lvlRank == enc.getLvlRank());
258 assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
259 // We only need to do index reduction if there is at least one
260 // non-trivial index expression on sparse levels. If all non-trivial
261 // index expression is on dense levels, we can efficiently rely on
262 // the random access to locate the element.
263 bool needIdxReduc =
264 enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
265 // If then current tensor being inspected requires affine index, it need
266 // to be sliced.
267 for (Level l = 0; l < lvlRank; l++) {
268 const AffineExpr a = map.getResult(l);
269 const LevelType lt = enc.getLvlType(l);
270 if (idxReducBased && needIdxReduc) {
271 if (!findDepIdxSet(env.merger(), tid, l, a, lt))
272 return false; // inadmissible affine expression
273 } else {
274 if (!findAffine(env.merger(), tid, l, a, lt))
275 return false; // inadmissible affine expression
276 }
277 }
278 }
279 return annotated;
280}
281
282//===----------------------------------------------------------------------===//
283// Sparsifier synthesis methods (statements and expressions).
284//===----------------------------------------------------------------------===//
285
286/// Local bufferization of all dense and sparse data structures.
287static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
288 linalg::GenericOp op = env.op();
289 Location loc = op.getLoc();
290 assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
291
292 SmallVector<Range, 4> loopRange =
293 llvm::cast<linalg::LinalgOp>(op.getOperation())
294 .createLoopRanges(builder, loc);
295
296 env.emitter().initializeLoopEmit(
297 builder, loc,
298 /// Generates buffer for the output tensor.
299 /// Note that all sparse kernels assume that when all elements are written
300 /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
301 /// to all zeroes and only nonzeroes values are computed and written out.
302 /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
303 /// for the updates and no assumption on the original contents of the
304 /// output buffer is necessary.
305 [&op](OpBuilder &builder, Location loc, Value memref,
306 Value tensor) -> Value {
307 // Must not be a sparse tensor.
308 assert(!getSparseTensorEncoding(tensor.getType()));
309 // Two output tensor references should point to the same object.
310 OpOperand *lhs = op.getDpsInitOperand(0);
311 assert(lhs->get() == tensor);
312 // An output tensor can simply materialize from the buffer of the tensor
313 // that appears in the outs() clause. For updates, this has the
314 // advantage that only the nonzero value are involved in the
315 // computation, keeping the operation O(nnz). In all other cases, we are
316 // forced to zero out the buffer to enforce the assumption above, which
317 // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
318 // O(nnz) for matrices).
319 // TODO: use better analysis to avoid zeroing out the buffer?
320 bool isInit = op.isInitTensor(lhs);
321 Value init = memref;
322 if (!isInit) {
323 Value zero = constantZero(builder, loc,
324 tp: getElementTypeOrSelf(type: tensor.getType()));
325 builder.create<linalg::FillOp>(loc, ValueRange{zero},
326 ValueRange{init});
327 }
328 return init;
329 },
330 [&loopRange](OpBuilder &b, Location loc, Level l) {
331 assert(l < loopRange.size());
332 return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
333 });
334}
335
336/// Generates index for load/store on sparse tensor.
337static Value genIndex(CodegenEnv &env, OpOperand *t) {
338 const auto map = env.op().getMatchingIndexingMap(t);
339 const auto stt = getSparseTensorType(val: t->get());
340 const Level lvlRank = stt.getLvlRank();
341 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
342 const AffineExpr a = map.getResult(lvlRank - 1);
343 assert(a.getKind() == AffineExprKind::DimId);
344 const LoopId idx = env.makeLoopId(i: cast<AffineDimExpr>(Val: a).getPosition());
345 return env.getLoopVar(i: idx);
346}
347
348/// Generates subscript for load/store on a dense or sparse tensor.
349static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
350 SmallVectorImpl<Value> &args) {
351 const Location loc = env.op().getLoc();
352 const TensorId tid = env.makeTensorId(t: t->getOperandNumber());
353 const auto map = env.op().getMatchingIndexingMap(t);
354 const auto stt = getSparseTensorType(val: t->get());
355 if (stt.hasEncoding()) {
356 // For sparse tensors we only push the last-level's position onto `args`.
357 const auto pos = env.emitter().getValPosits(tid);
358 assert(!pos.empty());
359 args.append(RHS: pos);
360 } else {
361 // For dense tensors we push all level's coordinates onto `args`.
362 const Level lvlRank = stt.getLvlRank();
363 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
364 for (Level l = 0; l < lvlRank; l++) {
365 const auto lvlExpr = map.getResult(l);
366 const auto lvlCrd = env.emitter().genAffine(builder, loc, a: lvlExpr);
367 args.push_back(Elt: lvlCrd);
368 }
369 }
370 return env.emitter().getValBuffer()[tid];
371}
372
373/// Generates insertion code to implement dynamic tensor load.
374static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
375 OpOperand *t) {
376 linalg::GenericOp op = env.op();
377 Location loc = op.getLoc();
378 // Direct lexicographic coordinate order, tensor loads as zero.
379 if (!env.isExpand()) {
380 Type tp = getElementTypeOrSelf(type: t->get().getType());
381 return constantZero(builder, loc, tp);
382 }
383 // Load from expanded access pattern.
384 Value index = genIndex(env, t);
385 return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
386}
387
388/// Generates insertion code to implement dynamic tensor load for reduction.
389static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
390 OpOperand *t) {
391 linalg::GenericOp op = env.op();
392 Location loc = op.getLoc();
393 Value identity = env.getCustomRedId();
394 // Direct lexicographic coordinate order, tensor loads as identity.
395 if (!env.isExpand())
396 return identity;
397 // Load from expanded access pattern if filled, identity otherwise.
398 Value values = env.getExpandValues();
399 Value filled = env.getExpandFilled();
400 Value index = genIndex(env, t);
401 Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
402 Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
403 return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
404}
405
406/// Generates insertion code to implement dynamic tensor store.
407static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
408 Value rhs) {
409 linalg::GenericOp op = env.op();
410 Location loc = op.getLoc();
411 // Direct insertion in lexicographic coordinate order.
412 if (!env.isExpand()) {
413 const LoopId numLoops = op.getRank(t);
414 // Retrieves the first `numLoop` induction variables.
415 SmallVector<Value> ivs = llvm::to_vector(Range: llvm::drop_end(
416 RangeOrContainer: env.emitter().getLoopIVsRange(), N: env.getCurrentDepth() - numLoops));
417 Value chain = env.getInsertionChain();
418 if (env.isValidLexInsert()) {
419 // Generates runtime check for a valid lex during reduction,
420 // to avoid inserting the identity value for empty reductions.
421 // if (validLexInsert) then
422 // insert(rhs) into chain
423 // return updated chain
424 // else
425 // return unmodified chain
426 scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
427 loc, chain.getType(), env.getValidLexInsert(),
428 /*else=*/true);
429 // True branch.
430 builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
431 Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
432 builder.create<scf::YieldOp>(loc, res);
433 // False branch.
434 builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
435 builder.create<scf::YieldOp>(loc, chain);
436 // Value assignment.
437 builder.setInsertionPointAfter(ifValidLexInsert);
438 env.updateInsertionChain(chain: ifValidLexInsert.getResult(0));
439 } else {
440 // Generates regular insertion chain.
441 env.updateInsertionChain(
442 builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
443 }
444 return;
445 }
446 // Generates insertion code along expanded access pattern.
447 // if (!expFilled[i]) then
448 // expFilled[i] = true
449 // expAdded[inserts++] = i
450 // endif
451 // values[i] = rhs
452 Value values = env.getExpandValues();
453 Value filled = env.getExpandFilled();
454 Value added = env.getExpandAdded();
455 Value count = env.getExpandCount();
456 Value index = genIndex(env, t);
457 Value fval = constantI1(builder, loc, b: false);
458 Value tval = constantI1(builder, loc, b: true);
459 // If statement.
460 Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
461 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
462 isFilled, fval);
463 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
464 /*else=*/true);
465 // True branch.
466 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
467 builder.create<memref::StoreOp>(loc, tval, filled, index);
468 builder.create<memref::StoreOp>(loc, index, added, count);
469 Value one = constantIndex(builder, loc, i: 1);
470 Value add = builder.create<arith::AddIOp>(loc, count, one);
471 builder.create<scf::YieldOp>(loc, add);
472 // False branch.
473 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
474 builder.create<scf::YieldOp>(loc, count);
475 builder.setInsertionPointAfter(ifOp);
476 // Value assignment.
477 env.updateExpandCount(count: ifOp.getResult(0));
478 builder.create<memref::StoreOp>(loc, rhs, values, index);
479}
480
481/// Generates a load on a dense or sparse tensor.
482static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
483 // Test if the load was hoisted to a higher loop nest.
484 Value val = env.exp(e: exp).val;
485 if (val)
486 return val;
487 // Load during insertion.
488 linalg::GenericOp op = env.op();
489 OpOperand *t = &op->getOpOperand(env.exp(e: exp).tensor);
490 if (env.isSparseOutput(o: t)) {
491 if (env.isCustomReduc())
492 return genInsertionLoadReduce(env, builder, t);
493 return genInsertionLoad(env, builder, t);
494 }
495 // Actual load.
496 SmallVector<Value> args;
497 Value ptr = genSubscript(env, builder, t, args);
498 return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
499}
500
501/// Generates a store on a dense or sparse tensor.
502static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
503 Value rhs) {
504 // Only unary and binary are allowed to return an uninitialized rhs
505 // to indicate missing output. Or otherwise a custom reduction that
506 // received no value to accumulate.
507 if (!rhs) {
508 assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
509 env.exp(exp).kind == TensorExp::Kind::kBinary ||
510 env.exp(exp).kind == TensorExp::Kind::kReduce);
511 return;
512 }
513 // Test if this is a scalarized reduction.
514 if (env.isReduc()) {
515 env.updateReduc(val: rhs);
516 return;
517 }
518 // Regular store.
519 linalg::GenericOp op = env.op();
520 Location loc = op.getLoc();
521 OpOperand *t = op.getDpsInitOperand(0);
522 if (!env.isSparseOutput(o: t)) {
523 SmallVector<Value> args;
524 Value ptr = genSubscript(env, builder, t, args);
525 builder.create<memref::StoreOp>(loc, rhs, ptr, args);
526 return;
527 }
528 // Store during sparse insertion.
529 if (env.exp(e: exp).kind != TensorExp::Kind::kSelect) {
530 genInsertionStore(env, builder, t, rhs);
531 return;
532 }
533 // Select operation insertion.
534 Value chain = env.getInsertionChain();
535 scf::IfOp ifOp =
536 builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
537 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
538 // Existing value was preserved to be used here.
539 assert(env.exp(exp).val);
540 Value v0 = env.exp(e: exp).val;
541 genInsertionStore(env, builder, t, rhs: v0);
542 env.merger().clearExprValue(e: exp);
543 // Yield modified insertion chain along true branch.
544 Value mchain = env.getInsertionChain();
545 builder.create<scf::YieldOp>(op.getLoc(), mchain);
546 // Yield original insertion chain along false branch.
547 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
548 builder.create<scf::YieldOp>(loc, chain);
549 // Done with if statement.
550 env.updateInsertionChain(chain: ifOp->getResult(0));
551 builder.setInsertionPointAfter(ifOp);
552}
553
554/// Generates an invariant value.
555inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
556 return env.exp(e: exp).val;
557}
558
559/// Semi-ring branches are simply inlined by the sparsifier. Prior
560/// analysis has verified that all computations are "local" to the inlined
561/// branch or otherwise invariantly defined outside the loop nest, with the
562/// exception of index computations, which need to be relinked to actual
563/// inlined cloned code.
564static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
565 Value e) {
566 if (auto arg = dyn_cast<BlockArgument>(Val&: e)) {
567 // Direct arguments of the original linalg op must be converted
568 // into dense tensor loads. Note that we should not encounter
569 // anything else. This needs to be verified by semi-ring ops.
570 linalg::GenericOp op = env.op();
571 if (arg.getOwner()->getParentOp() == op) {
572 const TensorId tid = env.makeTensorId(t: arg.getArgNumber());
573 OpOperand *t = &op->getOpOperand(tid);
574 assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
575 SmallVector<Value> args;
576 Value ptr = genSubscript(env, builder&: rewriter, t, args);
577 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
578 }
579 } else if (Operation *def = e.getDefiningOp()) {
580 // Handle index computation.
581 if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
582 return env.getLoopVar(i: env.makeLoopId(i: indexOp.getDim()));
583 // When still defined in new body, recurse into operands.
584 if (def->getBlock() == block) {
585 rewriter.setInsertionPoint(def);
586 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
587 rewriter.modifyOpInPlace(root: def, callable: [&]() {
588 def->setOperand(
589 idx: i, value: relinkBranch(env, rewriter, block, e: def->getOperand(idx: i)));
590 });
591 }
592 }
593 }
594 return e;
595}
596
597/// Recursively generates tensor expression.
598static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
599 if (e == ::mlir::sparse_tensor::detail::kInvalidId)
600 return Value();
601
602 linalg::GenericOp op = env.op();
603 Location loc = op.getLoc();
604 const TensorExp &exp = env.exp(e);
605 const auto kind = exp.kind;
606 if (kind == TensorExp::Kind::kTensor)
607 return genTensorLoad(env, builder&: rewriter, exp: e);
608 if (kind == TensorExp::Kind::kInvariant)
609 return genInvariantValue(env, exp: e);
610 if (kind == TensorExp::Kind::kLoopVar)
611 return env.getLoopVar(i: exp.loop);
612
613 if (kind == TensorExp::Kind::kReduce)
614 env.startCustomReduc(exp: e); // enter custom
615
616 // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
617 // based on the type of the other operand.
618 Value v0, v1;
619 if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
620 env.exp(e: exp.children.e0).kind == TensorExp::Kind::kSynZero) {
621 v1 = genExp(env, rewriter, e: exp.children.e1);
622 v0 = constantZero(builder&: rewriter, loc, tp: v1.getType());
623 } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
624 env.exp(e: exp.children.e1).kind == TensorExp::Kind::kSynZero) {
625 v0 = genExp(env, rewriter, e: exp.children.e0);
626 v1 = constantZero(builder&: rewriter, loc, tp: v0.getType());
627 } else {
628 v0 = genExp(env, rewriter, e: exp.children.e0);
629 v1 = genExp(env, rewriter, e: exp.children.e1);
630 }
631
632 Value ee;
633 if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
634 // custom reduce did not receive a value
635 } else {
636 ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
637 if (ee &&
638 (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
639 kind == TensorExp::Kind::kBinaryBranch ||
640 kind == TensorExp::Kind::kReduce ||
641 kind == TensorExp::Kind::kSelect)) {
642 OpBuilder::InsertionGuard guard(rewriter);
643 ee = relinkBranch(env, rewriter, block: ee.getParentBlock(), e: ee);
644 }
645 }
646
647 if (kind == TensorExp::Kind::kReduce)
648 env.endCustomReduc(); // exit custom
649
650 if (kind == TensorExp::Kind::kSelect)
651 env.merger().setExprValue(e, v: v0); // Preserve value for later use.
652
653 return ee;
654}
655
656/// Hoists loop invariant tensor loads for which indices have been exhausted.
657static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
658 LoopId curr, bool isStart) {
659 if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
660 return;
661 if (env.exp(e: exp).kind == TensorExp::Kind::kTensor) {
662 // Inspect tensor indices.
663 linalg::GenericOp op = env.op();
664 OpOperand &t = op->getOpOperand(env.exp(e: exp).tensor);
665 const auto map = op.getMatchingIndexingMap(&t);
666 const auto stt = getSparseTensorType(val: t.get());
667 const Level lvlRank = stt.getLvlRank();
668 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
669 bool isCurrentLoop = curr == 0; // for scalar tensors
670 for (Level l = 0; l < lvlRank; l++) {
671 const AffineExpr a = map.getResult(l);
672 if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
673 return; // still in play
674 }
675 // All exhausted at current level.
676 if (!isCurrentLoop)
677 return;
678 // Generate code for a scalarized reduction or invariant. Note that
679 // because custom reduction lhs may occur several times in the IR,
680 // we have a built-in safety for only initializing and wrapping-up
681 // the scalarized reduction once.
682 OpOperand *lhs = op.getDpsInitOperand(0);
683 if (lhs == &t) {
684 // Start or end a scalarized reduction.
685 if (isStart) {
686 if (env.isCustomReduc()) {
687 if (!env.isReduc())
688 env.startReduc(exp, val: env.getCustomRedId());
689 } else {
690 env.startReduc(exp, val: genTensorLoad(env, builder, exp));
691 }
692 if (env.hasSparseOutput())
693 env.startValidLexInsert(
694 val: constantI1(builder, env.op().getLoc(), false));
695 } else {
696 if (!env.isCustomReduc() || env.isReduc())
697 genTensorStore(env, builder, exp, rhs: env.endReduc());
698 if (env.hasSparseOutput())
699 env.endValidLexInsert();
700 }
701 } else {
702 // Start or end loop invariant hoisting of a tensor load.
703 if (isStart) {
704 env.merger().setExprValue(e: exp, v: genTensorLoad(env, builder, exp));
705 } else {
706 env.merger().clearExprValue(e: exp);
707 }
708 }
709 } else if (env.exp(e: exp).kind != TensorExp::Kind::kInvariant &&
710 env.exp(e: exp).kind != TensorExp::Kind::kLoopVar &&
711 env.exp(e: exp).kind != TensorExp::Kind::kSynZero) {
712 // Traverse into the binary operations. Note that we only hoist
713 // tensor loads, since subsequent MLIR/LLVM passes know how to
714 // deal with all other kinds of derived loop invariants.
715 if (env.exp(e: exp).kind == TensorExp::Kind::kReduce)
716 env.startCustomReduc(exp); // enter custom
717 const ExprId e0 = env.exp(e: exp).children.e0;
718 const ExprId e1 = env.exp(e: exp).children.e1;
719 genInvariants(env, builder, exp: e0, curr, isStart);
720 genInvariants(env, builder, exp: e1, curr, isStart);
721 if (env.exp(e: exp).kind == TensorExp::Kind::kReduce)
722 env.endCustomReduc(); // exit custom
723 }
724}
725
726/// Generates an expanded access pattern in innermost dimension.
727static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
728 bool isStart) {
729 linalg::GenericOp op = env.op();
730 OpOperand *lhs = op.getDpsInitOperand(0);
731 if (!env.atExpandLevel(o: lhs, rank: op.getRank(lhs), n: curr))
732 return; // not needed at current level
733 assert(!env.isReduc());
734 // Generate start or end of an expanded access pattern. Note that because
735 // an expansion does not rely on the ongoing contents of the sparse storage
736 // scheme, we can use the original tensor as incoming SSA value (which
737 // simplifies codegen a bit). If expansion on the actual contents is ever
738 // needed, we will need to use the SSA value in the insertion chain instead.
739 Value tensor = lhs->get();
740 Location loc = op.getLoc();
741 if (isStart) {
742 auto dynShape = {ShapedType::kDynamic};
743 Type etp = cast<ShapedType>(tensor.getType()).getElementType();
744 Type t1 = MemRefType::get(dynShape, etp);
745 Type t2 = MemRefType::get(dynShape, builder.getI1Type());
746 Type t3 = MemRefType::get(dynShape, builder.getIndexType());
747 Type t4 = builder.getIndexType();
748 auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
749 assert(r.getNumResults() == 4);
750 env.startExpand(values: r.getResult(0), filled: r.getResult(1), added: r.getResult(2),
751 count: r.getResult(3));
752 } else {
753 SmallVector<Value> indices;
754 for (LoopId i = 0; i < curr; i++)
755 indices.push_back(Elt: env.emitter().getLoopIV(n: i));
756 Value values = env.getExpandValues();
757 Value filled = env.getExpandFilled();
758 Value added = env.getExpandAdded();
759 Value count = env.getExpandCount();
760 Value chain = env.getInsertionChain();
761 Value compress = builder.create<CompressOp>(loc, values, filled, added,
762 count, chain, indices);
763 env.updateInsertionChain(chain: compress);
764 env.endExpand();
765 }
766}
767
768/// Returns parallelization strategy. Any implicit loop in the Linalg
769/// operation that is marked "parallel" is a candidate. Whether it is actually
770/// converted to a parallel operation depends on the requested strategy.
771static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
772 // Reject parallelization of sparse output.
773 if (env.hasSparseOutput())
774 return false;
775 // Parallel loops on tensor expansion can cause data races.
776 if (env.isExpand())
777 return false;
778 // Inspect strategy.
779 switch (env.options().parallelizationStrategy) {
780 case SparseParallelizationStrategy::kNone:
781 return false;
782 case SparseParallelizationStrategy::kDenseOuterLoop:
783 return isOuter && !isSparse;
784 case SparseParallelizationStrategy::kAnyStorageOuterLoop:
785 return isOuter;
786 case SparseParallelizationStrategy::kDenseAnyLoop:
787 return !isSparse;
788 case SparseParallelizationStrategy::kAnyStorageAnyLoop:
789 return true;
790 }
791 llvm_unreachable("unexpected parallelization strategy");
792}
793
794/// Whether or not the current loop being generated should be parallized (if
795/// possible) according to the configuration.
796static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
797 ArrayRef<TensorLevel> tidLvls) {
798 linalg::GenericOp op = env.op();
799 auto iteratorTypes = op.getIteratorTypesArray();
800 bool isSparse = llvm::any_of(Range&: tidLvls, P: [curr, &env](TensorLevel tidLvl) {
801 // Queries the LT based on the tensor and loop id, as requested by
802 // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
803 // should be consistent with the LT indexed by <TensorId, Level>.
804 const auto lt = env.lt(t: env.unpackTensorLevel(tl: tidLvl).first, i: curr);
805 return lt.hasSparseSemantic();
806 });
807 return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
808}
809
810/// Emit a loop to coiterate over the list of tensor levels. The generated loop
811/// can either be a for loop or while loop depending on whether there is at most
812/// one sparse level in the list.
813static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
814 ArrayRef<TensorLevel> tidLvls,
815 bool tryParallel, bool needsUniv) {
816 Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
817 // Construct while-loop with a parameter for each index.
818 return env.emitter().enterCoIterationOverTensorsAtLvls(
819 builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
820 });
821 assert(loop);
822 return loop;
823}
824
825/// Generates a for-loop or a while-loop, depending on whether it implements
826/// singleton iteration or co-iteration over the given conjunction.
827static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
828 bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
829 bool tryParallel = shouldTryParallize(env, curr, tidLvls);
830 return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
831}
832
833/// Generates the induction structure for a while-loop.
834static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
835 bool needsUniv) {
836 Location loc = env.op().getLoc();
837 // Finalize each else branch of all if statements.
838 if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
839 while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
840 builder.getInsertionBlock()->getParentOp())) {
841 // Break on IfOp for slicing filtering.
842 if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
843 StringAttr::get(ifOp->getContext(), "slice"))
844 break;
845
846 unsigned y = 0;
847 SmallVector<Value> yields;
848 if (env.isReduc()) {
849 yields.push_back(Elt: env.getReduc());
850 env.updateReduc(val: ifOp.getResult(y++));
851 if (env.isValidLexInsert()) {
852 yields.push_back(Elt: env.getValidLexInsert());
853 env.updateValidLexInsert(val: ifOp.getResult(y++));
854 }
855 }
856 if (env.isExpand()) {
857 yields.push_back(Elt: env.getExpandCount());
858 env.updateExpandCount(count: ifOp->getResult(y++));
859 }
860 if (env.getInsertionChain()) {
861 yields.push_back(Elt: env.getInsertionChain());
862 env.updateInsertionChain(chain: ifOp->getResult(y++));
863 }
864 assert(y == yields.size());
865 builder.create<scf::YieldOp>(loc, yields);
866 builder.setInsertionPointAfter(ifOp);
867 }
868 }
869 // No need to set the insertion point here as LoopEmitter keeps track of the
870 // basic block where scf::Yield should be inserted.
871}
872
873/// Generates a single if-statement within a while-loop.
874static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
875 LatPointId p) {
876 Location loc = env.op().getLoc();
877 SmallVector<Type> types;
878 Value cond;
879 env.merger().foreachTensorLoopId(
880 p, /*simple=*/true,
881 callback: [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
882 bool isIdxRed) {
883 if (isIdxRed) {
884 // Since there is no 1:1 mapping from loop to level (multiple loops
885 // are required to resolve one level with non-trivial index
886 // expression), we need to reconstruct the tensor level types if this
887 // loop requires index reduction condition.
888 assert(lvl.has_value() && isUndefLT(lt));
889 auto stt = getSparseTensorType(env.op().getInputs()[tid]);
890 lt = stt.getLvlType(*lvl);
891 }
892 assert(curr == env.merger().loop(b));
893 Value clause;
894 if (lt.hasSparseSemantic()) {
895 assert(lvl.has_value());
896 const Value crd = env.emitter().getCoord(tid, lvl: *lvl);
897 const Value lvar = env.getLoopVar(i: curr);
898 clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
899 crd, lvar);
900 } else {
901 assert(lt.hasDenseSemantic() || isUndefLT(lt));
902 clause = constantI1(builder, loc, b: true);
903 }
904 cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
905 });
906 if (env.isReduc()) {
907 types.push_back(Elt: env.getReduc().getType());
908 if (env.isValidLexInsert())
909 types.push_back(Elt: env.getValidLexInsert().getType());
910 }
911 if (env.isExpand())
912 types.push_back(builder.getIndexType());
913 if (env.getInsertionChain())
914 types.push_back(Elt: env.getInsertionChain().getType());
915 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
916 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
917 return ifOp;
918}
919
920/// Generates end of true branch of if-statement within a while-loop.
921static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
922 Value redInput, Value cntInput, Value insInput,
923 Value validIns) {
924 SmallVector<Value> operands;
925 if (env.isReduc()) {
926 operands.push_back(Elt: env.getReduc());
927 env.updateReduc(val: redInput);
928 if (env.isValidLexInsert()) {
929 // Any overlapping indices during a reduction creates a valid lex insert.
930 operands.push_back(Elt: constantI1(builder, env.op().getLoc(), true));
931 env.updateValidLexInsert(val: validIns);
932 }
933 }
934 if (env.isExpand()) {
935 operands.push_back(Elt: env.getExpandCount());
936 env.updateExpandCount(count: cntInput);
937 }
938 if (env.getInsertionChain()) {
939 operands.push_back(Elt: env.getInsertionChain());
940 env.updateInsertionChain(chain: insInput);
941 }
942 if (!operands.empty())
943 builder.create<scf::YieldOp>(env.op().getLoc(), operands);
944 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
945}
946
947//===----------------------------------------------------------------------===//
948// Sparsifier synthesis methods (loop sequence).
949//===----------------------------------------------------------------------===//
950
951static bool getAllTidLvlsInLatPoints(
952 CodegenEnv &env, LatPointId li, LoopId curr,
953 llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
954 const BitVector &simple = env.lat(l: li).simple;
955 const TensorId outTid = env.merger().getOutTensorID();
956 const std::optional<Level> outLvl = env.merger().getLvl(t: outTid, i: curr);
957
958 unsigned numloopCond = 0;
959 bool hasNonUnique = false;
960 env.merger().foreachTensorLoopId(
961 p: li, callback: [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
962 LevelType lt, bool isIdxReduc) {
963 if (simple[b]) {
964 if (isIdxReduc) {
965 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
966 numloopCond++;
967 return;
968 }
969 if (isUndefLT(lt)) {
970 // An undefined lt in the lattices, we probably mean to
971 // generate a dense loop according to the synthetic tensor (for
972 // invariants and sparse output tensor).
973 if (env.merger().getSynTensorID() == tid) {
974 // Coiterating with an invariant
975 // e.g., out = prod(in[i][j] op invariant);
976 // or a broadcast
977 // e.g., out[i][j] = in[i] (j is undef for input)
978 //
979 // The level of the synthetic tensor is the current loop depth;
980 // the rank of the synthetic tensor equals to number of loops.
981 assert(curr == env.getCurrentDepth());
982 lvl = curr;
983 } else if (!lvl) {
984 // Skips invalid lvl (e.g., when this is a zero ranked tensor).
985 return;
986 }
987 }
988 hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
989 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
990 numloopCond++;
991 } else if (lt.hasDenseSemantic() || isIdxReduc) {
992 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
993 } else {
994 assert(isUndefLT(lt));
995 linalg::GenericOp op = env.op();
996 if (tid >= op.getNumDpsInputs())
997 // We only handle affine expression on input tensors (for now).
998 return;
999 OpOperand *operand = &op->getOpOperand(tid);
1000 const auto stt = getSparseTensorType(val: operand->get());
1001 // Non-annotated dense tensors requires no special handling.
1002 if (!stt.hasEncoding())
1003 return;
1004
1005 ArrayRef<AffineExpr> affines =
1006 op.getMatchingIndexingMap(operand).getResults();
1007 const Level lvlRank = stt.getLvlRank();
1008 assert(affines.size() == static_cast<size_t>(lvlRank));
1009 for (Level l = 0; l < lvlRank; l++) {
1010 AffineExpr exp = affines[l];
1011 // Skip simple affine expression and non-dense levels (which
1012 // have their own filter loop).
1013 LevelType lt = stt.getLvlType(l);
1014 if (isa<AffineDimExpr>(Val: exp) || !lt.hasDenseSemantic())
1015 continue;
1016
1017 // Constant affine expression are handled in genLoop.
1018 if (!isa<AffineConstantExpr>(Val: exp)) {
1019 bool isCurrentLoop = false;
1020 assert(curr == env.getCurrentDepth());
1021 if (isInvariantAffine(a: exp, curr: curr + 1, /*out*/ isCurrentLoop) &&
1022 isCurrentLoop) {
1023 // If the compound affine is invariant and we are right at the
1024 // level. We need to generate the address according to the
1025 // affine expression. This is also the best place we can do it
1026 // to avoid putting it inside inner loops.
1027 callback(env.makeTensorLevel(t: tid, l), exp);
1028 }
1029 }
1030 }
1031 }
1032 });
1033
1034 if (isDenseLT(lt: env.lt(t: outTid, i: curr))) {
1035 auto stt = getSparseTensorType(env.op().getOutputs().front());
1036 // Note that we generate dense indices of the output tensor unconditionally,
1037 // since they may not appear in the lattice, but may be needed for
1038 // linearized env.
1039 // TODO: we should avoid introducing corner cases for all-dense sparse
1040 // tensors.
1041 if (stt.hasEncoding() && stt.isAllDense())
1042 callback(env.makeTensorLevel(t: outTid, l: *outLvl), nullptr);
1043 }
1044
1045 if (numloopCond == 0) {
1046 // Corner cases where the loop bound is defined by a *unused* operand, in
1047 // this case, we just generate a dense "fake" loop by iterating over the
1048 // synthetic tensor.
1049 callback(env.makeTensorLevel(t: env.merger().getSynTensorID(), l: curr), nullptr);
1050 numloopCond++;
1051 }
1052 // If we just need to one loop conditions and the conditions is not imposed on
1053 // non-unique level, the loop can be generated by a for loop.
1054 return numloopCond == 1 && !hasNonUnique;
1055}
1056
1057/// Starts a loop sequence at given level. Returns true if
1058/// the universal loop index must be maintained at this level.
1059static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1060 LoopId curr, LatSetId lts) {
1061 assert(!env.getLoopVar(curr));
1062 // Emit invariants at this loop sequence level.
1063 genInvariants(env, builder, exp, curr, /*isStart=*/true);
1064 // Emit access pattern expansion for sparse tensor output.
1065 genExpand(env, builder, curr, /*isStart=*/true);
1066 // Emit further initialization at this loop sequence level.
1067 const LatPointId l0 = env.set(lts)[0];
1068
1069 SmallVector<TensorLevel> tidLvls;
1070 getAllTidLvlsInLatPoints(env, li: l0, curr, callback: [&](TensorLevel tl, AffineExpr) {
1071 // TODO: remove this! The same tensor level might be added for multiple
1072 // times due to the special handling for all-dense "sparse" output tensor
1073 // (see L1038).
1074 if (llvm::find(Range&: tidLvls, Val: tl) != tidLvls.end())
1075 return;
1076 tidLvls.emplace_back(Args&: tl);
1077 });
1078
1079 env.emitter().enterNewLoopSeq(builder, loc: env.op().getLoc(), tidLvls);
1080
1081 // Maintain the universal index only if it is actually
1082 // consumed by a subsequent lattice point.
1083 for (const LatPointId li : env.set(lts).drop_front())
1084 if (!env.merger().hasAnySparse(bits: env.lat(l: li).simple))
1085 return true;
1086
1087 return false;
1088}
1089
1090// Generates dense affine address for encoding.
1091static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1092 OpBuilder &builder, TensorId tid,
1093 Level startLvl) {
1094 // TODO: Handle affine expression on output tensor.
1095 linalg::GenericOp op = env.op();
1096 assert(tid < op.getNumDpsInputs());
1097 OpOperand *input = op.getDpsInputOperands()[tid];
1098 const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1099 const auto enc = getSparseTensorEncoding(input->get().getType());
1100 if (enc) {
1101 const Location loc = op.getLoc();
1102 const TensorId tid = env.makeTensorId(t: input->getOperandNumber());
1103 const Level lvlRank = enc.getLvlRank();
1104 assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1105 for (Level l = startLvl; l < lvlRank; l++) {
1106 AffineExpr lvlExpr = lvlExprs[l];
1107 if (enc.getLvlType(l).hasDenseSemantic() &&
1108 isa<AffineConstantExpr>(Val: lvlExpr))
1109 env.emitter().locateLvlAtAffineAddress(
1110 builder, loc, tidLvl: env.makeTensorLevel(t: tid, l), lvlExpr);
1111 else
1112 return; // break on first non-dense non-constant level
1113 }
1114 }
1115}
1116
1117// We can generate address for constant affine expression before any loops
1118// starting from the first level as they do not depend on anything.
1119// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1120// levels can be determined before loops.
1121static void genInitConstantDenseAddress(CodegenEnv &env,
1122 RewriterBase &rewriter) {
1123 for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1124 genConstantDenseAddressFromLevel(env, builder&: rewriter, tid, startLvl: 0);
1125}
1126
1127/// Returns true if the lattice bit can be iterated by a for loop.
1128static bool translateBitsToTidLvlPairs(
1129 CodegenEnv &env, LatPointId li, LoopId curr,
1130 SmallVectorImpl<TensorLevel> &tidLvls,
1131 SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1132 return getAllTidLvlsInLatPoints(env, li, curr,
1133 callback: [&](TensorLevel tl, AffineExpr exp) {
1134 if (exp)
1135 affineTidLvls.emplace_back(Args&: tl, Args&: exp);
1136 else
1137 tidLvls.emplace_back(Args&: tl);
1138 });
1139}
1140
1141/// Starts a single loop in current sequence.
1142static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1143 OpBuilder &builder, LoopId curr,
1144 LatPointId li, bool needsUniv) {
1145 // The set of tensors + lvls to generate loops on
1146 SmallVector<TensorLevel> tidLvls;
1147
1148 // The set of dense tensors with non-trivial affine expression that just
1149 // becomes invariant and the address are generated at the current level.
1150 SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
1151 bool isSingleCond =
1152 translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1153
1154 // Emit the for/while-loop control.
1155 Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
1156 Location loc = env.op().getLoc();
1157 for (auto [tidLvl, exp] : affineTidLvls) {
1158 env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, lvlExpr: exp);
1159 }
1160
1161 // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1162 // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1163 // on constant affines expression may now be determined.
1164 auto allTidLvls =
1165 llvm::concat<TensorLevel>(Ranges&: tidLvls, Ranges: llvm::make_first_range(c&: affineTidLvls));
1166 for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1167 if (tid != env.merger().getOutTensorID() &&
1168 tid != env.merger().getSynTensorID())
1169 genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
1170 }
1171
1172 return std::make_pair(x&: loop, y&: isSingleCond);
1173}
1174
1175/// Ends a single loop in current sequence. Returns new values for needsUniv.
1176static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1177 LatPointId li, bool needsUniv, bool isSingleCond) {
1178 // Either a for-loop or a while-loop that iterates over a slice.
1179 if (isSingleCond) {
1180 // Any iteration creates a valid lex insert.
1181 if (env.isReduc() && env.isValidLexInsert())
1182 env.updateValidLexInsert(val: constantI1(rewriter, env.op().getLoc(), true));
1183 } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1184 // End a while-loop.
1185 finalizeWhileOp(env, builder&: rewriter, needsUniv);
1186 } else {
1187 needsUniv = false;
1188 }
1189 env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) {
1190 env.emitter().exitCurrentLoop(rewriter, loc: env.op().getLoc(), reduc);
1191 return std::nullopt;
1192 });
1193 return needsUniv;
1194}
1195
1196/// Ends a loop sequence at given level.
1197static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1198 unsigned at) {
1199 assert(!env.getLoopVar(at));
1200 env.emitter().exitCurrentLoopSeq(builder, loc: env.op().getLoc());
1201 // Unmark bookkeeping of invariants and loop index.
1202 genInvariants(env, builder, exp, curr: at, /*isStart=*/false);
1203 // Finalize access pattern expansion for sparse tensor output.
1204 genExpand(env, builder, curr: at, /*isStart=*/false);
1205}
1206
1207/// Recursively generates code while computing iteration lattices in order
1208/// to manage the complexity of implementing co-iteration over unions
1209/// and intersections of sparse iterations spaces.
1210static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1211 LoopId curr) {
1212 assert(curr == env.getCurrentDepth());
1213
1214 // At each leaf, assign remaining tensor (sub)expression to output tensor.
1215 if (curr == env.getLoopNum()) {
1216 Value rhs = genExp(env, rewriter, e: exp);
1217 genTensorStore(env, builder&: rewriter, exp, rhs);
1218 return;
1219 }
1220
1221 // Construct iteration lattices for current loop index.
1222 const LatSetId lts =
1223 env.merger().optimizeSet(s: env.merger().buildLattices(e: exp, i: curr));
1224
1225 // Start a loop sequence.
1226 bool needsUniv = startLoopSeq(env, builder&: rewriter, exp, curr, lts);
1227
1228 // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1229 // We cannot change this to `for (const LatPointId li : env.set(lts))`
1230 // because the loop body causes data-movement which invalidates
1231 // the iterator.
1232 const unsigned lsize = env.set(lts).size();
1233 for (unsigned i = 0; i < lsize; i++) {
1234 const LatPointId li = env.set(lts)[i];
1235 // Start a loop.
1236 auto [loop, isSingleCond] = startLoop(env, builder&: rewriter, curr, li, needsUniv);
1237
1238 // Visit all lattices points with Li >= Lj to generate the
1239 // loop-body, possibly with if statements for coiteration.
1240 Value redInput = env.getReduc();
1241 Value cntInput = env.getExpandCount();
1242 Value insInput = env.getInsertionChain();
1243 Value validIns = env.getValidLexInsert();
1244 // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1245 // because the loop body causes data-movement which invalidates the
1246 // iterator.
1247 for (unsigned j = 0; j < lsize; j++) {
1248 const LatPointId lj = env.set(lts)[j];
1249 const ExprId ej = env.lat(l: lj).exp;
1250 if (li == lj || env.merger().latGT(p0: li, p1: lj)) {
1251 // Recurse into body of each branch.
1252 if (!isSingleCond) {
1253 scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1254 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1255 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1256 } else {
1257 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1258 }
1259 }
1260 }
1261
1262 // End a loop.
1263 needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond);
1264 }
1265
1266 // End a loop sequence.
1267 endLoopSeq(env, builder&: rewriter, exp, at: curr);
1268 assert(curr == env.getCurrentDepth());
1269}
1270
1271/// Converts the result computed by the sparse kernel into the required form.
1272static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1273 linalg::GenericOp op = env.op();
1274 OpOperand *lhs = op.getDpsInitOperand(0);
1275 Value tensor = lhs->get();
1276 Type resType = tensor.getType();
1277 if (getSparseTensorEncoding(type: resType)) {
1278 // The sparse tensor rematerializes from the original sparse tensor's
1279 // underlying sparse storage format. For an insertion chain, the
1280 // tensor materializes from the chain with 'hasInserts' enabled.
1281 bool hasInserts = false;
1282 if (Value chain = env.getInsertionChain()) {
1283 hasInserts = true;
1284 tensor = chain;
1285 }
1286 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1287 } else {
1288 // To rematerialize an non-annotated tensor, simply load it
1289 // from the bufferized value.
1290 Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1291 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1292 }
1293}
1294
1295//===----------------------------------------------------------------------===//
1296// Sparsifier rewriting methods.
1297//===----------------------------------------------------------------------===//
1298
1299namespace {
1300
1301/// Sparse rewriting rule for generic Lingalg operation.
1302struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1303public:
1304 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1305 : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1306
1307 LogicalResult matchAndRewrite(linalg::GenericOp op,
1308 PatternRewriter &rewriter) const override {
1309 // Only accept single output operations with pure tensor semantics.
1310 if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1311 return failure();
1312
1313 // Only accept trivial affine indices.
1314 if (hasNonTrivialAffineOnSparseOut(op))
1315 return failure();
1316
1317 // Only accept scheduled loops.
1318 if (!op->hasAttr("sorted")) {
1319 return rewriter.notifyMatchFailure(
1320 op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
1321 "before sparsification.");
1322 }
1323
1324 // Must have been demapped as well if the generic op is sorted.
1325 assert(!hasAnyNonIdentityOperandsOrResults(op));
1326
1327 // Sets up a code generation environment.
1328 const unsigned numTensors = op->getNumOperands();
1329 const unsigned numLoops = op.getNumLoops();
1330 bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1331 // If we have indexing map like (d0) -> (0, d0), there might be more
1332 // levels then loops because of the constant index, that means we can not
1333 // use numLoops as the upper bound for ranks of all tensors.
1334 // TODO: Constant indices are currently not support on sparse tensor, but
1335 // are allowed in non-annotated dense tensor. Support it, it would be
1336 // required for sparse tensor slice rank reducing too.
1337 Level maxLvlRank = 0;
1338 for (auto operand : op.getOperands()) {
1339 if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1340 maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
1341 }
1342 }
1343
1344 // Detects sparse annotations and translates the per-level sparsity
1345 // information for all tensors to loop indices in the kernel.
1346 CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1347 if (!findSparseAnnotations(env, idxReducBased: needIdxRed))
1348 return failure();
1349
1350 // Only standard reduction operations (add, sub, or, xor) that can be
1351 // sparsified by merely reducing the stored values are admissible. More
1352 // elaborate reduction operations (such as mul, and, min, max) would need
1353 // to know whether implicit zeros occur as well. They can still be
1354 // implemented with a custom reduction operation, accepted here as well.
1355 if (op.getNumReductionLoops() > 0) {
1356 Operation *yield = op.getRegion().front().getTerminator();
1357 assert(isa<linalg::YieldOp>(yield));
1358 Operation *redop = yield->getOperand(idx: 0).getDefiningOp();
1359 if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1360 !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1361 !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1362 !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1363 !isa<ReduceOp>(redop)) {
1364 return failure();
1365 }
1366 }
1367
1368 // Constructs the tensor expressions tree from `op`, returns failure if the
1369 // tree can not be built or the tensor expression is inadmissible.
1370 if (failed(result: env.initTensorExp()))
1371 return failure();
1372
1373 // Recursively generates code if admissible.
1374 env.startEmit(emitStrategy: options.sparseEmitStrategy);
1375 genBuffers(env, builder&: rewriter);
1376 // TODO: Constant affine expression should be handled differently when using
1377 // slice-based codegen, it does not matter now because we already reject the
1378 // constant expression at an earlier stage.
1379 genInitConstantDenseAddress(env, rewriter);
1380 genStmt(env, rewriter, exp: env.getExprId(), curr: 0);
1381 genResult(env, rewriter);
1382 return success();
1383 }
1384
1385private:
1386 /// Options to control sparse code generation.
1387 SparsificationOptions options;
1388};
1389
1390} // namespace
1391
1392/// Populates the given patterns list with rewriting rules required for
1393/// the sparsification of linear algebra operations.
1394void mlir::populateSparsificationPatterns(
1395 RewritePatternSet &patterns, const SparsificationOptions &options) {
1396 patterns.add<GenericOpSparsifier>(arg: patterns.getContext(), args: options);
1397}
1398

source code of mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp