1//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting rules that are specific to sparse tensors.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenUtils.h"
14#include "Utils/LoopEmitter.h"
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/Linalg/Utils/Utils.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
25#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
26#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27#include "mlir/Dialect/Tensor/IR/Tensor.h"
28#include "mlir/Dialect/Vector/IR/VectorOps.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/Support/LLVM.h"
32
33using namespace mlir;
34using namespace mlir::bufferization;
35using namespace mlir::linalg;
36using namespace mlir::sparse_tensor;
37
38//===---------------------------------------------------------------------===//
39// Helper methods for the actual rewriting rules.
40//===---------------------------------------------------------------------===//
41
42// Helper method to match any typed zero.
43static bool isZeroValue(Value val) {
44 return matchPattern(value: val, pattern: m_Zero()) || matchPattern(value: val, pattern: m_AnyZeroFloat());
45}
46
47// Helper to detect a sparse tensor type operand.
48static bool isSparseTensor(Value v) {
49 auto enc = getSparseTensorEncoding(v.getType());
50 return enc && !llvm::all_of(enc.getLvlTypes(),
51 [](auto lt) { return lt == LevelFormat::Dense; });
52}
53static bool isSparseTensor(OpOperand *op) { return isSparseTensor(v: op->get()); }
54
55// Helper method to find zero/uninitialized tensor materialization.
56static bool isMaterializing(OpOperand *op, bool isZero) {
57 Value val = op->get();
58 // Check allocation, with zero alloc when required.
59 if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60 Value copy = alloc.getCopy();
61 if (isZero)
62 return copy && isZeroValue(val: copy);
63 return !copy;
64 }
65 // Check for empty tensor materialization.
66 if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67 return !isZero;
68 // Last resort for zero alloc: the whole value is zero.
69 return isZero && isZeroValue(val);
70}
71
72// Helper to detect sampling operation.
73static bool isSampling(GenericOp op) {
74 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77 // Both scalar input arguments used exactly once.
78 Value s1 = op.getBlock()->getArgument(0);
79 Value s2 = op.getBlock()->getArgument(1);
80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82 }
83 }
84 return false;
85}
86
87// Helper to detect chain of multiplications that do not involve x.
88static bool isMulChain(Value val, Value x) {
89 if (auto arg = dyn_cast<BlockArgument>(Val&: val))
90 return arg != x;
91 if (auto *def = val.getDefiningOp()) {
92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93 return isMulChain(val: def->getOperand(idx: 0), x) &&
94 isMulChain(val: def->getOperand(idx: 1), x);
95 }
96 return false;
97}
98
99// Helper to detect x = x + <multiplications>.
100static bool isSumOfMul(GenericOp op) {
101 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104 Value x = op.getBlock()->getArguments().back();
105 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107 }
108 }
109 return false;
110}
111
112// Helper to detect direct yield of a zero value.
113static bool isZeroYield(GenericOp op) {
114 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115 if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116 if (arg.getOwner()->getParentOp() == op) {
117 return isZeroValue(op->getOperand(arg.getArgNumber()));
118 }
119 }
120 return isZeroValue(yieldOp.getOperand(0));
121}
122
123/// Populates given sizes array from type (for static sizes) and from
124/// the tensor (for dynamic sizes).
125static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126 Location loc, ShapedType stp, Value tensor) {
127 for (const auto &d : enumerate(stp.getShape())) {
128 Value dim;
129 if (d.value() == ShapedType::kDynamic)
130 dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131 else
132 dim = constantIndex(builder, loc, d.value());
133 sizes.push_back(dim);
134 }
135}
136
137static RankedTensorType getBufferType(const SparseTensorType &stt,
138 bool needTmpCOO) {
139 return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140 : stt.getRankedTensorType();
141}
142
143/// Collects the dynamic dimension sizes for `tp` with the assumption that
144/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145/// sizes to dynSizes.
146static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147 SmallVectorImpl<Value> &dynSizes) {
148 for (const auto &d : enumerate(tp.getShape())) {
149 if (d.value() == ShapedType::kDynamic)
150 dynSizes.push_back(sizes[d.index()]);
151 }
152}
153
154static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155 RewriterBase &rewriter,
156 SparseElementsAttr attr) {
157 auto loc = op.getLoc();
158 SmallVector<Value> reduc = op.getInitArgs();
159
160 // Foreach on constant.
161 foreachInSparseConstant(
162 rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163 [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164 SmallVector<Value> args;
165 args.append(in_start: cvs.begin(), in_end: cvs.end());
166 args.push_back(Elt: v);
167 args.append(RHS: reduc);
168 // Clones the foreach op to get a copy of the loop body.
169 auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170 assert(args.size() == cloned.getBody()->getNumArguments());
171 Operation *yield = cloned.getBody()->getTerminator();
172 rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173 // clean up
174 rewriter.eraseOp(op: cloned);
175 reduc = yield->getOperands();
176 rewriter.eraseOp(op: yield);
177 });
178
179 rewriter.replaceOp(op, reduc);
180 return success();
181}
182
183/// Populates the given sizes array for concatenation from types (for static
184/// sizes) and from the source tensors (for dynamic sizes).
185static void concatSizesFromInputs(OpBuilder &builder,
186 SmallVectorImpl<Value> &sizes, Location loc,
187 ShapedType dstTp, ValueRange srcs,
188 unsigned dim) {
189 auto dstShape = dstTp.getShape();
190 sizesFromSrc(builder, sizes, loc, src: srcs[0]);
191
192 // Sum up on the `dim` if the dimension is dynamic.
193 if (dstShape[dim] != ShapedType::kDynamic) {
194 // Faithfully take the static size.
195 sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196 } else {
197 // Else, compute the shape dynamically.
198 for (const auto &src : srcs.drop_front()) {
199 Value srcSz = linalg::createOrFoldDimOp(b&: builder, loc, val: src, dim);
200 // Sum up all the sizes.
201 sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
202 }
203 }
204}
205
206//===---------------------------------------------------------------------===//
207// The actual sparse tensor rewriting rules.
208//===---------------------------------------------------------------------===//
209
210namespace {
211
212/// Rewriting rule that converts direct yield of zero with initial allocation.
213struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
214public:
215 using OpRewritePattern<GenericOp>::OpRewritePattern;
216
217 LogicalResult matchAndRewrite(GenericOp op,
218 PatternRewriter &rewriter) const override {
219 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
220 !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
221 !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
222 return failure();
223 auto outputType = getRankedTensorType(op.getResult(0));
224 // Yielding zero on newly materialized sparse tensor can be
225 // optimized directly (regardless of dynamic or static size).
226 if (getSparseTensorEncoding(outputType)) {
227 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
228 return success();
229 }
230 // Use static zero value directly instead of materialization.
231 if (!outputType.hasStaticShape())
232 return failure();
233 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
234 rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
235 rewriter.eraseOp(op: def);
236 return success();
237 }
238};
239
240/// Rewriting rule that converts two kernels:
241///
242/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
243/// X(i,j) = S(i,j) * T(i,j)
244///
245/// into a single kernel, using distributive law:
246///
247/// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
248///
249/// This kind of fusion (merging two ops into one but using arithmetic
250/// equalities that may not hold for floating-point computations) would
251/// be undesirable in the dense case, since we distribute the multiplication
252/// into the reduction loop. However, for sparse sampling tensor S, such
253/// a fusion may actually reduce the asymptotic complexity of the kernel,
254/// since intermediate results may be nullified.
255struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
256public:
257 using OpRewritePattern<GenericOp>::OpRewritePattern;
258
259 LogicalResult matchAndRewrite(GenericOp op,
260 PatternRewriter &rewriter) const override {
261 // Check consumer.
262 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
263 op.getNumResults() != 1 ||
264 op.getNumParallelLoops() != op.getNumLoops() ||
265 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
266 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
267 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
268 return failure();
269 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
270 // operand can be sparse or dense, since the point of this rewriting rule
271 // is detecting a situation in which *more* sparsity is introduced into
272 // a computation, be it already sparse or still dense.
273 unsigned other = 0;
274 if (isSparseTensor(op.getDpsInputOperand(0)))
275 other = 1;
276 else if (!isSparseTensor(op.getDpsInputOperand(1)))
277 return failure();
278 // Check producer.
279 auto prod = dyn_cast_or_null<GenericOp>(
280 op.getDpsInputOperand(other)->get().getDefiningOp());
281 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
282 !prod.getResult(0).hasOneUse())
283 return failure();
284 // Sampling consumer and sum of multiplication chain producer.
285 if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
286 !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
287 !isSampling(op) || !isSumOfMul(prod))
288 return failure();
289 // Modify operand structure of producer and consumer.
290 Location loc = prod.getLoc();
291 SmallVector<Value> inputOps = prod.getInputs();
292 SmallVector<Value> outputOps = op.getOutputs();
293 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
294 inputOps.push_back(Elt: op.getDpsInputOperand(1 - other)->get());
295 fusedIndexMaps.push_back(Elt: fusedIndexMaps.back()); // mimic other
296 // Fuse producer and consumer into a new generic op.
297 auto fusedOp = rewriter.create<GenericOp>(
298 loc, op.getResult(0).getType(), inputOps, outputOps,
299 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
300 /*doc=*/nullptr, /*library_call=*/nullptr);
301 Block &prodBlock = prod.getRegion().front();
302 Block &consBlock = op.getRegion().front();
303 IRMapping mapper;
304 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
305 unsigned num = prodBlock.getNumArguments();
306 for (unsigned i = 0; i < num - 1; i++)
307 addArg(mapper, b: fusedBlock, a: prodBlock.getArgument(i));
308 addArg(mapper, b: fusedBlock, a: consBlock.getArgument(i: 1 - other));
309 addArg(mapper, b: fusedBlock, a: prodBlock.getArgument(i: num - 1));
310 // Clone bodies of the producer and consumer in new evaluation order.
311 auto *acc = prodBlock.getTerminator()->getOperand(idx: 0).getDefiningOp();
312 auto *sampler = consBlock.getTerminator()->getOperand(idx: 0).getDefiningOp();
313 Value last;
314 for (auto &op : prodBlock.without_terminator())
315 if (&op != acc) {
316 last = op.getResult(0);
317 rewriter.clone(op, mapper);
318 }
319 mapper.map(from: consBlock.getArgument(i: other), to: fusedBlock->back().getResult(idx: 0));
320 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
321 last = rewriter.clone(*acc, mapper)->getResult(0);
322 rewriter.create<linalg::YieldOp>(loc, last);
323 // Force initial value on merged allocation for dense outputs.
324 // TODO: deal with non alloc tensor here one day
325 if (!getSparseTensorEncoding(op.getResult(0).getType())) {
326 Value init = prod.getDpsInitOperand(0)
327 ->get()
328 .getDefiningOp<AllocTensorOp>()
329 .getCopy();
330 AllocTensorOp a =
331 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
332 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
333 }
334 // Replace consumer with fused operation. Old producer
335 // and consumer ops will be removed by DCE.
336 rewriter.replaceOp(op, fusedOp->getResults());
337 return success();
338 }
339
340private:
341 // Helper to add argument and record the mapping.
342 static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
343 mapper.map(from: a, to: b->addArgument(type: a.getType(), loc: a.getLoc()));
344 }
345};
346
347// Fuse a tensor cast into producing operation. Note that a tensor.cast
348// should really not be used to convert between sparse encodings. Since
349// the pattern currently appears as a result of some prior rewriting
350// we make an attempt to repair very obvious cases.
351// TODO: audit the pure tensor dialect rewriting rules
352struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
353public:
354 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
355
356 LogicalResult matchAndRewrite(tensor::CastOp op,
357 PatternRewriter &rewriter) const override {
358 Type srcType = op.getSource().getType();
359 Type dstType = op.getDest().getType();
360 // A nop cast simply folds away.
361 if (srcType == dstType) {
362 rewriter.replaceOp(op, op->getResults());
363 return success();
364 }
365 // See if a sparsity changing cast can be fused into producer.
366 if (tensor::isSameTypeWithoutEncoding(tp1: srcType, tp2: dstType)) {
367 if (Operation *def = op.getSource().getDefiningOp()) {
368 if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(Val: def)) {
369 rewriter.modifyOpInPlace(root: def, callable: [&]() {
370 def->getResult(idx: 0).setType(op->getResultTypes()[0]);
371 });
372 rewriter.replaceOp(op, def->getResult(idx: 0));
373 return success();
374 }
375 }
376 }
377 // Repair tensor casts with at least one sparse operand into the
378 // the properly supported sparse_tensor.convert.
379 if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
380 rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
381 return success();
382 }
383 // Fail otherwise.
384 return failure();
385 }
386};
387
388/// Rewrites a sequence of operations for sparse tensor selections in to
389/// semi-ring operations such that they can be compiled correctly by the
390/// sparsifier. E.g., transforming the following sequence
391///
392/// %sel = arith.select %cond, %sp1, %sp2
393///
394/// to
395///
396/// %sel = binary %sp1, %sp2:
397/// both (%l, %r) {yield select %cond, %l, %r}
398/// left (%l) {yield select %cond, %l, 0}
399/// right (%r) {yield select %cond, 0, %r}
400///
401/// TODO: We require that the tensor used for extracting conditions to be dense
402/// to sparsify the code. To support a sparse condition tensor, we need a
403/// tri-nary operation.
404struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
405public:
406 using OpRewritePattern<GenericOp>::OpRewritePattern;
407 LogicalResult matchAndRewrite(GenericOp op,
408 PatternRewriter &rewriter) const override {
409 // Rejects non sparse kernels.
410 if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
411 return failure();
412
413 Location loc = op.getLoc();
414 SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
415 for (Operation &inst : *op.getBody()) {
416 // Matches pattern.
417 auto matched = isRewritablePattern(op, &inst);
418 if (!matched.has_value())
419 continue;
420
421 rewriter.setInsertionPoint(&inst);
422 auto [c, t, f] = matched.value();
423 assert(t.getType() == f.getType());
424 auto selTp = t.getType();
425 auto c0 = constantZero(rewriter, loc, selTp);
426 auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
427 // Initializes all the blocks.
428 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
429 {t.getLoc(), f.getLoc()});
430 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
431 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
432
433 for (auto *r : binOp.getRegions()) {
434 Block *b = &r->front();
435 rewriter.setInsertionPointToStart(b);
436
437 IRMapping irMap;
438 // Clones the cmp operations into the region to make the binary op
439 // admissible.
440 Value newC = c;
441 if (auto *def = c.getDefiningOp())
442 newC = rewriter.clone(*def, irMap)->getResult(0);
443
444 irMap.map(c, newC);
445 if (r == &binOp.getLeftRegion()) {
446 irMap.map(t, b->getArgument(0));
447 irMap.map(f, c0);
448 } else if (r == &binOp.getRightRegion()) {
449 irMap.map(t, c0);
450 irMap.map(f, b->getArgument(0));
451 } else {
452 irMap.map(t, b->getArgument(0));
453 irMap.map(f, b->getArgument(1));
454 }
455 auto y = rewriter.clone(inst, irMap)->getResult(0);
456 rewriter.create<sparse_tensor::YieldOp>(loc, y);
457 }
458
459 // We successfully rewrited a operation. We can not do replacement here
460 // becuase it invalidate the iterator for the current loop to traverse
461 // the instructions.
462 semiRings.emplace_back(&inst, binOp);
463 }
464
465 // Finalizes the replacement.
466 for (auto [sel, semi] : semiRings)
467 rewriter.replaceOp(sel, semi->getResults());
468
469 return success(!semiRings.empty());
470 }
471
472private:
473 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
474 isRewritablePattern(GenericOp op, Operation *v) {
475 auto sel = dyn_cast<arith::SelectOp>(v);
476 if (!sel)
477 return std::nullopt;
478
479 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
480 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
481 // TODO: For simplicity, we only handle cases where both true/false value
482 // are directly loaded the input tensor. We can probably admit more cases
483 // in theory.
484 if (!tVal || !fVal)
485 return std::nullopt;
486
487 // Helper lambda to determine whether the value is loaded from a dense input
488 // or is a loop invariant.
489 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
490 if (auto bArg = dyn_cast<BlockArgument>(Val&: v);
491 bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
492 return true;
493 // If the value is defined outside the loop, it is a loop invariant.
494 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
495 };
496
497 // If the condition value is load directly from a dense tensor or
498 // loop-invariants, we can sparsify the kernel.
499 auto cond = sel.getCondition();
500 if (isValFromDenseInputOrInvariant(cond))
501 return std::make_tuple(cond, tVal, fVal);
502
503 Value cmpL, cmpR;
504 if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
505 matchers::m_Any(&cmpR))) ||
506 matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
507 matchers::m_Any(&cmpR)))) {
508 // TODO: we can do it recursively to check whether all the leaf values are
509 // loaded from dense tensors or are loop invariants.
510 if (isValFromDenseInputOrInvariant(cmpL) ||
511 isValFromDenseInputOrInvariant(cmpR))
512 return std::make_tuple(cond, tVal, fVal);
513 }
514
515 return std::nullopt;
516 };
517};
518
519/// Rewrites a sparse reduction that would not sparsify directly since
520/// doing so would only iterate over the stored elements, ignoring the
521/// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
522/// (note that reductions like add/sub/or/xor can directly be sparsified
523/// since the implicit zeros do not contribute to the final result).
524/// Note that prod/and are still included since, even though they often
525/// are nullified in sparse data, they may still occur for special
526/// situations in which e.g. some rows in a sparse matrix are fully
527/// dense. For min/max, including the implicit zeros is a much more
528/// common situation.
529///
530/// TODO: this essentially "densifies" the operation; we want to implement
531/// this much more efficiently by performing the reduction over the
532/// stored values, and feed in the zero once if there were *any*
533/// implicit zeros as well; but for now, at least we provide
534/// the functionality
535///
536struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
537public:
538 using OpRewritePattern<GenericOp>::OpRewritePattern;
539
540 LogicalResult matchAndRewrite(GenericOp op,
541 PatternRewriter &rewriter) const override {
542 // Reject non-reductions.
543 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
544 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
545 return failure();
546 auto *inp = op.getDpsInputOperand(0);
547 auto *init = op.getDpsInitOperand(0);
548 if (!isSparseTensor(inp))
549 return failure();
550 // Look for direct x = x OP y for semi-ring ready reductions.
551 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
552 .getOperand(0)
553 .getDefiningOp();
554 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
555 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
556 arith::MaxUIOp>(red))
557 return failure();
558 Value s0 = op.getBlock()->getArgument(0);
559 Value s1 = op.getBlock()->getArgument(1);
560 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
561 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
562 return failure();
563 // Identity.
564 Location loc = op.getLoc();
565 Value identity =
566 rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
567 // Unary {
568 // present -> value
569 // absent -> zero.
570 // }
571 Type rtp = s0.getType();
572 rewriter.setInsertionPointToStart(&op.getRegion().front());
573 auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
574 Block *present =
575 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
576 rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
577 rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
578 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
579 rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
580 auto zero =
581 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
582 rewriter.create<sparse_tensor::YieldOp>(loc, zero);
583 rewriter.setInsertionPointAfter(semiring);
584 // CustomReduce {
585 // x = x REDUC y, identity
586 // }
587 auto custom = rewriter.create<sparse_tensor::ReduceOp>(
588 loc, rtp, semiring.getResult(), s1, identity);
589 Block *region =
590 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
591 rewriter.setInsertionPointToStart(&custom.getRegion().front());
592 IRMapping irMap;
593 irMap.map(red->getOperand(0), region->getArgument(i: 0));
594 irMap.map(red->getOperand(1), region->getArgument(i: 1));
595 auto *cloned = rewriter.clone(*red, irMap);
596 rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
597 rewriter.setInsertionPointAfter(custom);
598 rewriter.replaceOp(red, custom.getResult());
599 return success();
600 }
601};
602
603/// Sparse rewriting rule for the print operator. This operation is mainly used
604/// for debugging and testing. As such, it lowers to the vector.print operation
605/// which only require very light-weight runtime support.
606struct PrintRewriter : public OpRewritePattern<PrintOp> {
607public:
608 using OpRewritePattern::OpRewritePattern;
609 LogicalResult matchAndRewrite(PrintOp op,
610 PatternRewriter &rewriter) const override {
611 Location loc = op.getLoc();
612 auto tensor = op.getTensor();
613 auto stt = getSparseTensorType(tensor);
614 // Header with NSE.
615 auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
616 rewriter.create<vector::PrintOp>(
617 loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
618 rewriter.create<vector::PrintOp>(loc, nse);
619 // Print run-time contents for dim/lvl sizes.
620 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
621 printSizes(rewriter, loc, tensor: tensor, size: stt.getDimRank(), /*isDim=*/true);
622 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
623 printSizes(rewriter, loc, tensor: tensor, size: stt.getLvlRank(), /*isDim=*/false);
624 // Use the "codegen" foreach loop construct to iterate over
625 // all typical sparse tensor components for printing.
626 foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
627 &stt](Type, FieldIndex,
628 SparseTensorFieldKind kind,
629 Level l, LevelType) {
630 switch (kind) {
631 case SparseTensorFieldKind::StorageSpec: {
632 break;
633 }
634 case SparseTensorFieldKind::PosMemRef: {
635 auto lvl = constantIndex(builder&: rewriter, loc, i: l);
636 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
637 rewriter.create<vector::PrintOp>(
638 loc, lvl, vector::PrintPunctuation::NoPunctuation);
639 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
640 auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
641 printContents(rewriter, loc, vec: pos);
642 break;
643 }
644 case SparseTensorFieldKind::CrdMemRef: {
645 auto lvl = constantIndex(builder&: rewriter, loc, i: l);
646 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
647 rewriter.create<vector::PrintOp>(
648 loc, lvl, vector::PrintPunctuation::NoPunctuation);
649 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
650 Value crd = nullptr;
651 // For COO AoS storage, we want to print a single, linear view of
652 // the full coordinate storage at this level. For any other storage,
653 // we show the coordinate storage for every indivual level.
654 if (stt.getAoSCOOStart() == l)
655 crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
656 else
657 crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
658 printContents(rewriter, loc, vec: crd);
659 break;
660 }
661 case SparseTensorFieldKind::ValMemRef: {
662 rewriter.create<vector::PrintOp>(loc,
663 rewriter.getStringAttr("values : "));
664 auto val = rewriter.create<ToValuesOp>(loc, tensor);
665 printContents(rewriter, loc, vec: val);
666 break;
667 }
668 }
669 return true;
670 });
671 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
672 rewriter.eraseOp(op: op);
673 return success();
674 }
675
676private:
677 // Helper to print contents of a single memref. Note that for the "push_back"
678 // vectors, this prints the full capacity, not just the size. This is done
679 // on purpose, so that clients see how much storage has been allocated in
680 // total. Contents of the extra capacity in the buffer may be uninitialized
681 // (unless the flag enable-buffer-initialization is set to true).
682 //
683 // Generates code to print:
684 // ( a0, a1, ... )
685 static void printContents(PatternRewriter &rewriter, Location loc,
686 Value vec) {
687 // Open bracket.
688 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
689 // For loop over elements.
690 auto zero = constantIndex(builder&: rewriter, loc, i: 0);
691 auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
692 auto step = constantIndex(builder&: rewriter, loc, i: 1);
693 auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
694 rewriter.setInsertionPointToStart(forOp.getBody());
695 auto idx = forOp.getInductionVar();
696 auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
697 if (llvm::isa<ComplexType>(val.getType())) {
698 // Since the vector dialect does not support complex types in any op,
699 // we split those into (real, imag) pairs here.
700 Value real = rewriter.create<complex::ReOp>(loc, val);
701 Value imag = rewriter.create<complex::ImOp>(loc, val);
702 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
703 rewriter.create<vector::PrintOp>(loc, real,
704 vector::PrintPunctuation::Comma);
705 rewriter.create<vector::PrintOp>(loc, imag,
706 vector::PrintPunctuation::Close);
707 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
708 } else {
709 rewriter.create<vector::PrintOp>(loc, val,
710 vector::PrintPunctuation::Comma);
711 }
712 rewriter.setInsertionPointAfter(forOp);
713 // Close bracket and end of line.
714 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
715 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
716 }
717
718 // Helper method to print run-time lvl/dim sizes.
719 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
720 unsigned size, bool isDim) {
721 // Open bracket.
722 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
723 // Print unrolled contents (dimop requires constant value).
724 for (unsigned i = 0; i < size; i++) {
725 auto idx = constantIndex(builder&: rewriter, loc, i);
726 Value val;
727 if (isDim)
728 val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
729 else
730 val = rewriter.create<LvlOp>(loc, tensor, idx);
731 rewriter.create<vector::PrintOp>(
732 loc, val,
733 i != size - 1 ? vector::PrintPunctuation::Comma
734 : vector::PrintPunctuation::NoPunctuation);
735 }
736 // Close bracket and end of line.
737 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
738 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
739 }
740};
741
742/// Sparse rewriting rule for sparse-to-sparse reshape operator.
743struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
744public:
745 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
746
747 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
748 PatternRewriter &rewriter) const override {
749 Location loc = op.getLoc();
750 Value srcTensor = op.getSource();
751 const auto srcTp = getSparseTensorType(val: srcTensor);
752 const auto dstTp = getSparseTensorType(op.getResult());
753
754 if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
755 !dstTp.hasStaticDimShape())
756 return failure();
757
758 SmallVector<Value> srcSizes;
759 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
760 SmallVector<Value> dstSizes;
761 for (Dimension d : dstTp.getDimShape())
762 dstSizes.push_back(constantIndex(rewriter, loc, d));
763
764 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
765 // Only need an unordered COO buffer if input and output are not sorted
766 // in the same way.
767 Type bufferTp = getBufferType(
768 dstTp.withoutDimToLvl(),
769 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
770 SmallVector<Value> dynSizes;
771 Value buffer = rewriter
772 .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
773 nnz, Attribute())
774 .getResult();
775
776 // Convert src coordinates to dst coordinates by first collapsing it to 1D
777 // and then expand it to the match the rank of the destination tensor.
778 // Implemented as follows:
779 // foreach srcCoords %srcTensor
780 // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
781 // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
782 // insert expandedCoords, %buffer
783 //
784 // followed by an optional
785 // %t = sparse_tensor.cast %tmp
786 // depending on whether the input/output are sorted in the same way.
787 const auto encSrc = srcTp.getEncoding();
788 ForeachOp foreachOp = rewriter.create<ForeachOp>(
789 loc, srcTensor, buffer,
790 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
791 ValueRange reduc) {
792 const Dimension srcRank = srcTp.getDimRank();
793 SmallVector<Value> srcDcvs;
794 srcDcvs.reserve(srcRank);
795 for (Dimension d = 0; d < srcRank; d++) {
796 Level lvl = toLvl(encSrc, d);
797 srcDcvs.push_back(srcLcvs[lvl]);
798 }
799
800 Value collapseSize = constantIndex(builder, loc, 1);
801 for (Dimension d = 0; d < srcRank; d++)
802 collapseSize =
803 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
804 SmallVector<Value, 1> collapsedSizes = {collapseSize};
805
806 ReassociationIndices collapseIdx;
807 for (Dimension i = 0; i < srcRank; i++)
808 collapseIdx.push_back(i);
809 SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
810 SmallVector<Value, 1> collapsedDcvs;
811 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
812 collapsedSizes, collapsedDcvs);
813
814 ReassociationIndices expandIdx;
815 for (Dimension i = 0; i < dstTp.getDimRank(); i++)
816 expandIdx.push_back(i);
817 SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
818 SmallVector<Value> dstDcvs;
819 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
820 dstSizes, dstDcvs);
821
822 auto t =
823 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
824 builder.create<sparse_tensor::YieldOp>(loc, t);
825 });
826
827 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
828 if (bufferTp != dstTp) {
829 auto dstRTT = dstTp.getRankedTensorType();
830 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
831 rewriter.create<DeallocTensorOp>(loc, t);
832 t = converted;
833 }
834 rewriter.replaceOp(op, t);
835 return success();
836 }
837};
838
839/// Sparse rewriting rule for sparse-to-sparse reshape operator.
840template <typename ReshapeOp>
841struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
842public:
843 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
844
845 LogicalResult matchAndRewrite(ReshapeOp op,
846 PatternRewriter &rewriter) const override {
847 Location loc = op.getLoc();
848 Value srcTensor = op.getSrc();
849 const auto srcTp = getSparseTensorType(val: srcTensor);
850 const auto dstTp = getSparseTensorType(op.getResult());
851 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
852 return failure();
853
854 // Generate code to represent the static dimension constants or compute
855 // the dynamic dimension values.
856 SmallVector<Value> srcSizes;
857 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
858 SmallVector<Value> dstSizes;
859 SmallVector<Value> dstDynSizes;
860 if (dstTp.hasStaticDimShape()) {
861 for (Dimension d : dstTp.getDimShape())
862 dstSizes.push_back(Elt: constantIndex(builder&: rewriter, loc, i: d));
863 } else {
864 ArrayRef<Size> dstShape = dstTp.getDimShape();
865 genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
866 op.getReassociationIndices());
867 for (auto [idx, shape] : llvm::enumerate(First&: dstShape)) {
868 if (shape == ShapedType::kDynamic)
869 dstDynSizes.push_back(Elt: dstSizes[idx]);
870 }
871 }
872 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
873 // Only need a unordered COO buffer if input and output are not sorted
874 // in the same way.
875 Type bufferTp = getBufferType(
876 dstTp.withoutDimToLvl(),
877 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
878
879 Value buffer =
880 rewriter
881 .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
882 /*sizeHint=*/nnz, Attribute())
883 .getResult();
884
885 // Implement the sparse2sparse reshape as follows:
886 // foreach srcCoords %srcTensor
887 // insert reshapeCvs(srcCoords), %buffer
888 //
889 // followed by an optional
890 // %t = sparse_tensor.cast %tmp
891 // depending on whether the input/output are sorted in the same way.
892 const auto encSrc = srcTp.getEncoding();
893 ForeachOp foreachOp = rewriter.create<ForeachOp>(
894 loc, srcTensor, buffer,
895 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
896 ValueRange reduc) {
897 const Dimension dimRank = srcTp.getDimRank();
898 SmallVector<Value> srcDcvs;
899 srcDcvs.reserve(dimRank);
900 for (Dimension d = 0; d < dimRank; d++) {
901 Level lvl = toLvl(encSrc, d);
902 srcDcvs.push_back(srcLcvs[lvl]);
903 }
904 SmallVector<Value> dstDcvs;
905 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
906 srcDcvs, dstSizes, dstDcvs);
907 auto t =
908 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
909 builder.create<sparse_tensor::YieldOp>(loc, t);
910 });
911
912 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
913 if (bufferTp != dstTp) {
914 auto dstRTT = dstTp.getRankedTensorType();
915 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
916 rewriter.create<DeallocTensorOp>(loc, t);
917 t = converted;
918 }
919 rewriter.replaceOp(op, t);
920 return success();
921 }
922};
923
924/// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
925/// operator.
926template <typename ReshapeOp>
927struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
928public:
929 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
930
931 LogicalResult matchAndRewrite(ReshapeOp op,
932 PatternRewriter &rewriter) const override {
933 Location loc = op->getLoc();
934 auto encDst = getSparseTensorEncoding(op.getResult().getType());
935 auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
936 // Since a pure dense expansion is very cheap (change of view), for
937 // a sparse2dense or dense2sparse, we can simply unfuse a sparse
938 // conversion from the reshape operation itself.
939 // All other cases are handled elsewhere.
940 if (encDst && encSrc) {
941 return failure();
942 }
943 if (encSrc) {
944 auto rtp = getRankedTensorType(op.getSrc());
945 auto denseTp =
946 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
947 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
948 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
949 return success();
950 }
951 if (encDst) {
952 auto rtp = getRankedTensorType(op.getResult());
953 auto denseTp =
954 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
955 auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
956 op.getReassociation());
957 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
958 rewriter.replaceOp(op, convert);
959 return success();
960 }
961 return failure();
962 }
963};
964
965// A trivial wrapper to help generate different operations for dense/sparse
966// tensors.
967struct TensorLike {
968 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
969 ValueRange sizes) {
970 SmallVector<Value> dynSzs;
971 getDynamicSizes(rtt, sizes, dynSzs);
972
973 val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
974 if (!isSparse()) {
975 Value c0 = constantZero(builder, loc, rtt.getElementType());
976 val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
977 }
978 }
979
980 void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
981 val = builder.create<tensor::InsertOp>(loc, v, val, crds);
982 }
983
984 Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
985 if (isSparse())
986 return builder.create<LoadOp>(loc, val, true);
987 return val;
988 }
989
990 bool isSparse() const {
991 return getSparseTensorEncoding(val.getType()) != nullptr;
992 }
993
994 Value val;
995};
996
997struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
998 using OpRewritePattern::OpRewritePattern;
999 LogicalResult matchAndRewrite(tensor::DimOp op,
1000 PatternRewriter &rewriter) const override {
1001 std::optional<int64_t> dim = op.getConstantIndex();
1002 auto stt = getSparseTensorType(op.getSource());
1003 if (!dim || !stt.hasEncoding())
1004 return failure();
1005
1006 if (stt.isPermutation()) {
1007 rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1008 toLvl(stt.getEncoding(), *dim));
1009 return success();
1010 }
1011
1012 // Non-permutation dim2lvl/lvl2dim maps.
1013 // Compute as follows:
1014 // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1015 // Note that it is not the most efficient way (but a more general one) for
1016 // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1017 // computed simply by lvl_size * block_size.
1018 Location loc = op.getLoc();
1019 SmallVector<Value> maxLvlCrds;
1020 for (Level l = 0; l < stt.getLvlRank(); l++) {
1021 Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1022 Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1023 loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1024 maxLvlCrds.push_back(Elt: maxLvlCrd);
1025 }
1026
1027 AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
1028 Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1029 op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
1030 maxLvlCrds);
1031
1032 Value dimSz = rewriter.create<arith::AddIOp>(
1033 loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1034 rewriter.replaceOp(op, dimSz);
1035 return success();
1036 }
1037};
1038
1039struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1040 using OpRewritePattern::OpRewritePattern;
1041 LogicalResult matchAndRewrite(ConcatenateOp op,
1042 PatternRewriter &rewriter) const override {
1043 if (op.needsExtraSort())
1044 op.emitError("ConcatenateOp not staged");
1045
1046 const Location loc = op.getLoc();
1047 const auto dstTp = getSparseTensorType(op);
1048 const Dimension conDim = op.getDimension();
1049 SmallVector<Value> sizes;
1050 concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1051
1052 // %t = concatenate %s1, %s2, %s3 {dim = 1}
1053 // ==>
1054 // if (isSparseDst)
1055 // if (allDense)
1056 // %tmp = bufferization.alloc_tensor dstTp
1057 // else
1058 // %tmp = bufferization.alloc_tensor : unordered COO
1059 // else
1060 // %tmp = memref.alloc : dense tensor
1061 // foreach in %s1 : insert d0, d1, %tmp
1062 // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1063 // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1064
1065 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1066 Value offset = constantIndex(builder&: rewriter, loc, i: 0);
1067 Value iterArg = dstBuf.val;
1068
1069 ForeachOp foreachOp;
1070 for (Value input : op.getInputs()) {
1071 // Builds a for op for each input tensor to append new values into the
1072 // output tensor.
1073 foreachOp = rewriter.create<ForeachOp>(
1074 loc, input, iterArg,
1075 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1076 ValueRange reduc) {
1077 SmallVector<Value> offDimCrd(dcvs);
1078 offDimCrd[conDim] =
1079 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1080
1081 // Enters foreach, updates the SSA chain.
1082 dstBuf.val = reduc.front();
1083 if (!dstTp.isAllDense()) {
1084 Value cond = genIsNonzero(builder, loc, v);
1085 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1086 /*else*/ true);
1087 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1088 builder.create<scf::YieldOp>(loc, dstBuf.val);
1089
1090 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1091 dstBuf.insert(builder, loc, v, offDimCrd);
1092 builder.create<scf::YieldOp>(loc, dstBuf.val);
1093
1094 // Exits the ifOp, update the sparse tensor SSA value.
1095 builder.setInsertionPointAfter(ifOp);
1096 dstBuf.val = ifOp.getResult(0);
1097 } else {
1098 dstBuf.insert(builder, loc, v, offDimCrd);
1099 }
1100 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1101 });
1102 // Accumulates the offset. Note that only static-shaped inputs are allowed
1103 // by concatenate op verifier, which saves us from computing the offset
1104 // dynamically.
1105 const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1106 assert(!ShapedType::isDynamic(sz));
1107 offset = rewriter.create<arith::AddIOp>(loc, offset,
1108 constantIndex(rewriter, loc, sz));
1109 iterArg = foreachOp.getResult(0);
1110 dstBuf.val = iterArg;
1111 }
1112
1113 dstBuf.val = iterArg;
1114 Value ret = dstBuf.finalize(builder&: rewriter, loc, rtp: dstTp.getRankedTensorType());
1115 rewriter.replaceOp(op, ret);
1116 return success();
1117 }
1118};
1119
1120struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1121 using OpRewritePattern::OpRewritePattern;
1122 LogicalResult matchAndRewrite(ConvertOp op,
1123 PatternRewriter &rewriter) const override {
1124 if (op.needsExtraSort())
1125 return op.emitError("ConvertOp not staged.");
1126
1127 // TODO: Maybe we want a different operation for this too.
1128 auto encDst = getSparseTensorEncoding(op.getType());
1129 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1130 if (encDst && encSrc && !encSrc.isSlice() &&
1131 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1132 // Trivial tensor conversion and simple element type conversion is handled
1133 // in codegen.
1134 return failure();
1135 }
1136
1137 Location loc = op.getLoc();
1138 Value src = op.getSource();
1139
1140 SparseTensorType srcStt = getSparseTensorType(op.getSource());
1141 SparseTensorType dstStt = getSparseTensorType(op.getDest());
1142
1143 bool fromSparseConst = false;
1144 if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1145 if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1146 fromSparseConst = true;
1147
1148 const AffineMapAttr foreachOrder =
1149 (!dstStt.isIdentity() && fromSparseConst)
1150 ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
1151 : nullptr;
1152
1153 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1154
1155 SmallVector<Value> sizes;
1156 sizesFromSrc(builder&: rewriter, sizes, loc, src);
1157 ValueRange vs;
1158 TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1159
1160 auto foreachOp = rewriter.create<ForeachOp>(
1161 loc, src, dstBuf.val, foreachOrder,
1162 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1163 ValueRange reduc) {
1164 // Enters the loop, update the SSA value for insertion chain.
1165 dstBuf.val = reduc.front();
1166 if (!skipZeroCheck) {
1167 Value cond = genIsNonzero(builder, loc, v);
1168 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1169 /*else*/ true);
1170 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1171 builder.create<scf::YieldOp>(loc, dstBuf.val);
1172
1173 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1174 dstBuf.insert(builder, loc, v, dcvs);
1175 builder.create<scf::YieldOp>(loc, dstBuf.val);
1176
1177 // Exits the ifOp, update the sparse tensor SSA value.
1178 builder.setInsertionPointAfter(ifOp);
1179 dstBuf.val = ifOp.getResult(0);
1180 } else {
1181 dstBuf.insert(builder, loc, v, dcvs);
1182 }
1183 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1184 });
1185
1186 rewriter.setInsertionPointAfter(foreachOp);
1187
1188 // Exits the for loop, links the SSA chain.
1189 dstBuf.val = foreachOp.getResult(0);
1190
1191 Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1192 rewriter.replaceOp(op, ret);
1193 return success();
1194 }
1195};
1196
1197struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1198 using OpRewritePattern::OpRewritePattern;
1199 LogicalResult matchAndRewrite(CrdTranslateOp op,
1200 PatternRewriter &rewriter) const override {
1201 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1202 ? op.getEncoder().getDimToLvl()
1203 : op.getEncoder().getLvlToDim();
1204
1205 SmallVector<Value> outCrds;
1206 for (AffineExpr result : map.getResults()) {
1207 // TODO: we should probably expand the affine map to IR using our own
1208 // rules, since affine.apply assume signed value, while the cooridinates
1209 // we provided must always be signless.
1210 Value trans = rewriter.create<affine::AffineApplyOp>(
1211 op.getLoc(), AffineMap::get(dimCount: map.getNumDims(), symbolCount: 0, result),
1212 op.getInCrds());
1213 outCrds.push_back(Elt: trans);
1214 }
1215 rewriter.replaceOp(op, outCrds);
1216 return success();
1217 }
1218};
1219
1220/// Sparse rewriting rule for the foreach operator.
1221struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1222public:
1223 using OpRewritePattern::OpRewritePattern;
1224
1225 LogicalResult matchAndRewrite(ForeachOp op,
1226 PatternRewriter &rewriter) const override {
1227
1228 auto loc = op.getLoc();
1229 Value input = op.getTensor();
1230 SmallVector<Value> reduc = op.getInitArgs();
1231 const auto stt = getSparseTensorType(val: input);
1232 const Level lvlRank = stt.getLvlRank();
1233
1234 // Special-case: for each over a sparse constant uses its own rewriting
1235 // rule.
1236 if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1237 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1238 return genForeachOnSparseConstant(op, rewriter, attr);
1239 }
1240 }
1241
1242 // Otherwise, use loop emitter to generate loops.
1243 const auto enc = stt.getEncoding();
1244
1245 // 1. Generates loop for the sparse input.
1246 LoopEmitter loopEmitter(
1247 ValueRange{input},
1248 StringAttr::get(getContext(), ForeachOp::getOperationName()));
1249 loopEmitter.initializeLoopEmit(builder&: rewriter, loc: loc);
1250 for (Level l = 0; l < lvlRank; l++) {
1251 // TODO: provide utility function for loop sequences that only contains
1252 // one for loop?
1253 const SmallVector<TensorLevel, 1> tidLvls{
1254 loopEmitter.makeTensorLevel(t: 0, l)};
1255 loopEmitter.enterNewLoopSeq(builder&: rewriter, loc: loc, tidLvls);
1256 // Note that reduc will be taken care of by loop emitter and get updated
1257 // in place.
1258 loopEmitter.enterCoIterationOverTensorsAtLvls(builder&: rewriter, loc: loc, tidLvls,
1259 reduc);
1260 }
1261
1262 SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1263 if (op.getOrder()) {
1264 // TODO: Support it so that we can do direct conversion from CSR->BSR.
1265 llvm_unreachable(
1266 "Level order not yet implemented on non-constant input tensors.");
1267 }
1268
1269 Value vals = loopEmitter.getValBuffer()[0];
1270 SmallVector<Value> pos = loopEmitter.getValPosits(tid: 0);
1271 // Loads the value from sparse tensor using position-index;
1272 // loads the value from dense tensor using coords.
1273 Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1274 : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1275
1276 // 2. Inline the block in the foreach operator.
1277 Block *srcBlock = op.getBody();
1278
1279 // Remap coordinates.
1280 SmallVector<Value> args =
1281 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1282
1283 // Remap value.
1284 args.push_back(Elt: val);
1285 // Remap reduction variables.
1286 args.append(RHS: reduc);
1287
1288 // Remove sparse_tensor.yield.
1289 SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1290 rewriter.eraseOp(op: srcBlock->getTerminator());
1291
1292 Operation &last = rewriter.getBlock()->back();
1293 if (llvm::isa<scf::YieldOp>(last)) {
1294 // Because `scf.for` inserts an implicit yield op when there is no
1295 // reduction variable upon creation, we reset the insertion point such
1296 // that the block is inlined before *before* the yield op.
1297 rewriter.setInsertionPoint(&last);
1298 }
1299
1300 rewriter.inlineBlockBefore(source: srcBlock, dest: rewriter.getBlock(),
1301 before: rewriter.getInsertionPoint(), argValues: args);
1302 rewriter.setInsertionPointToEnd(rewriter.getBlock());
1303 for (Level l = 0; l < lvlRank; l++) {
1304 // Link the reduction chain. Note that loop emitter update the reducValue
1305 // in place.
1306 loopEmitter.exitCurrentLoop(rewriter, loc: loc, reduc: reducValue);
1307 loopEmitter.exitCurrentLoopSeq(builder&: rewriter, loc: loc);
1308 }
1309
1310 // Replace the foreach operator with the value returned by the outtermost
1311 // for loop.
1312 rewriter.replaceOp(op, reducValue);
1313 return success();
1314 }
1315};
1316
1317/// Sparse rewriting rule for the new operator.
1318struct NewRewriter : public OpRewritePattern<NewOp> {
1319 using OpRewritePattern::OpRewritePattern;
1320 LogicalResult matchAndRewrite(NewOp op,
1321 PatternRewriter &rewriter) const override {
1322 Location loc = op.getLoc();
1323 auto stt = getSparseTensorType(op.getResult());
1324 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1325 return failure();
1326
1327 // Implement the NewOp as follows:
1328 // %orderedCoo = sparse_tensor.new %filename
1329 // %t = sparse_tensor.convert %orderedCoo
1330 // with enveloping reinterpreted_map ops for non-permutations.
1331 RankedTensorType dstTp = stt.getRankedTensorType();
1332 RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1333 Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1334 Value convert = cooTensor;
1335 auto enc = stt.getEncoding();
1336 if (!stt.isPermutation()) { // demap coo, demap dstTp
1337 auto coo = getSparseTensorType(val: cooTensor).getEncoding().withoutDimToLvl();
1338 convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1339 dstTp = getSparseTensorType(val: convert).withEncoding(enc.withoutDimToLvl());
1340 }
1341 convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1342 if (!stt.isPermutation()) // remap to original enc
1343 convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1344 rewriter.replaceOp(op, convert);
1345
1346 // Release the temporary ordered COO tensor.
1347 rewriter.setInsertionPointAfterValue(convert);
1348 rewriter.create<DeallocTensorOp>(loc, cooTensor);
1349
1350 return success();
1351 }
1352};
1353
1354/// Sparse rewriting rule for the out operator.
1355struct OutRewriter : public OpRewritePattern<OutOp> {
1356 using OpRewritePattern::OpRewritePattern;
1357 LogicalResult matchAndRewrite(OutOp op,
1358 PatternRewriter &rewriter) const override {
1359 Location loc = op.getLoc();
1360 // Calculate NNZ.
1361 Value src = op.getTensor();
1362 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1363
1364 // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1365 const auto srcTp = getSparseTensorType(val: src);
1366 const Dimension dimRank = srcTp.getDimRank();
1367 Type indexTp = rewriter.getIndexType();
1368 Value dimSizes = genAlloca(builder&: rewriter, loc, sz: dimRank, tp: indexTp);
1369
1370 // Generate code to calculate dimension size values and store the values to
1371 // the buffer.
1372 SmallVector<Value> dims;
1373 sizesForTensor(rewriter, dims, loc, srcTp, src);
1374 for (Dimension d = 0; d < dimRank; d++) {
1375 rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1376 constantIndex(rewriter, loc, d));
1377 }
1378
1379 // Create a sparse tensor writer and output meta data.
1380 Type opaqueTp = getOpaquePointerType(builder&: rewriter);
1381 Value writer =
1382 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1383 {op.getDest()}, EmitCInterface::Off)
1384 .getResult(0);
1385 Value rankValue = constantIndex(builder&: rewriter, loc, i: dimRank);
1386 createFuncCall(builder&: rewriter, loc, name: "outSparseTensorWriterMetaData", resultType: {},
1387 operands: {writer, rankValue, nnz, dimSizes}, emitCInterface: EmitCInterface::On);
1388
1389 Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1390 Type eltTp = srcTp.getElementType();
1391 SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1392 primaryTypeFunctionSuffix(elemTp: eltTp)};
1393 Value value = genAllocaScalar(builder&: rewriter, loc, tp: eltTp);
1394 ModuleOp module = op->getParentOfType<ModuleOp>();
1395
1396 // For each element in the source tensor, output the element.
1397 rewriter.create<ForeachOp>(
1398 loc, src, std::nullopt,
1399 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1400 ValueRange reduc) {
1401 for (Dimension d = 0; d < dimRank; d++) {
1402 rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1403 constantIndex(builder, loc, d));
1404 }
1405 rewriter.create<memref::StoreOp>(loc, v, value);
1406 SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1407 FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1408 EmitCInterface::On);
1409 builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1410 builder.create<sparse_tensor::YieldOp>(loc);
1411 });
1412
1413 // Release the writer.
1414 createFuncCall(builder&: rewriter, loc, name: "delSparseTensorWriter", resultType: {}, operands: {writer},
1415 emitCInterface: EmitCInterface::Off);
1416
1417 rewriter.eraseOp(op: op);
1418 return success();
1419 }
1420};
1421
1422} // namespace
1423
1424//===---------------------------------------------------------------------===//
1425// Methods that add patterns described in this file to a pattern list.
1426//===---------------------------------------------------------------------===//
1427
1428void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
1429 patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1430 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1431 arg: patterns.getContext());
1432}
1433
1434void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
1435 bool enableRT,
1436 bool enableConvert) {
1437 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1438 ReshapeRewriter<tensor::CollapseShapeOp>,
1439 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1440 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1441 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1442 patterns.getContext());
1443
1444 if (enableConvert)
1445 patterns.add<DirectConvertRewriter>(arg: patterns.getContext());
1446 if (!enableRT)
1447 patterns.add<NewRewriter>(arg: patterns.getContext());
1448}
1449
1450void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
1451 // Run CrdTranslateRewriter later in the pipeline so that operation can be
1452 // folded before lowering to affine.apply
1453 patterns.add<CrdTranslateRewriter, ForeachRewriter>(arg: patterns.getContext());
1454}
1455

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp