1//===- Sparsification.cpp - Implementation of sparsification --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements converting sparse tensor types to actual sparse code.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenEnv.h"
14#include "Utils/CodegenUtils.h"
15#include "Utils/LoopEmitter.h"
16
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
19#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/Linalg/IR/Linalg.h"
22#include "mlir/Dialect/Linalg/Utils/Utils.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SCF/Transforms/Transforms.h"
26#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
27#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
28#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
29#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
30#include "mlir/Dialect/Tensor/IR/Tensor.h"
31
32#include <optional>
33
34using namespace mlir;
35using namespace mlir::sparse_tensor;
36
37//===----------------------------------------------------------------------===//
38// Sparsifier analysis methods.
39//===----------------------------------------------------------------------===//
40
41/// Returns true iff affine expression is invariant. Sets the
42/// parameter `isCurrentLoop` when expression just became invariant.
43static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
44 switch (a.getKind()) {
45 case AffineExprKind::DimId: {
46 const LoopId i = cast<AffineDimExpr>(Val&: a).getPosition();
47 if (i + 1 == curr) {
48 isCurrentLoop = true;
49 return true; // becomes invariant at current loop
50 }
51 return i < curr; // invariant when already generated
52 }
53 case AffineExprKind::Add:
54 case AffineExprKind::Mul: {
55 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
56 return isInvariantAffine(a: binOp.getLHS(), curr, isCurrentLoop) &&
57 isInvariantAffine(a: binOp.getRHS(), curr, isCurrentLoop);
58 }
59 default: {
60 assert(isa<AffineConstantExpr>(a));
61 return true;
62 }
63 }
64}
65
66/// Helper method to inspect affine expressions. Rejects cases where the
67/// same index is used more than once. Also rejects compound affine
68/// expressions in sparse dimensions.
69static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
70 LevelType lt, bool setLvlFormat = true) {
71 switch (a.getKind()) {
72 case AffineExprKind::DimId: {
73 const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition());
74 if (!isUndefLT(lt: merger.getLvlType(t: tid, i: idx)))
75 return false; // used more than once
76 if (setLvlFormat)
77 merger.setLevelAndType(t: tid, i: idx, lvl, lt);
78 return true;
79 }
80 case AffineExprKind::Add:
81 case AffineExprKind::Mul:
82 case AffineExprKind::Constant: {
83 assert(lt.hasDenseSemantic());
84 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: a)) {
85 // We do not set dim level format for affine expression like d0 + d1 on
86 // either loop index at d0 or d1. We continue the recursion merely to
87 // check whether current affine is admissible or not.
88 return findAffine(merger, tid, lvl, a: binOp.getLHS(), lt, setLvlFormat: false) &&
89 findAffine(merger, tid, lvl, a: binOp.getRHS(), lt, setLvlFormat: false);
90 }
91 // Falls through when it is a constant Affine
92 return true;
93 }
94 default:
95 return false;
96 }
97}
98
99/// Helper method to inspect affine expressions for index variable reduction
100/// based codegen. It finds the dependent index set for all tensor levels in the
101/// current expression we are generating.
102///
103/// For example, when handling A[i+j][j+k], we build the two way mapping in
104/// merger between (tensor, level) pairs and their dependent index variable set:
105/// A_0 <=> [i, j] and A_1 <=> [j, k]
106///
107/// It rejects cases (returns false)
108/// 1st, when the same index is used more than once, e.g., A[i+j][i]
109/// 2nd, when multiplication is used in the non-trivial index expression.
110/// 3rd, when a constant operand is used in the non-trivial index expression.
111///
112/// TODO: constant should be easy to handle.
113static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
114 AffineExpr a, LevelType lt, bool isSubExp = false,
115 int64_t coefficient = 1) {
116 switch (a.getKind()) {
117 case AffineExprKind::DimId: {
118 // Only allow positive coefficients on AffineDimExpr.
119 if (coefficient <= 0)
120 return false;
121
122 const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition());
123 if (!isUndefLT(lt: merger.getLvlType(t: tensor, i: idx)))
124 return false; // used more than once, e.g., A[i][i]
125
126 // TODO: Generalizes the following two cases. A[i] (with trivial index
127 // expression) can be treated as a special affine index expression. We do
128 // not necessarily need to differentiate them.
129 if (!isSubExp) {
130 assert(coefficient == 1);
131 merger.setLevelAndType(t: tensor, i: idx, lvl, lt);
132 }
133
134 if (isSubExp) {
135 // The current loops appears in more than one affine expressions on the
136 // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
137 // used twice.
138 if (merger.hasDependentLvl(i: idx, t: tensor)) {
139 // TODO: This can be supported by coiterate slices if the loop idx is
140 // appeared on affine index for different tensor, or take slice on
141 // multiple dimensions when it is on the same tensor.
142 // E.g.,
143 // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
144 // d0_1 = getNextSliceOffset t0 along lvl0
145 // d0_2 = getNextSliceOffset t1 along lvl0
146 // if d0_1 == d0_2 then d0 = d0_1 = d0_1
147 // else increase min(d0_1, d0_2).
148 return false;
149 }
150 merger.setLoopDependentTensorLevel(i: idx, t: tensor, lvl, lt, coefficient);
151 }
152 return true;
153 }
154 case AffineExprKind::Constant:
155 case AffineExprKind::Mul: {
156 // TODO: Support index expression like `2 * d0`, we now only support more
157 // complicated cases like `2 * d0 + d1`.
158 if (!isSubExp)
159 return false;
160
161 // TODO: Support Constant AffineExp for slice-based codegen
162 if (isa<AffineConstantExpr>(Val: a))
163 llvm_unreachable("Not yet implemented");
164
165 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
166 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
167 if (isa<AffineConstantExpr>(Val: rhs))
168 std::swap(a&: lhs, b&: rhs);
169 // Must be in form of `constant * d`.
170 assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
171 int64_t coefficient = cast<AffineConstantExpr>(Val&: lhs).getValue();
172 return findDepIdxSet(merger, tensor, lvl, a: rhs, lt, isSubExp, coefficient);
173 }
174 case AffineExprKind::Add: {
175 auto binOp = cast<AffineBinaryOpExpr>(Val&: a);
176 return findDepIdxSet(merger, tensor, lvl, a: binOp.getLHS(), lt, isSubExp: true) &&
177 findDepIdxSet(merger, tensor, lvl, a: binOp.getRHS(), lt, isSubExp: true);
178 }
179 default:
180 return false;
181 }
182}
183
184/// Gets the total number of compound affine expressions in the
185/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
186///
187/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
188///
189/// Returns 1 (because the first level is compressed and its corresponding
190/// indexing-expression is `d0 + d1`)
191static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
192 Value tensor) {
193 // The `tensor` is not guaranteed to have `RankedTensorType`, therefore
194 // we can't use `getRankedTensorType`/`getSparseTensorType` here.
195 // However, we don't need to handle `StorageSpecifierType`, so we
196 // can use `SparseTensorType` once we guard against non-tensors.
197 const auto rtp = dyn_cast<RankedTensorType>(Val: tensor.getType());
198 if (!rtp)
199 return 0;
200 const SparseTensorType stt(rtp);
201
202 const Level lvlRank = stt.getLvlRank();
203 const auto exprs = map.getResults();
204 assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
205 "AffineMap does not have dimension-rank many results");
206 unsigned num = 0;
207 for (Level l = 0; l < lvlRank; l++) {
208 if (!isa<AffineDimExpr>(Val: exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
209 num++;
210 }
211 return num;
212}
213
214/// Gets the total number of sparse levels with compound affine
215/// expressions, summed over all operands of the `GenericOp`.
216static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
217 unsigned num = 0;
218 for (OpOperand &t : op->getOpOperands())
219 num += getNumNonTrivialIdxExpOnSparseLvls(map: op.getMatchingIndexingMap(opOperand: &t),
220 tensor: t.get());
221 return num;
222}
223
224// Returns true iff output has nontrivial affine indices.
225static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
226 OpOperand *out = op.getDpsInitOperand(i: 0);
227 if (getSparseTensorType(val: out->get()).isAllDense())
228 return false;
229 return getNumNonTrivialIdxExpOnSparseLvls(map: op.getMatchingIndexingMap(opOperand: out),
230 tensor: out->get());
231}
232
233/// Helper method to inspect sparse encodings in the tensor types.
234/// Fills the per-dimension sparsity information for all tensors.
235/// Returns true if the sparse annotations and affine subscript
236/// expressions of all tensors are admissible. Returns false if
237/// no annotations are found or inadmissible constructs occur.
238/// We currently support two different ways to handle non-trivial index
239/// expression on sparse tensors, and they accept different affine expressions.
240/// When using dependent index reducton-based approach, it currently only
241/// supports affine addition index expression.
242static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
243 bool annotated = false;
244 for (OpOperand &t : env.op()->getOpOperands()) {
245 const TensorId tid = env.makeTensorId(t: t.getOperandNumber());
246 const auto map = env.op().getMatchingIndexingMap(opOperand: &t);
247 const auto enc = getSparseTensorEncoding(type: t.get().getType());
248 if (enc)
249 annotated = true;
250 const Level lvlRank = map.getNumResults();
251 assert(!enc || lvlRank == enc.getLvlRank());
252 assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
253 // We only need to do index reduction if there is at least one
254 // non-trivial index expression on sparse levels. If all non-trivial
255 // index expression is on dense levels, we can efficiently rely on
256 // the random access to locate the element.
257 bool needIdxReduc =
258 enc && getNumNonTrivialIdxExpOnSparseLvls(map, tensor: t.get()) != 0;
259 // If then current tensor being inspected requires affine index, it need
260 // to be sliced.
261 for (Level l = 0; l < lvlRank; l++) {
262 const AffineExpr a = map.getResult(idx: l);
263 const LevelType lt = enc.getLvlType(l);
264 if (idxReducBased && needIdxReduc) {
265 if (!findDepIdxSet(merger&: env.merger(), tensor: tid, lvl: l, a, lt))
266 return false; // inadmissible affine expression
267 } else {
268 if (!findAffine(merger&: env.merger(), tid, lvl: l, a, lt))
269 return false; // inadmissible affine expression
270 }
271 }
272 }
273 return annotated;
274}
275
276//===----------------------------------------------------------------------===//
277// Sparsifier synthesis methods (statements and expressions).
278//===----------------------------------------------------------------------===//
279
280/// Local bufferization of all dense and sparse data structures.
281static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
282 linalg::GenericOp op = env.op();
283 Location loc = op.getLoc();
284 assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
285
286 SmallVector<Range, 4> loopRange =
287 llvm::cast<linalg::LinalgOp>(Val: op.getOperation())
288 .createLoopRanges(b&: builder, loc);
289
290 env.emitter().initializeLoopEmit(
291 builder, loc,
292 /// Generates buffer for the output tensor.
293 /// Note that all sparse kernels assume that when all elements are written
294 /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
295 /// to all zeroes and only nonzeroes values are computed and written out.
296 /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
297 /// for the updates and no assumption on the original contents of the
298 /// output buffer is necessary.
299 updater: [&op](OpBuilder &builder, Location loc, Value memref,
300 Value tensor) -> Value {
301 // Must not be a sparse tensor.
302 assert(!getSparseTensorEncoding(tensor.getType()));
303 // Two output tensor references should point to the same object.
304 OpOperand *lhs = op.getDpsInitOperand(i: 0);
305 assert(lhs->get() == tensor);
306 // An output tensor can simply materialize from the buffer of the tensor
307 // that appears in the outs() clause. For updates, this has the
308 // advantage that only the nonzero value are involved in the
309 // computation, keeping the operation O(nnz). In all other cases, we are
310 // forced to zero out the buffer to enforce the assumption above, which
311 // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
312 // O(nnz) for matrices).
313 // TODO: use better analysis to avoid zeroing out the buffer?
314 bool isInit = op.isInitTensor(opOperand: lhs);
315 Value init = memref;
316 if (!isInit) {
317 Value zero = constantZero(builder, loc,
318 tp: getElementTypeOrSelf(type: tensor.getType()));
319 builder.create<linalg::FillOp>(location: loc, args: ValueRange{zero},
320 args: ValueRange{init});
321 }
322 return init;
323 },
324 synSetter: [&loopRange](OpBuilder &b, Location loc, Level l) {
325 assert(l < loopRange.size());
326 return mlir::getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange[l].size);
327 });
328}
329
330/// Generates index for load/store on sparse tensor.
331static Value genIndex(CodegenEnv &env, OpOperand *t) {
332 const auto map = env.op().getMatchingIndexingMap(opOperand: t);
333 const auto stt = getSparseTensorType(val: t->get());
334 const Level lvlRank = stt.getLvlRank();
335 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
336 const AffineExpr a = map.getResult(idx: lvlRank - 1);
337 assert(a.getKind() == AffineExprKind::DimId);
338 const LoopId idx = env.makeLoopId(i: cast<AffineDimExpr>(Val: a).getPosition());
339 return env.getLoopVar(i: idx);
340}
341
342/// Generates subscript for load/store on a dense or sparse tensor.
343static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
344 SmallVectorImpl<Value> &args) {
345 const Location loc = env.op().getLoc();
346 const TensorId tid = env.makeTensorId(t: t->getOperandNumber());
347 const auto map = env.op().getMatchingIndexingMap(opOperand: t);
348 const auto stt = getSparseTensorType(val: t->get());
349 if (stt.hasEncoding()) {
350 // For sparse tensors we only push the last-level's position onto `args`.
351 const auto pos = env.emitter().getValPosits(tid);
352 assert(!pos.empty());
353 args.append(RHS: pos);
354 // Simply returns the tensor to extract value using iterators.
355 if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator)
356 return t->get();
357 } else {
358 // For dense tensors we push all level's coordinates onto `args`.
359 const Level lvlRank = stt.getLvlRank();
360 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
361 for (Level l = 0; l < lvlRank; l++) {
362 const auto lvlExpr = map.getResult(idx: l);
363 const auto lvlCrd = env.emitter().genAffine(builder, loc, a: lvlExpr);
364 args.push_back(Elt: lvlCrd);
365 }
366 }
367 return env.emitter().getValBuffer()[tid];
368}
369
370/// Generates insertion code to implement dynamic tensor load.
371static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
372 OpOperand *t) {
373 linalg::GenericOp op = env.op();
374 Location loc = op.getLoc();
375 // Direct lexicographic coordinate order, tensor loads as zero.
376 if (!env.isExpand()) {
377 Type tp = getElementTypeOrSelf(type: t->get().getType());
378 return constantZero(builder, loc, tp);
379 }
380 // Load from expanded access pattern.
381 Value index = genIndex(env, t);
382 return builder.create<memref::LoadOp>(location: loc, args: env.getExpandValues(), args&: index);
383}
384
385/// Generates insertion code to implement dynamic tensor load for reduction.
386static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
387 OpOperand *t) {
388 linalg::GenericOp op = env.op();
389 Location loc = op.getLoc();
390 Value identity = env.getCustomRedId();
391 // Direct lexicographic coordinate order, tensor loads as identity.
392 if (!env.isExpand())
393 return identity;
394 // Load from expanded access pattern if filled, identity otherwise.
395 Value values = env.getExpandValues();
396 Value filled = env.getExpandFilled();
397 Value index = genIndex(env, t);
398 Value isFilled = builder.create<memref::LoadOp>(location: loc, args&: filled, args&: index);
399 Value valAtIndex = builder.create<memref::LoadOp>(location: loc, args&: values, args&: index);
400 return builder.create<arith::SelectOp>(location: loc, args&: isFilled, args&: valAtIndex, args&: identity);
401}
402
403static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
404 Value sparseOut, ValueRange ivs, Value v) {
405 scf::IfOp condInsert =
406 builder.create<scf::IfOp>(location: loc, args: sparseOut.getType(), args&: cond, args: true);
407 // True branch.
408 builder.setInsertionPointToStart(condInsert.thenBlock());
409 Value res = builder.create<tensor::InsertOp>(location: loc, args&: v, args&: sparseOut, args&: ivs);
410 builder.create<scf::YieldOp>(location: loc, args&: res);
411 // False branch.
412 builder.setInsertionPointToStart(condInsert.elseBlock());
413 builder.create<scf::YieldOp>(location: loc, args&: sparseOut);
414 // Value assignment.
415 builder.setInsertionPointAfter(condInsert);
416 return condInsert.getResult(i: 0);
417}
418
419/// Generates insertion code to implement dynamic tensor store.
420static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
421 Value rhs) {
422 linalg::GenericOp op = env.op();
423 Location loc = op.getLoc();
424 // Direct insertion in lexicographic coordinate order.
425 if (!env.isExpand()) {
426 const LoopId numLoops = op.getRank(opOperand: t);
427 // Retrieves the first `numLoop` induction variables.
428 SmallVector<Value> ivs = llvm::to_vector(Range: llvm::drop_end(
429 RangeOrContainer: env.emitter().getLoopIVsRange(), N: env.getCurrentDepth() - numLoops));
430 Value chain = env.getInsertionChain();
431 if (env.isValidLexInsert()) {
432 // Generates runtime check for a valid lex during reduction,
433 // to avoid inserting the identity value for empty reductions.
434 // if (validLexInsert) then
435 // insert(rhs) into chain
436 // return updated chain
437 // else
438 // return unmodified chain
439 Value out = genConditionalInsert(loc, builder, cond: env.getValidLexInsert(),
440 sparseOut: chain, ivs, v: rhs);
441 env.updateInsertionChain(chain: out);
442 } else {
443 Value sparseOut;
444 if (!hasAnySparseType(types: env.op().getInputs().getTypes())) {
445 // This is an all-dense -> sparse kernel, test rhs != 0 before
446 // insertion.
447 Value nz = genIsNonzero(builder, loc, v: rhs);
448 sparseOut = genConditionalInsert(loc, builder, cond: nz, sparseOut: chain, ivs, v: rhs);
449 } else {
450 sparseOut = builder.create<tensor::InsertOp>(location: loc, args&: rhs, args&: chain, args&: ivs);
451 }
452 // Generates regular insertion chain.
453 env.updateInsertionChain(chain: sparseOut);
454 }
455 return;
456 }
457 // Generates insertion code along expanded access pattern.
458 // if (!expFilled[i]) then
459 // expFilled[i] = true
460 // expAdded[inserts++] = i
461 // endif
462 // values[i] = rhs
463 Value values = env.getExpandValues();
464 Value filled = env.getExpandFilled();
465 Value added = env.getExpandAdded();
466 Value count = env.getExpandCount();
467 Value index = genIndex(env, t);
468 Value fval = constantI1(builder, loc, b: false);
469 Value tval = constantI1(builder, loc, b: true);
470 // If statement.
471 Value isFilled = builder.create<memref::LoadOp>(location: loc, args&: filled, args&: index);
472 Value cond = builder.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::eq,
473 args&: isFilled, args&: fval);
474 scf::IfOp ifOp = builder.create<scf::IfOp>(location: loc, args: builder.getIndexType(), args&: cond,
475 /*else=*/args: true);
476 // True branch.
477 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
478 builder.create<memref::StoreOp>(location: loc, args&: tval, args&: filled, args&: index);
479 builder.create<memref::StoreOp>(location: loc, args&: index, args&: added, args&: count);
480 Value one = constantIndex(builder, loc, i: 1);
481 Value add = builder.create<arith::AddIOp>(location: loc, args&: count, args&: one);
482 builder.create<scf::YieldOp>(location: loc, args&: add);
483 // False branch.
484 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
485 builder.create<scf::YieldOp>(location: loc, args&: count);
486 builder.setInsertionPointAfter(ifOp);
487 // Value assignment.
488 env.updateExpandCount(count: ifOp.getResult(i: 0));
489 builder.create<memref::StoreOp>(location: loc, args&: rhs, args&: values, args&: index);
490}
491
492/// Generates a load on a dense or sparse tensor.
493static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
494 // Test if the load was hoisted to a higher loop nest.
495 Value val = env.exp(e: exp).val;
496 if (val)
497 return val;
498 // Get tensor operand.
499 linalg::GenericOp op = env.op();
500 Location loc = op.getLoc();
501 OpOperand *t = &op->getOpOperand(idx: env.exp(e: exp).tensor);
502 // Fold binary-valued tensor into explicit value.
503 const auto stt = getSparseTensorType(val: t->get());
504 if (auto explVal = stt.getExplicitVal())
505 return genValFromAttr(builder, loc, attr: explVal);
506 // Load during insertion.
507 if (env.isSparseOutput(o: t)) {
508 if (env.isCustomReduc())
509 return genInsertionLoadReduce(env, builder, t);
510 return genInsertionLoad(env, builder, t);
511 }
512
513 // Actual load.
514 SmallVector<Value> args;
515 Value ptr = genSubscript(env, builder, t, args);
516 if (llvm::isa<TensorType>(Val: ptr.getType())) {
517 assert(env.options().sparseEmitStrategy ==
518 SparseEmitStrategy::kSparseIterator);
519 return builder.create<ExtractValOp>(location: loc, args&: ptr, args&: llvm::getSingleElement(C&: args));
520 }
521 return builder.create<memref::LoadOp>(location: loc, args&: ptr, args);
522}
523
524/// Generates a store on a dense or sparse tensor.
525static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
526 Value rhs) {
527 // Only unary and binary are allowed to return an uninitialized rhs
528 // to indicate missing output. Or otherwise a custom reduction that
529 // received no value to accumulate.
530 if (!rhs) {
531 assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
532 env.exp(exp).kind == TensorExp::Kind::kBinary ||
533 env.exp(exp).kind == TensorExp::Kind::kReduce);
534 return;
535 }
536 // Test if this is a scalarized reduction.
537 if (env.isReduc()) {
538 env.updateReduc(val: rhs);
539 return;
540 }
541 // Regular store.
542 linalg::GenericOp op = env.op();
543 Location loc = op.getLoc();
544 OpOperand *t = op.getDpsInitOperand(i: 0);
545 if (!env.isSparseOutput(o: t)) {
546 SmallVector<Value> args;
547 Value ptr = genSubscript(env, builder, t, args);
548 builder.create<memref::StoreOp>(location: loc, args&: rhs, args&: ptr, args);
549 return;
550 }
551 // Store during sparse insertion.
552 if (env.exp(e: exp).kind != TensorExp::Kind::kSelect) {
553 genInsertionStore(env, builder, t, rhs);
554 return;
555 }
556 // Select operation insertion.
557 Value chain = env.getInsertionChain();
558 scf::IfOp ifOp =
559 builder.create<scf::IfOp>(location: loc, args: chain.getType(), args&: rhs, /*else=*/args: true);
560 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
561 // Existing value was preserved to be used here.
562 assert(env.exp(exp).val);
563 Value v0 = env.exp(e: exp).val;
564 genInsertionStore(env, builder, t, rhs: v0);
565 env.merger().clearExprValue(e: exp);
566 // Yield modified insertion chain along true branch.
567 Value mchain = env.getInsertionChain();
568 builder.create<scf::YieldOp>(location: op.getLoc(), args&: mchain);
569 // Yield original insertion chain along false branch.
570 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
571 builder.create<scf::YieldOp>(location: loc, args&: chain);
572 // Done with if statement.
573 env.updateInsertionChain(chain: ifOp->getResult(idx: 0));
574 builder.setInsertionPointAfter(ifOp);
575}
576
577/// Generates an invariant value.
578inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
579 return env.exp(e: exp).val;
580}
581
582/// Semi-ring branches are simply inlined by the sparsifier. Prior
583/// analysis has verified that all computations are "local" to the inlined
584/// branch or otherwise invariantly defined outside the loop nest, with the
585/// exception of index computations, which need to be relinked to actual
586/// inlined cloned code.
587static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
588 Value e) {
589 if (auto arg = dyn_cast<BlockArgument>(Val&: e)) {
590 // Direct arguments of the original linalg op must be converted
591 // into dense tensor loads. Note that we should not encounter
592 // anything else. This needs to be verified by semi-ring ops.
593 linalg::GenericOp op = env.op();
594 if (arg.getOwner()->getParentOp() == op) {
595 const TensorId tid = env.makeTensorId(t: arg.getArgNumber());
596 OpOperand *t = &op->getOpOperand(idx: tid);
597 assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
598 SmallVector<Value> args;
599 Value ptr = genSubscript(env, builder&: rewriter, t, args);
600 return rewriter.create<memref::LoadOp>(location: op.getLoc(), args&: ptr, args);
601 }
602 } else if (Operation *def = e.getDefiningOp()) {
603 // Handle index computation.
604 if (auto indexOp = dyn_cast<linalg::IndexOp>(Val: def))
605 return env.getLoopVar(i: env.makeLoopId(i: indexOp.getDim()));
606 // When still defined in new body, recurse into operands.
607 if (def->getBlock() == block) {
608 rewriter.setInsertionPoint(def);
609 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
610 rewriter.modifyOpInPlace(root: def, callable: [&]() {
611 def->setOperand(
612 idx: i, value: relinkBranch(env, rewriter, block, e: def->getOperand(idx: i)));
613 });
614 }
615 }
616 }
617 return e;
618}
619
620/// Recursively generates tensor expression.
621static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
622 if (e == ::mlir::sparse_tensor::detail::kInvalidId)
623 return Value();
624
625 linalg::GenericOp op = env.op();
626 Location loc = op.getLoc();
627 const TensorExp &exp = env.exp(e);
628 const auto kind = exp.kind;
629 if (kind == TensorExp::Kind::kTensor)
630 return genTensorLoad(env, builder&: rewriter, exp: e);
631 if (kind == TensorExp::Kind::kInvariant)
632 return genInvariantValue(env, exp: e);
633 if (kind == TensorExp::Kind::kLoopVar)
634 return env.getLoopVar(i: exp.loop);
635
636 if (kind == TensorExp::Kind::kReduce)
637 env.startCustomReduc(exp: e); // enter custom
638
639 // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
640 // based on the type of the other operand.
641 Value v0, v1;
642 if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
643 env.exp(e: exp.children.e0).kind == TensorExp::Kind::kSynZero) {
644 v1 = genExp(env, rewriter, e: exp.children.e1);
645 v0 = constantZero(builder&: rewriter, loc, tp: v1.getType());
646 } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
647 env.exp(e: exp.children.e1).kind == TensorExp::Kind::kSynZero) {
648 v0 = genExp(env, rewriter, e: exp.children.e0);
649 v1 = constantZero(builder&: rewriter, loc, tp: v0.getType());
650 } else {
651 v0 = genExp(env, rewriter, e: exp.children.e0);
652 v1 = genExp(env, rewriter, e: exp.children.e1);
653 }
654
655 Value ee;
656 if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
657 // custom reduce did not receive a value
658 } else {
659 ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
660 if (ee &&
661 (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
662 kind == TensorExp::Kind::kBinaryBranch ||
663 kind == TensorExp::Kind::kReduce ||
664 kind == TensorExp::Kind::kSelect)) {
665 OpBuilder::InsertionGuard guard(rewriter);
666 ee = relinkBranch(env, rewriter, block: ee.getParentBlock(), e: ee);
667 }
668 }
669
670 if (kind == TensorExp::Kind::kReduce)
671 env.endCustomReduc(); // exit custom
672
673 if (kind == TensorExp::Kind::kSelect)
674 env.merger().setExprValue(e, v: v0); // Preserve value for later use.
675
676 return ee;
677}
678
679/// Hoists loop invariant tensor loads for which indices have been exhausted.
680static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
681 LoopId curr, bool isStart) {
682 if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
683 return;
684 if (env.exp(e: exp).kind == TensorExp::Kind::kTensor) {
685 // Inspect tensor indices.
686 linalg::GenericOp op = env.op();
687 OpOperand &t = op->getOpOperand(idx: env.exp(e: exp).tensor);
688 const auto map = op.getMatchingIndexingMap(opOperand: &t);
689 const auto stt = getSparseTensorType(val: t.get());
690 const Level lvlRank = stt.getLvlRank();
691 assert(static_cast<Level>(map.getNumResults()) == lvlRank);
692 bool isCurrentLoop = curr == 0; // for scalar tensors
693 for (Level l = 0; l < lvlRank; l++) {
694 const AffineExpr a = map.getResult(idx: l);
695 if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop))
696 return; // still in play
697 }
698 // All exhausted at current level.
699 if (!isCurrentLoop)
700 return;
701 // Generate code for a scalarized reduction or invariant. Note that
702 // because custom reduction lhs may occur several times in the IR,
703 // we have a built-in safety for only initializing and wrapping-up
704 // the scalarized reduction once.
705 OpOperand *lhs = op.getDpsInitOperand(i: 0);
706 if (lhs == &t) {
707 // Start or end a scalarized reduction.
708 if (isStart) {
709 if (env.isCustomReduc()) {
710 if (!env.isReduc())
711 env.startReduc(exp, val: env.getCustomRedId());
712 } else {
713 env.startReduc(exp, val: genTensorLoad(env, builder, exp));
714 }
715 if (env.hasSparseOutput())
716 env.startValidLexInsert(
717 val: constantI1(builder, loc: env.op().getLoc(), b: false));
718 } else {
719 if (!env.isCustomReduc() || env.isReduc())
720 genTensorStore(env, builder, exp, rhs: env.endReduc());
721 if (env.hasSparseOutput())
722 env.endValidLexInsert();
723 }
724 } else {
725 // Start or end loop invariant hoisting of a tensor load.
726 if (isStart) {
727 env.merger().setExprValue(e: exp, v: genTensorLoad(env, builder, exp));
728 } else {
729 env.merger().clearExprValue(e: exp);
730 }
731 }
732 } else if (env.exp(e: exp).kind != TensorExp::Kind::kInvariant &&
733 env.exp(e: exp).kind != TensorExp::Kind::kLoopVar &&
734 env.exp(e: exp).kind != TensorExp::Kind::kSynZero) {
735 // Traverse into the binary operations. Note that we only hoist
736 // tensor loads, since subsequent MLIR/LLVM passes know how to
737 // deal with all other kinds of derived loop invariants.
738 if (env.exp(e: exp).kind == TensorExp::Kind::kReduce)
739 env.startCustomReduc(exp); // enter custom
740 const ExprId e0 = env.exp(e: exp).children.e0;
741 const ExprId e1 = env.exp(e: exp).children.e1;
742 genInvariants(env, builder, exp: e0, curr, isStart);
743 genInvariants(env, builder, exp: e1, curr, isStart);
744 if (env.exp(e: exp).kind == TensorExp::Kind::kReduce)
745 env.endCustomReduc(); // exit custom
746 }
747}
748
749/// Generates an expanded access pattern in innermost dimension.
750static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
751 bool isStart) {
752 linalg::GenericOp op = env.op();
753 OpOperand *lhs = op.getDpsInitOperand(i: 0);
754 if (!env.atExpandLevel(o: lhs, rank: op.getRank(opOperand: lhs), n: curr))
755 return; // not needed at current level
756 assert(!env.isReduc());
757 // Generate start or end of an expanded access pattern. Note that because
758 // an expansion does not rely on the ongoing contents of the sparse storage
759 // scheme, we can use the original tensor as incoming SSA value (which
760 // simplifies codegen a bit). If expansion on the actual contents is ever
761 // needed, we will need to use the SSA value in the insertion chain instead.
762 Value tensor = lhs->get();
763 Location loc = op.getLoc();
764 if (isStart) {
765 auto dynShape = {ShapedType::kDynamic};
766 Type etp = cast<ShapedType>(Val: tensor.getType()).getElementType();
767 Type t1 = MemRefType::get(shape: dynShape, elementType: etp);
768 Type t2 = MemRefType::get(shape: dynShape, elementType: builder.getI1Type());
769 Type t3 = MemRefType::get(shape: dynShape, elementType: builder.getIndexType());
770 Type t4 = builder.getIndexType();
771 auto r = builder.create<ExpandOp>(location: loc, args: TypeRange({t1, t2, t3, t4}), args&: tensor);
772 assert(r.getNumResults() == 4);
773 env.startExpand(values: r.getResult(i: 0), filled: r.getResult(i: 1), added: r.getResult(i: 2),
774 count: r.getResult(i: 3));
775 } else {
776 SmallVector<Value> indices;
777 for (LoopId i = 0; i < curr; i++)
778 indices.push_back(Elt: env.emitter().getLoopIV(n: i));
779 Value values = env.getExpandValues();
780 Value filled = env.getExpandFilled();
781 Value added = env.getExpandAdded();
782 Value count = env.getExpandCount();
783 Value chain = env.getInsertionChain();
784 Value compress = builder.create<CompressOp>(location: loc, args&: values, args&: filled, args&: added,
785 args&: count, args&: chain, args&: indices);
786 env.updateInsertionChain(chain: compress);
787 env.endExpand();
788 }
789}
790
791/// Returns parallelization strategy. Any implicit loop in the Linalg
792/// operation that is marked "parallel" is a candidate. Whether it is actually
793/// converted to a parallel operation depends on the requested strategy.
794static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
795 // Reject parallelization of sparse output.
796 if (env.hasSparseOutput())
797 return false;
798 // Parallel loops on tensor expansion can cause data races.
799 if (env.isExpand())
800 return false;
801 // Inspect strategy.
802 switch (env.options().parallelizationStrategy) {
803 case SparseParallelizationStrategy::kNone:
804 return false;
805 case SparseParallelizationStrategy::kDenseOuterLoop:
806 return isOuter && !isSparse;
807 case SparseParallelizationStrategy::kAnyStorageOuterLoop:
808 return isOuter;
809 case SparseParallelizationStrategy::kDenseAnyLoop:
810 return !isSparse;
811 case SparseParallelizationStrategy::kAnyStorageAnyLoop:
812 return true;
813 }
814 llvm_unreachable("unexpected parallelization strategy");
815}
816
817/// Whether or not the current loop being generated should be parallized (if
818/// possible) according to the configuration.
819static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
820 ArrayRef<TensorLevel> tidLvls) {
821 linalg::GenericOp op = env.op();
822 auto iteratorTypes = op.getIteratorTypesArray();
823 bool isSparse = llvm::any_of(Range&: tidLvls, P: [curr, &env](TensorLevel tidLvl) {
824 // Queries the LT based on the tensor and loop id, as requested by
825 // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
826 // should be consistent with the LT indexed by <TensorId, Level>.
827 const auto lt = env.lt(t: env.unpackTensorLevel(tl: tidLvl).first, i: curr);
828 return lt.hasSparseSemantic();
829 });
830 return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
831}
832
833/// Emit a loop to coiterate over the list of tensor levels. The generated loop
834/// can either be a for loop or while loop depending on whether there is at most
835/// one sparse level in the list.
836static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
837 ArrayRef<TensorLevel> tidLvls,
838 unsigned numCases, bool tryParallel,
839 bool needsUniv) {
840 Operation *loop = *env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) {
841 // Construct while-loop with a parameter for each index.
842 return env.emitter().enterCoIterationOverTensorsAtLvls(
843 builder, loc: env.op().getLoc(), tidLvls, numCases, reduc, isParallel: tryParallel,
844 needsUniv);
845 });
846 assert(loop);
847 return loop;
848}
849
850/// Generates a for-loop or a while-loop, depending on whether it implements
851/// singleton iteration or co-iteration over the given conjunction.
852static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
853 unsigned numCases, bool needsUniv,
854 ArrayRef<TensorLevel> tidLvls) {
855 bool tryParallel = shouldTryParallize(env, curr, tidLvls);
856 return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
857 needsUniv);
858}
859
860/// Generates the induction structure for a while-loop.
861static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
862 bool needsUniv) {
863 Location loc = env.op().getLoc();
864 // Finalize each else branch of all if statements.
865 if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
866 while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
867 Val: builder.getInsertionBlock()->getParentOp())) {
868 // Break on IfOp for slicing filtering.
869 if (ifOp->getAttr(name: LoopEmitter::getLoopEmitterLoopAttrName()) ==
870 StringAttr::get(context: ifOp->getContext(), bytes: "slice"))
871 break;
872
873 unsigned y = 0;
874 SmallVector<Value> yields;
875 if (env.isReduc()) {
876 yields.push_back(Elt: env.getReduc());
877 env.updateReduc(val: ifOp.getResult(i: y++));
878 if (env.isValidLexInsert()) {
879 yields.push_back(Elt: env.getValidLexInsert());
880 env.updateValidLexInsert(val: ifOp.getResult(i: y++));
881 }
882 }
883 if (env.isExpand()) {
884 yields.push_back(Elt: env.getExpandCount());
885 env.updateExpandCount(count: ifOp->getResult(idx: y++));
886 }
887 if (env.getInsertionChain()) {
888 yields.push_back(Elt: env.getInsertionChain());
889 env.updateInsertionChain(chain: ifOp->getResult(idx: y++));
890 }
891 assert(y == yields.size());
892 builder.create<scf::YieldOp>(location: loc, args&: yields);
893 builder.setInsertionPointAfter(ifOp);
894 }
895 }
896 // No need to set the insertion point here as LoopEmitter keeps track of the
897 // basic block where scf::Yield should be inserted.
898}
899
900/// Generates a case region in the coiterate operation.
901static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
902 unsigned caseIdx, LatPointId allCase,
903 LatPointId curCase,
904 MutableArrayRef<Value> reduc) {
905 assert(allCase == curCase || env.merger().latGT(allCase, curCase));
906 const BitVector &allCaseBits = env.merger().lat(p: allCase).simple;
907 const BitVector &curCaseBits = env.merger().lat(p: curCase).simple;
908
909 /// Computes the subset of iterators that are valid in the current case being
910 /// generated.
911 I64BitSet caseBit(0);
912 for (auto [idx, set] : llvm::enumerate(First: allCaseBits.set_bits()))
913 if (curCaseBits.test(Idx: set))
914 caseBit.set(idx);
915
916 env.emitter().enterCurrentCoIterationCase(builder, loc: env.op().getLoc(), caseBit,
917 caseIdx, reduc);
918}
919
920/// Generates a single if-statement within a while-loop.
921static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
922 LatPointId p) {
923 Location loc = env.op().getLoc();
924 SmallVector<Type> types;
925 Value cond;
926 env.merger().foreachTensorLoopId(
927 p, /*simple=*/true,
928 callback: [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
929 bool isIdxRed) {
930 if (isIdxRed) {
931 // Since there is no 1:1 mapping from loop to level (multiple loops
932 // are required to resolve one level with non-trivial index
933 // expression), we need to reconstruct the tensor level types if this
934 // loop requires index reduction condition.
935 assert(lvl.has_value() && isUndefLT(lt));
936 auto stt = getSparseTensorType(val: env.op().getInputs()[tid]);
937 lt = stt.getLvlType(l: *lvl);
938 }
939 assert(curr == env.merger().loop(b));
940 Value clause;
941 if (lt.hasSparseSemantic()) {
942 assert(lvl.has_value());
943 const Value crd = env.emitter().getCoord(tid, lvl: *lvl);
944 const Value lvar = env.getLoopVar(i: curr);
945 clause = builder.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::eq,
946 args: crd, args: lvar);
947 } else {
948 assert(lt.hasDenseSemantic() || isUndefLT(lt));
949 clause = constantI1(builder, loc, b: true);
950 }
951 cond = cond ? builder.create<arith::AndIOp>(location: loc, args&: cond, args&: clause) : clause;
952 });
953 if (env.isReduc()) {
954 types.push_back(Elt: env.getReduc().getType());
955 if (env.isValidLexInsert())
956 types.push_back(Elt: env.getValidLexInsert().getType());
957 }
958 if (env.isExpand())
959 types.push_back(Elt: builder.getIndexType());
960 if (env.getInsertionChain())
961 types.push_back(Elt: env.getInsertionChain().getType());
962 scf::IfOp ifOp = builder.create<scf::IfOp>(location: loc, args&: types, args&: cond, /*else=*/args: true);
963 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
964 return ifOp;
965}
966
967/// Generates end of true branch of if-statement within a while-loop.
968static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
969 Value redInput, Value cntInput, Value insInput,
970 Value validIns) {
971 SmallVector<Value> operands;
972 if (env.isReduc()) {
973 operands.push_back(Elt: env.getReduc());
974 env.updateReduc(val: redInput);
975 if (env.isValidLexInsert()) {
976 // Any overlapping indices during a reduction creates a valid lex insert.
977 operands.push_back(Elt: constantI1(builder, loc: env.op().getLoc(), b: true));
978 env.updateValidLexInsert(val: validIns);
979 }
980 }
981 if (env.isExpand()) {
982 operands.push_back(Elt: env.getExpandCount());
983 env.updateExpandCount(count: cntInput);
984 }
985 if (env.getInsertionChain()) {
986 operands.push_back(Elt: env.getInsertionChain());
987 env.updateInsertionChain(chain: insInput);
988 }
989 if (!operands.empty())
990 builder.create<scf::YieldOp>(location: env.op().getLoc(), args&: operands);
991 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
992}
993
994//===----------------------------------------------------------------------===//
995// Sparsifier synthesis methods (loop sequence).
996//===----------------------------------------------------------------------===//
997
998static bool getAllTidLvlsInLatPoints(
999 CodegenEnv &env, LatPointId li, LoopId curr,
1000 llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1001 const BitVector &simple = env.lat(l: li).simple;
1002 const TensorId outTid = env.merger().getOutTensorID();
1003 const std::optional<Level> outLvl = env.merger().getLvl(t: outTid, i: curr);
1004
1005 unsigned numloopCond = 0;
1006 bool hasNonUnique = false;
1007 env.merger().foreachTensorLoopId(
1008 p: li, callback: [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
1009 LevelType lt, bool isIdxReduc) {
1010 if (simple[b]) {
1011 if (isIdxReduc) {
1012 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
1013 numloopCond++;
1014 return;
1015 }
1016 if (isUndefLT(lt)) {
1017 // An undefined lt in the lattices, we probably mean to
1018 // generate a dense loop according to the synthetic tensor (for
1019 // invariants and sparse output tensor).
1020 if (env.merger().getSynTensorID() == tid) {
1021 // Coiterating with an invariant
1022 // e.g., out = prod(in[i][j] op invariant);
1023 // or a broadcast
1024 // e.g., out[i][j] = in[i] (j is undef for input)
1025 //
1026 // The level of the synthetic tensor is the current loop depth;
1027 // the rank of the synthetic tensor equals to number of loops.
1028 assert(curr == env.getCurrentDepth());
1029 lvl = curr;
1030 } else if (!lvl) {
1031 // Skips invalid lvl (e.g., when this is a zero ranked tensor).
1032 return;
1033 }
1034 }
1035 hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1036 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
1037 numloopCond++;
1038 } else if (lt.hasDenseSemantic() || isIdxReduc) {
1039 callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr);
1040 } else {
1041 assert(isUndefLT(lt));
1042 linalg::GenericOp op = env.op();
1043 if (tid >= op.getNumDpsInputs())
1044 // We only handle affine expression on input tensors (for now).
1045 return;
1046 OpOperand *operand = &op->getOpOperand(idx: tid);
1047 const auto stt = getSparseTensorType(val: operand->get());
1048 // Non-annotated dense tensors requires no special handling.
1049 if (!stt.hasEncoding())
1050 return;
1051
1052 ArrayRef<AffineExpr> affines =
1053 op.getMatchingIndexingMap(opOperand: operand).getResults();
1054 const Level lvlRank = stt.getLvlRank();
1055 assert(affines.size() == static_cast<size_t>(lvlRank));
1056 for (Level l = 0; l < lvlRank; l++) {
1057 AffineExpr exp = affines[l];
1058 // Skip simple affine expression and non-dense levels (which
1059 // have their own filter loop).
1060 LevelType lt = stt.getLvlType(l);
1061 if (isa<AffineDimExpr>(Val: exp) || !lt.hasDenseSemantic())
1062 continue;
1063
1064 // Constant affine expression are handled in genLoop.
1065 if (!isa<AffineConstantExpr>(Val: exp)) {
1066 bool isCurrentLoop = false;
1067 assert(curr == env.getCurrentDepth());
1068 if (isInvariantAffine(a: exp, curr: curr + 1, /*out*/ isCurrentLoop) &&
1069 isCurrentLoop) {
1070 // If the compound affine is invariant and we are right at the
1071 // level. We need to generate the address according to the
1072 // affine expression. This is also the best place we can do it
1073 // to avoid putting it inside inner loops.
1074 callback(env.makeTensorLevel(t: tid, l), exp);
1075 }
1076 }
1077 }
1078 }
1079 });
1080
1081 if (isDenseLT(lt: env.lt(t: outTid, i: curr))) {
1082 auto stt = getSparseTensorType(val: env.op().getOutputs().front());
1083 // Note that we generate dense indices of the output tensor unconditionally,
1084 // since they may not appear in the lattice, but may be needed for
1085 // linearized env.
1086 // TODO: we should avoid introducing corner cases for all-dense sparse
1087 // tensors.
1088 if (stt.hasEncoding() && stt.isAllDense())
1089 callback(env.makeTensorLevel(t: outTid, l: *outLvl), nullptr);
1090 }
1091
1092 if (numloopCond == 0) {
1093 // Corner cases where the loop bound is defined by a *unused* operand, in
1094 // this case, we just generate a dense "fake" loop by iterating over the
1095 // synthetic tensor.
1096 callback(env.makeTensorLevel(t: env.merger().getSynTensorID(), l: curr), nullptr);
1097 numloopCond++;
1098 }
1099 // If we just need to one loop conditions and the conditions is not imposed on
1100 // non-unique level, the loop can be generated by a for loop.
1101 // Or, if we are generating sparse-iterator-based loops, we always generate
1102 // `sparse_tensor.iterate` regardless whether the level is unique or not.
1103 return numloopCond == 1 &&
1104 (!hasNonUnique || env.options().sparseEmitStrategy ==
1105 SparseEmitStrategy::kSparseIterator);
1106}
1107
1108/// Starts a loop sequence at given level. Returns true if
1109/// the universal loop index must be maintained at this level.
1110static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1111 LoopId curr, LatSetId lts) {
1112 assert(!env.getLoopVar(curr));
1113 // Emit invariants at this loop sequence level.
1114 genInvariants(env, builder, exp, curr, /*isStart=*/true);
1115 // Emit access pattern expansion for sparse tensor output.
1116 genExpand(env, builder, curr, /*isStart=*/true);
1117 // Emit further initialization at this loop sequence level.
1118 const LatPointId l0 = env.set(lts)[0];
1119
1120 SmallVector<TensorLevel> tidLvls;
1121 getAllTidLvlsInLatPoints(env, li: l0, curr, callback: [&](TensorLevel tl, AffineExpr) {
1122 // TODO: remove this! The same tensor level might be added for multiple
1123 // times due to the special handling for all-dense "sparse" output tensor
1124 // (see L1038).
1125 if (llvm::is_contained(Range&: tidLvls, Element: tl))
1126 return;
1127 tidLvls.emplace_back(Args&: tl);
1128 });
1129
1130 env.emitter().enterNewLoopSeq(builder, loc: env.op().getLoc(), tidLvls);
1131
1132 // Maintain the universal index only if it is actually
1133 // consumed by a subsequent lattice point.
1134 for (const LatPointId li : env.set(lts).drop_front())
1135 if (!env.merger().hasAnySparse(bits: env.lat(l: li).simple))
1136 return true;
1137
1138 return false;
1139}
1140
1141// Generates dense affine address for encoding.
1142static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1143 OpBuilder &builder, TensorId tid,
1144 Level startLvl) {
1145 // TODO: Handle affine expression on output tensor.
1146 linalg::GenericOp op = env.op();
1147 assert(tid < op.getNumDpsInputs());
1148 OpOperand *input = op.getDpsInputOperands()[tid];
1149 const auto lvlExprs = op.getMatchingIndexingMap(opOperand: input).getResults();
1150 const auto enc = getSparseTensorEncoding(type: input->get().getType());
1151 if (enc) {
1152 const Location loc = op.getLoc();
1153 const TensorId tid = env.makeTensorId(t: input->getOperandNumber());
1154 const Level lvlRank = enc.getLvlRank();
1155 assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1156 for (Level l = startLvl; l < lvlRank; l++) {
1157 AffineExpr lvlExpr = lvlExprs[l];
1158 if (enc.getLvlType(l).hasDenseSemantic() &&
1159 isa<AffineConstantExpr>(Val: lvlExpr))
1160 env.emitter().locateLvlAtAffineAddress(
1161 builder, loc, tidLvl: env.makeTensorLevel(t: tid, l), lvlExpr);
1162 else
1163 return; // break on first non-dense non-constant level
1164 }
1165 }
1166}
1167
1168// We can generate address for constant affine expression before any loops
1169// starting from the first level as they do not depend on anything.
1170// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1171// levels can be determined before loops.
1172static void genInitConstantDenseAddress(CodegenEnv &env,
1173 RewriterBase &rewriter) {
1174 for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1175 genConstantDenseAddressFromLevel(env, builder&: rewriter, tid, startLvl: 0);
1176}
1177
1178/// Returns true if the lattice bit can be iterated by a for loop.
1179static bool translateBitsToTidLvlPairs(
1180 CodegenEnv &env, LatPointId li, LoopId curr,
1181 SmallVectorImpl<TensorLevel> &tidLvls,
1182 SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1183 return getAllTidLvlsInLatPoints(env, li, curr,
1184 callback: [&](TensorLevel tl, AffineExpr exp) {
1185 if (exp)
1186 affineTidLvls.emplace_back(Args&: tl, Args&: exp);
1187 else
1188 tidLvls.emplace_back(Args&: tl);
1189 });
1190}
1191
1192/// Starts a single loop in current sequence.
1193static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1194 OpBuilder &builder, LoopId curr,
1195 LatPointId li, unsigned numCases,
1196 bool needsUniv) {
1197 // TODO: numCases only used when generating iterator-based loops. Cleanup
1198 // after fully migration.
1199 // The set of tensors + lvls to generate loops on
1200 SmallVector<TensorLevel> tidLvls;
1201
1202 // The set of dense tensors with non-trivial affine expression that just
1203 // becomes invariant and the address are generated at the current level.
1204 SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
1205 bool isSingleCond =
1206 translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
1207
1208 // Emit the for/while-loop control.
1209 Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1210 Location loc = env.op().getLoc();
1211 for (auto [tidLvl, exp] : affineTidLvls) {
1212 env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, lvlExpr: exp);
1213 }
1214
1215 // Until now, we have entered every <tid, lvl> pair in {cond, extra,
1216 // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
1217 // on constant affines expression may now be determined.
1218 auto allTidLvls =
1219 llvm::concat<TensorLevel>(Ranges&: tidLvls, Ranges: llvm::make_first_range(c&: affineTidLvls));
1220 for (auto [tid, lvl] : env.unpackTensorLevelRange(c&: allTidLvls)) {
1221 if (tid != env.merger().getOutTensorID() &&
1222 tid != env.merger().getSynTensorID())
1223 genConstantDenseAddressFromLevel(env, builder, tid, startLvl: lvl + 1);
1224 }
1225
1226 return std::make_pair(x&: loop, y&: isSingleCond);
1227}
1228
1229/// Ends a single loop in current sequence. Returns new values for needsUniv.
1230static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1231 LatPointId li, bool needsUniv, bool isSingleCond) {
1232 // Either a for-loop or a while-loop that iterates over a slice.
1233 if (isSingleCond) {
1234 // Any iteration creates a valid lex insert.
1235 if (env.isReduc() && env.isValidLexInsert())
1236 env.updateValidLexInsert(val: constantI1(builder&: rewriter, loc: env.op().getLoc(), b: true));
1237 } else if (auto whileOp = dyn_cast<scf::WhileOp>(Val: loop)) {
1238 // End a while-loop.
1239 finalizeWhileOp(env, builder&: rewriter, needsUniv);
1240 } else {
1241 needsUniv = false;
1242 }
1243 env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) {
1244 env.emitter().exitCurrentLoop(rewriter, loc: env.op().getLoc(), reduc);
1245 return std::nullopt;
1246 });
1247 return needsUniv;
1248}
1249
1250/// Ends a loop sequence at given level.
1251static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1252 unsigned at) {
1253 assert(!env.getLoopVar(at));
1254 env.emitter().exitCurrentLoopSeq(builder, loc: env.op().getLoc());
1255 // Unmark bookkeeping of invariants and loop index.
1256 genInvariants(env, builder, exp, curr: at, /*isStart=*/false);
1257 // Finalize access pattern expansion for sparse tensor output.
1258 genExpand(env, builder, curr: at, /*isStart=*/false);
1259}
1260
1261/// Recursively generates code while computing iteration lattices in order
1262/// to manage the complexity of implementing co-iteration over unions
1263/// and intersections of sparse iterations spaces.
1264static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1265 LoopId curr) {
1266 assert(curr == env.getCurrentDepth());
1267
1268 // At each leaf, assign remaining tensor (sub)expression to output tensor.
1269 if (curr == env.getLoopNum()) {
1270 Value rhs = genExp(env, rewriter, e: exp);
1271 genTensorStore(env, builder&: rewriter, exp, rhs);
1272 return;
1273 }
1274
1275 // Construct iteration lattices for current loop index.
1276 const LatSetId lts =
1277 env.merger().optimizeSet(s: env.merger().buildLattices(e: exp, i: curr));
1278
1279 // Start a loop sequence.
1280 bool needsUniv = startLoopSeq(env, builder&: rewriter, exp, curr, lts);
1281
1282 // When using sparse-iterator-based loops, we only need one loops, as
1283 // opposed to a loop sequence, to cover all the iterator spaces.
1284 const unsigned lsize = env.set(lts).size();
1285 if (env.generatingSparseIterator()) {
1286 // Get the largest lattice point and start a loop.
1287 const LatPointId li = env.set(lts)[0];
1288 auto [loop, isSingleCond] =
1289 startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv);
1290 assert(isSingleCond == llvm::isa<IterateOp>(loop));
1291 // We cannot change this to `for (const LatPointId li : env.set(lts))`
1292 // because the loop body causes data-movement which invalidates
1293 // the iterator.
1294 for (unsigned j = 0; j < lsize; j++) {
1295 const LatPointId lj = env.set(lts)[j];
1296 const ExprId ej = env.lat(l: lj).exp;
1297 // Recurse into body of each branch.
1298 if (!isSingleCond) {
1299 env.genLoopBoundary(callback: [&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1300 genCoIterationCase(env, builder&: rewriter, /*caseIdx*/ j, allCase: li, curCase: lj, reduc);
1301 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1302 // TODO: handle yield values.
1303 assert(reduc.empty() && "Not Implemented");
1304 rewriter.create<sparse_tensor::YieldOp>(location: env.op().getLoc());
1305 return std::nullopt;
1306 });
1307 // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1308 } else {
1309 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1310 }
1311 }
1312 // End a loop.
1313 needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond);
1314 } else {
1315 // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1316 for (unsigned i = 0; i < lsize; i++) {
1317 const LatPointId li = env.set(lts)[i];
1318 // Start a loop.
1319 auto [loop, isSingleCond] =
1320 startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv);
1321
1322 // Visit all lattices points with Li >= Lj to generate the
1323 // loop-body, possibly with if statements for coiteration.
1324 Value redInput = env.getReduc();
1325 Value cntInput = env.getExpandCount();
1326 Value insInput = env.getInsertionChain();
1327 Value validIns = env.getValidLexInsert();
1328 // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1329 // because the loop body causes data-movement which invalidates the
1330 // iterator.
1331 for (unsigned j = 0; j < lsize; j++) {
1332 const LatPointId lj = env.set(lts)[j];
1333 const ExprId ej = env.lat(l: lj).exp;
1334 if (li == lj || env.merger().latGT(p0: li, p1: lj)) {
1335 // Recurse into body of each branch.
1336 if (!isSingleCond) {
1337 scf::IfOp ifOp = genIf(env, builder&: rewriter, curr, p: lj);
1338 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1339 endIf(env, builder&: rewriter, ifOp, redInput, cntInput, insInput, validIns);
1340 } else {
1341 genStmt(env, rewriter, exp: ej, curr: curr + 1);
1342 }
1343 }
1344 }
1345
1346 // End a loop.
1347 needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond);
1348 }
1349 }
1350
1351 // End a loop sequence.
1352 endLoopSeq(env, builder&: rewriter, exp, at: curr);
1353 assert(curr == env.getCurrentDepth());
1354}
1355
1356/// Converts the result computed by the sparse kernel into the required form.
1357static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
1358 linalg::GenericOp op = env.op();
1359 OpOperand *lhs = op.getDpsInitOperand(i: 0);
1360 Value tensor = lhs->get();
1361 Type resType = tensor.getType();
1362 if (getSparseTensorEncoding(type: resType)) {
1363 // The sparse tensor rematerializes from the original sparse tensor's
1364 // underlying sparse storage format. For an insertion chain, the
1365 // tensor materializes from the chain with 'hasInserts' enabled.
1366 bool hasInserts = false;
1367 if (Value chain = env.getInsertionChain()) {
1368 hasInserts = true;
1369 tensor = chain;
1370 }
1371 rewriter.replaceOpWithNewOp<LoadOp>(op, args&: resType, args&: tensor, args&: hasInserts);
1372 } else {
1373 // To rematerialize an non-annotated tensor, simply load it
1374 // from the bufferized value.
1375 Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
1376 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, args&: resType, args&: val);
1377 }
1378}
1379
1380//===----------------------------------------------------------------------===//
1381// Sparsifier rewriting methods.
1382//===----------------------------------------------------------------------===//
1383
1384namespace {
1385
1386/// Sparse rewriting rule for generic Lingalg operation.
1387struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1388public:
1389 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1390 : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1391
1392 LogicalResult matchAndRewrite(linalg::GenericOp op,
1393 PatternRewriter &rewriter) const override {
1394 // Only accept single output operations with pure tensor semantics.
1395 if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1396 return failure();
1397
1398 // Only accept trivial affine indices.
1399 if (hasNonTrivialAffineOnSparseOut(op))
1400 return failure();
1401
1402 // Only accept scheduled loops.
1403 if (!op->hasAttr(name: "sorted")) {
1404 return rewriter.notifyMatchFailure(
1405 arg&: op, msg: "Loops not yet scheduled, try run --sparse-reinterpret-map "
1406 "before sparsification.");
1407 }
1408
1409 // Must have been demapped as well if the generic op is sorted.
1410 assert(!hasAnyNonIdentityOperandsOrResults(op));
1411
1412 // Sets up a code generation environment.
1413 const unsigned numTensors = op->getNumOperands();
1414 const unsigned numLoops = op.getNumLoops();
1415 bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
1416 // If we have indexing map like (d0) -> (0, d0), there might be more
1417 // levels then loops because of the constant index, that means we can not
1418 // use numLoops as the upper bound for ranks of all tensors.
1419 // TODO: Constant indices are currently not support on sparse tensor, but
1420 // are allowed in non-annotated dense tensor. Support it, it would be
1421 // required for sparse tensor slice rank reducing too.
1422 Level maxLvlRank = 0;
1423 for (auto operand : op.getOperands()) {
1424 if (auto rtp = dyn_cast<RankedTensorType>(Val: operand.getType())) {
1425 maxLvlRank = std::max(a: maxLvlRank, b: SparseTensorType(rtp).getLvlRank());
1426 }
1427 }
1428
1429 // Detects sparse annotations and translates the per-level sparsity
1430 // information for all tensors to loop indices in the kernel.
1431 CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
1432 if (!findSparseAnnotations(env, idxReducBased: needIdxRed))
1433 return failure();
1434
1435 // Only standard reduction operations (add, sub, or, xor) that can be
1436 // sparsified by merely reducing the stored values are admissible. More
1437 // elaborate reduction operations (such as mul, and, min, max) would need
1438 // to know whether implicit zeros occur as well. They can still be
1439 // implemented with a custom reduction operation, accepted here as well.
1440 if (op.getNumReductionLoops() > 0) {
1441 Operation *yield = op.getRegion().front().getTerminator();
1442 assert(isa<linalg::YieldOp>(yield));
1443 Operation *redop = yield->getOperand(idx: 0).getDefiningOp();
1444 if (!isa<arith::AddFOp>(Val: redop) && !isa<complex::AddOp>(Val: redop) &&
1445 !isa<arith::AddIOp>(Val: redop) && !isa<arith::SubFOp>(Val: redop) &&
1446 !isa<complex::SubOp>(Val: redop) && !isa<arith::SubIOp>(Val: redop) &&
1447 !isa<arith::OrIOp>(Val: redop) && !isa<arith::XOrIOp>(Val: redop) &&
1448 !isa<ReduceOp>(Val: redop)) {
1449 return failure();
1450 }
1451 }
1452
1453 // Constructs the tensor expressions tree from `op`, returns failure if the
1454 // tree can not be built or the tensor expression is inadmissible.
1455 if (failed(Result: env.initTensorExp()))
1456 return failure();
1457
1458 // Recursively generates code if admissible.
1459 env.startEmit(emitStrategy: options.sparseEmitStrategy);
1460 genBuffers(env, builder&: rewriter);
1461 // TODO: Constant affine expression should be handled differently when using
1462 // slice-based codegen, it does not matter now because we already reject the
1463 // constant expression at an earlier stage.
1464 genInitConstantDenseAddress(env, rewriter);
1465 genStmt(env, rewriter, exp: env.getExprId(), curr: 0);
1466 genResult(env, rewriter);
1467 return success();
1468 }
1469
1470private:
1471 /// Options to control sparse code generation.
1472 SparsificationOptions options;
1473};
1474
1475} // namespace
1476
1477/// Populates the given patterns list with rewriting rules required for
1478/// the sparsification of linear algebra operations.
1479void mlir::populateSparsificationPatterns(
1480 RewritePatternSet &patterns, const SparsificationOptions &options) {
1481 patterns.add<GenericOpSparsifier>(arg: patterns.getContext(), args: options);
1482}
1483

source code of mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp