1//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting rules that are specific to sparse tensors.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Utils/CodegenUtils.h"
14#include "Utils/LoopEmitter.h"
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/Linalg/Utils/Utils.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
25#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
26#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27#include "mlir/Dialect/Tensor/IR/Tensor.h"
28#include "mlir/Dialect/Vector/IR/VectorOps.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/Support/LLVM.h"
32
33using namespace mlir;
34using namespace mlir::bufferization;
35using namespace mlir::linalg;
36using namespace mlir::sparse_tensor;
37
38//===---------------------------------------------------------------------===//
39// Helper methods for the actual rewriting rules.
40//===---------------------------------------------------------------------===//
41
42// Helper method to match any typed zero.
43static bool isZeroValue(Value val) {
44 return matchPattern(value: val, pattern: m_Zero()) || matchPattern(value: val, pattern: m_AnyZeroFloat());
45}
46
47// Helper to detect a sparse tensor type operand.
48static bool isSparseTensor(Value v) {
49 auto enc = getSparseTensorEncoding(v.getType());
50 return enc && !llvm::all_of(enc.getLvlTypes(),
51 [](auto lt) { return lt == LevelFormat::Dense; });
52}
53static bool isSparseTensor(OpOperand *op) { return isSparseTensor(v: op->get()); }
54
55// Helper method to find zero/uninitialized tensor materialization.
56static bool isMaterializing(OpOperand *op, bool isZero) {
57 Value val = op->get();
58 // Check allocation, with zero alloc when required.
59 if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60 Value copy = alloc.getCopy();
61 if (isZero)
62 return copy && isZeroValue(val: copy);
63 return !copy;
64 }
65 // Check for empty tensor materialization.
66 if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67 return !isZero;
68 // Last resort for zero alloc: the whole value is zero.
69 return isZero && isZeroValue(val);
70}
71
72// Helper to detect sampling operation.
73static bool isSampling(GenericOp op) {
74 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77 // Both scalar input arguments used exactly once.
78 Value s1 = op.getBlock()->getArgument(0);
79 Value s2 = op.getBlock()->getArgument(1);
80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82 }
83 }
84 return false;
85}
86
87// Helper to detect chain of multiplications that do not involve x.
88static bool isMulChain(Value val, Value x) {
89 if (auto arg = dyn_cast<BlockArgument>(Val&: val))
90 return arg != x;
91 if (auto *def = val.getDefiningOp()) {
92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93 return isMulChain(val: def->getOperand(idx: 0), x) &&
94 isMulChain(val: def->getOperand(idx: 1), x);
95 }
96 return false;
97}
98
99// Helper to detect x = x + <multiplications>.
100static bool isSumOfMul(GenericOp op) {
101 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104 Value x = op.getBlock()->getArguments().back();
105 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107 }
108 }
109 return false;
110}
111
112// Helper to detect direct yield of a zero value.
113static bool isZeroYield(GenericOp op) {
114 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115 if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116 if (arg.getOwner()->getParentOp() == op) {
117 return isZeroValue(op->getOperand(arg.getArgNumber()));
118 }
119 }
120 return isZeroValue(yieldOp.getOperand(0));
121}
122
123/// Populates given sizes array from type (for static sizes) and from
124/// the tensor (for dynamic sizes).
125static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126 Location loc, ShapedType stp, Value tensor) {
127 for (const auto &d : enumerate(stp.getShape())) {
128 Value dim;
129 if (d.value() == ShapedType::kDynamic)
130 dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131 else
132 dim = constantIndex(builder, loc, d.value());
133 sizes.push_back(dim);
134 }
135}
136
137static RankedTensorType getBufferType(const SparseTensorType &stt,
138 bool needTmpCOO) {
139 return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140 : stt.getRankedTensorType();
141}
142
143/// Collects the dynamic dimension sizes for `tp` with the assumption that
144/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145/// sizes to dynSizes.
146static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147 SmallVectorImpl<Value> &dynSizes) {
148 for (const auto &d : enumerate(tp.getShape())) {
149 if (d.value() == ShapedType::kDynamic)
150 dynSizes.push_back(sizes[d.index()]);
151 }
152}
153
154static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155 RewriterBase &rewriter,
156 SparseElementsAttr attr) {
157 auto loc = op.getLoc();
158 SmallVector<Value> reduc = op.getInitArgs();
159
160 // Foreach on constant.
161 foreachInSparseConstant(
162 rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163 [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164 SmallVector<Value> args;
165 args.append(in_start: cvs.begin(), in_end: cvs.end());
166 args.push_back(Elt: v);
167 args.append(RHS: reduc);
168 // Clones the foreach op to get a copy of the loop body.
169 auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170 assert(args.size() == cloned.getBody()->getNumArguments());
171 Operation *yield = cloned.getBody()->getTerminator();
172 rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173 // clean up
174 rewriter.eraseOp(op: cloned);
175 reduc = yield->getOperands();
176 rewriter.eraseOp(op: yield);
177 });
178
179 rewriter.replaceOp(op, reduc);
180 return success();
181}
182
183/// Populates the given sizes array for concatenation from types (for static
184/// sizes) and from the source tensors (for dynamic sizes).
185static void concatSizesFromInputs(OpBuilder &builder,
186 SmallVectorImpl<Value> &sizes, Location loc,
187 ShapedType dstTp, ValueRange srcs,
188 unsigned dim) {
189 auto dstShape = dstTp.getShape();
190 sizesFromSrc(builder, sizes, loc, src: srcs[0]);
191
192 // Sum up on the `dim` if the dimension is dynamic.
193 if (dstShape[dim] != ShapedType::kDynamic) {
194 // Faithfully take the static size.
195 sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196 } else {
197 // Else, compute the shape dynamically.
198 for (const auto &src : srcs.drop_front()) {
199 Value srcSz = linalg::createOrFoldDimOp(b&: builder, loc, val: src, dim);
200 // Sum up all the sizes.
201 sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
202 }
203 }
204}
205
206//===---------------------------------------------------------------------===//
207// The actual sparse tensor rewriting rules.
208//===---------------------------------------------------------------------===//
209
210namespace {
211
212/// TODO: move it to tensor dialect instead.
213///
214/// Fold `tensor.concat` and `tensor.extract_slice`
215///
216/// %concat = tensor.concat dim(2) %t0, %t1
217/// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218/// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220/// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222///
223/// Becomes
224///
225/// %extract0, %extract1 = %t0, %t1
226struct FuseExtractSliceWithConcat
227 : public OpRewritePattern<tensor::ExtractSliceOp> {
228 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
229
230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231 PatternRewriter &rewriter) const override {
232 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233 if (!concatOp)
234 return failure();
235
236 Location loc = extractOp.getLoc();
237 int64_t dim = concatOp.getDim();
238 int64_t rank = extractOp.getResultType().getRank();
239
240 SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241 SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242
243 // Compute the partial sums for the slice offsets.
244 AffineExpr sum = rewriter.getAffineDimExpr(position: 0);
245 SmallVector<AffineExpr> partialSums = {sum};
246 SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247 for (auto [idx, input] :
248 llvm::enumerate(concatOp.getInputs().drop_back())) {
249 sum = sum + rewriter.getAffineDimExpr(idx + 1);
250 partialSums.push_back(sum);
251 offsetStrides.push_back(
252 rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253 }
254 auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255 partialSums, rewriter.getContext());
256 SmallVector<OpFoldResult> dimOffsets =
257 affine::makeComposedFoldedMultiResultAffineApply(
258 b&: rewriter, loc, map: partialSumMap, operands: offsetStrides);
259
260 auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261 for (auto [l, r] : llvm::zip(t&: lhs, u&: rhs)) {
262 std::optional<int64_t> staticVal = getConstantIntValue(ofr: l);
263 if (!staticVal.has_value() || staticVal != getConstantIntValue(ofr: r))
264 return false;
265 }
266 return lhs.size() == rhs.size();
267 };
268
269 for (auto [i, input, offset] :
270 llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271 SmallVector<OpFoldResult> srcSizes =
272 tensor::getMixedSizes(rewriter, loc, input);
273 srcOffsets[dim] = offset;
274
275 SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276 SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277 SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278
279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280 allEqual(srcStrides, dstStrides)) {
281 Value operand = concatOp.getOperand(i);
282 if (operand.getType() == extractOp.getResultType())
283 rewriter.replaceOp(extractOp, operand);
284 break;
285 }
286 }
287
288 return success();
289 }
290};
291
292/// Rewriting rule that fuses sparse_tensor.convert into producer.
293struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294public:
295 using OpRewritePattern::OpRewritePattern;
296
297 LogicalResult matchAndRewrite(ConvertOp op,
298 PatternRewriter &rewriter) const override {
299 auto producer = op.getSource().getDefiningOp<GenericOp>();
300 if (!producer || producer.getDpsInits().size() != 1 ||
301 !isMaterializing(producer.getDpsInitOperand(0), false) ||
302 !producer.getResult(0).hasOneUse()) {
303 return failure();
304 }
305 // Clone the materialization operation, but update the result to sparse.
306 rewriter.setInsertionPoint(producer);
307 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308 Operation *cloned = rewriter.clone(op&: *init);
309 cloned->getResult(idx: 0).setType(op.getResult().getType());
310
311 rewriter.modifyOpInPlace(producer, [&]() {
312 producer.getDpsInitsMutable().assign(cloned->getResults());
313 producer.getResult(0).setType(op.getResult().getType());
314 });
315
316 rewriter.replaceAllOpUsesWith(op, producer);
317 op->erase();
318
319 return success();
320 }
321};
322
323/// Rewriting rule that converts direct yield of zero with initial allocation.
324struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
325public:
326 using OpRewritePattern<GenericOp>::OpRewritePattern;
327
328 LogicalResult matchAndRewrite(GenericOp op,
329 PatternRewriter &rewriter) const override {
330 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
331 !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
332 !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333 return failure();
334 auto outputType = getRankedTensorType(op.getResult(0));
335 // Yielding zero on newly materialized sparse tensor can be
336 // optimized directly (regardless of dynamic or static size).
337 if (getSparseTensorEncoding(outputType)) {
338 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339 return success();
340 }
341 // Use static zero value directly instead of materialization.
342 if (!outputType.hasStaticShape())
343 return failure();
344 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
345 rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
346 rewriter.eraseOp(op: def);
347 return success();
348 }
349};
350
351/// Rewriting rule that converts two kernels:
352///
353/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
354/// X(i,j) = S(i,j) * T(i,j)
355///
356/// into a single kernel, using distributive law:
357///
358/// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
359///
360/// This kind of fusion (merging two ops into one but using arithmetic
361/// equalities that may not hold for floating-point computations) would
362/// be undesirable in the dense case, since we distribute the multiplication
363/// into the reduction loop. However, for sparse sampling tensor S, such
364/// a fusion may actually reduce the asymptotic complexity of the kernel,
365/// since intermediate results may be nullified.
366struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
367public:
368 using OpRewritePattern<GenericOp>::OpRewritePattern;
369
370 LogicalResult matchAndRewrite(GenericOp op,
371 PatternRewriter &rewriter) const override {
372 // Check consumer.
373 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374 op.getNumResults() != 1 ||
375 op.getNumParallelLoops() != op.getNumLoops() ||
376 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
379 return failure();
380 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
381 // operand can be sparse or dense, since the point of this rewriting rule
382 // is detecting a situation in which *more* sparsity is introduced into
383 // a computation, be it already sparse or still dense.
384 unsigned other = 0;
385 if (isSparseTensor(op.getDpsInputOperand(0)))
386 other = 1;
387 else if (!isSparseTensor(op.getDpsInputOperand(1)))
388 return failure();
389 // Check producer.
390 auto prod = dyn_cast_or_null<GenericOp>(
391 op.getDpsInputOperand(other)->get().getDefiningOp());
392 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393 !prod.getResult(0).hasOneUse())
394 return failure();
395 // Sampling consumer and sum of multiplication chain producer.
396 if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
397 !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
398 !isSampling(op) || !isSumOfMul(prod))
399 return failure();
400 // Modify operand structure of producer and consumer.
401 Location loc = prod.getLoc();
402 SmallVector<Value> inputOps = prod.getInputs();
403 SmallVector<Value> outputOps = op.getOutputs();
404 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
405 inputOps.push_back(Elt: op.getDpsInputOperand(1 - other)->get());
406 fusedIndexMaps.push_back(Elt: fusedIndexMaps.back()); // mimic other
407 // Fuse producer and consumer into a new generic op.
408 auto fusedOp = rewriter.create<GenericOp>(
409 loc, op.getResult(0).getType(), inputOps, outputOps,
410 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
411 /*doc=*/nullptr, /*library_call=*/nullptr);
412 Block &prodBlock = prod.getRegion().front();
413 Block &consBlock = op.getRegion().front();
414 IRMapping mapper;
415 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
416 unsigned num = prodBlock.getNumArguments();
417 for (unsigned i = 0; i < num - 1; i++)
418 addArg(mapper, b: fusedBlock, a: prodBlock.getArgument(i));
419 addArg(mapper, b: fusedBlock, a: consBlock.getArgument(i: 1 - other));
420 addArg(mapper, b: fusedBlock, a: prodBlock.getArgument(i: num - 1));
421 // Clone bodies of the producer and consumer in new evaluation order.
422 auto *acc = prodBlock.getTerminator()->getOperand(idx: 0).getDefiningOp();
423 auto *sampler = consBlock.getTerminator()->getOperand(idx: 0).getDefiningOp();
424 Value last;
425 for (auto &op : prodBlock.without_terminator())
426 if (&op != acc) {
427 last = op.getResult(0);
428 rewriter.clone(op, mapper);
429 }
430 mapper.map(from: consBlock.getArgument(i: other), to: fusedBlock->back().getResult(idx: 0));
431 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
432 last = rewriter.clone(*acc, mapper)->getResult(0);
433 rewriter.create<linalg::YieldOp>(loc, last);
434 // Force initial value on merged allocation for dense outputs.
435 // TODO: deal with non alloc tensor here one day
436 if (!getSparseTensorEncoding(op.getResult(0).getType())) {
437 Value init = prod.getDpsInitOperand(0)
438 ->get()
439 .getDefiningOp<AllocTensorOp>()
440 .getCopy();
441 AllocTensorOp a =
442 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
444 }
445 // Replace consumer with fused operation. Old producer
446 // and consumer ops will be removed by DCE.
447 rewriter.replaceOp(op, fusedOp->getResults());
448 return success();
449 }
450
451private:
452 // Helper to add argument and record the mapping.
453 static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
454 mapper.map(from: a, to: b->addArgument(type: a.getType(), loc: a.getLoc()));
455 }
456};
457
458// Fuse a tensor cast into producing operation. Note that a tensor.cast
459// should really not be used to convert between sparse encodings. Since
460// the pattern currently appears as a result of some prior rewriting
461// we make an attempt to repair very obvious cases.
462// TODO: audit the pure tensor dialect rewriting rules
463struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
464public:
465 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
466
467 LogicalResult matchAndRewrite(tensor::CastOp op,
468 PatternRewriter &rewriter) const override {
469 Type srcType = op.getSource().getType();
470 Type dstType = op.getDest().getType();
471 // A nop cast simply folds away.
472 if (srcType == dstType) {
473 rewriter.replaceOp(op, op->getResults());
474 return success();
475 }
476 // See if a sparsity changing cast can be fused into producer.
477 if (tensor::isSameTypeWithoutEncoding(tp1: srcType, tp2: dstType)) {
478 if (Operation *def = op.getSource().getDefiningOp()) {
479 if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(Val: def)) {
480 rewriter.modifyOpInPlace(root: def, callable: [&]() {
481 def->getResult(idx: 0).setType(op->getResultTypes()[0]);
482 });
483 rewriter.replaceOp(op, def->getResult(idx: 0));
484 return success();
485 }
486 }
487 }
488 // Repair tensor casts with at least one sparse operand into the
489 // the properly supported sparse_tensor.convert.
490 if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
491 rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
492 return success();
493 }
494 // Fail otherwise.
495 return failure();
496 }
497};
498
499/// Rewrites a sequence of operations for sparse tensor selections in to
500/// semi-ring operations such that they can be compiled correctly by the
501/// sparsifier. E.g., transforming the following sequence
502///
503/// %sel = arith.select %cond, %sp1, %sp2
504///
505/// to
506///
507/// %sel = binary %sp1, %sp2:
508/// both (%l, %r) {yield select %cond, %l, %r}
509/// left (%l) {yield select %cond, %l, 0}
510/// right (%r) {yield select %cond, 0, %r}
511///
512/// TODO: We require that the tensor used for extracting conditions to be dense
513/// to sparsify the code. To support a sparse condition tensor, we need a
514/// tri-nary operation.
515struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
516public:
517 using OpRewritePattern<GenericOp>::OpRewritePattern;
518 LogicalResult matchAndRewrite(GenericOp op,
519 PatternRewriter &rewriter) const override {
520 // Rejects non sparse kernels.
521 if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
522 return failure();
523
524 Location loc = op.getLoc();
525 SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
526 for (Operation &inst : *op.getBody()) {
527 // Matches pattern.
528 auto matched = isRewritablePattern(op, &inst);
529 if (!matched.has_value())
530 continue;
531
532 rewriter.setInsertionPoint(&inst);
533 auto [c, t, f] = matched.value();
534 assert(t.getType() == f.getType());
535 auto selTp = t.getType();
536 auto c0 = constantZero(rewriter, loc, selTp);
537 auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
538 // Initializes all the blocks.
539 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540 {t.getLoc(), f.getLoc()});
541 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
543
544 for (auto *r : binOp.getRegions()) {
545 Block *b = &r->front();
546 rewriter.setInsertionPointToStart(b);
547
548 IRMapping irMap;
549 // Clones the cmp operations into the region to make the binary op
550 // admissible.
551 Value newC = c;
552 if (auto *def = c.getDefiningOp())
553 newC = rewriter.clone(*def, irMap)->getResult(0);
554
555 irMap.map(c, newC);
556 if (r == &binOp.getLeftRegion()) {
557 irMap.map(t, b->getArgument(0));
558 irMap.map(f, c0);
559 } else if (r == &binOp.getRightRegion()) {
560 irMap.map(t, c0);
561 irMap.map(f, b->getArgument(0));
562 } else {
563 irMap.map(t, b->getArgument(0));
564 irMap.map(f, b->getArgument(1));
565 }
566 auto y = rewriter.clone(inst, irMap)->getResult(0);
567 rewriter.create<sparse_tensor::YieldOp>(loc, y);
568 }
569
570 // We successfully rewrited a operation. We can not do replacement here
571 // becuase it invalidate the iterator for the current loop to traverse
572 // the instructions.
573 semiRings.emplace_back(&inst, binOp);
574 }
575
576 // Finalizes the replacement.
577 for (auto [sel, semi] : semiRings)
578 rewriter.replaceOp(sel, semi->getResults());
579
580 return success(!semiRings.empty());
581 }
582
583private:
584 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585 isRewritablePattern(GenericOp op, Operation *v) {
586 auto sel = dyn_cast<arith::SelectOp>(v);
587 if (!sel)
588 return std::nullopt;
589
590 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
592 // TODO: For simplicity, we only handle cases where both true/false value
593 // are directly loaded the input tensor. We can probably admit more cases
594 // in theory.
595 if (!tVal || !fVal)
596 return std::nullopt;
597
598 // Helper lambda to determine whether the value is loaded from a dense input
599 // or is a loop invariant.
600 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601 if (auto bArg = dyn_cast<BlockArgument>(Val&: v);
602 bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
603 return true;
604 // If the value is defined outside the loop, it is a loop invariant.
605 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
606 };
607
608 // If the condition value is load directly from a dense tensor or
609 // loop-invariants, we can sparsify the kernel.
610 auto cond = sel.getCondition();
611 if (isValFromDenseInputOrInvariant(cond))
612 return std::make_tuple(cond, tVal, fVal);
613
614 Value cmpL, cmpR;
615 if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
616 matchers::m_Any(&cmpR))) ||
617 matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
618 matchers::m_Any(&cmpR)))) {
619 // TODO: we can do it recursively to check whether all the leaf values are
620 // loaded from dense tensors or are loop invariants.
621 if (isValFromDenseInputOrInvariant(cmpL) ||
622 isValFromDenseInputOrInvariant(cmpR))
623 return std::make_tuple(cond, tVal, fVal);
624 }
625
626 return std::nullopt;
627 };
628};
629
630/// Rewrites a sparse reduction that would not sparsify directly since
631/// doing so would only iterate over the stored elements, ignoring the
632/// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
633/// (note that reductions like add/sub/or/xor can directly be sparsified
634/// since the implicit zeros do not contribute to the final result).
635/// Note that prod/and are still included since, even though they often
636/// are nullified in sparse data, they may still occur for special
637/// situations in which e.g. some rows in a sparse matrix are fully
638/// dense. For min/max, including the implicit zeros is a much more
639/// common situation.
640///
641/// TODO: this essentially "densifies" the operation; we want to implement
642/// this much more efficiently by performing the reduction over the
643/// stored values, and feed in the zero once if there were *any*
644/// implicit zeros as well; but for now, at least we provide
645/// the functionality
646///
647struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
648public:
649 using OpRewritePattern<GenericOp>::OpRewritePattern;
650
651 LogicalResult matchAndRewrite(GenericOp op,
652 PatternRewriter &rewriter) const override {
653 // Reject non-reductions.
654 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
656 return failure();
657 auto *inp = op.getDpsInputOperand(0);
658 auto *init = op.getDpsInitOperand(0);
659 if (!isSparseTensor(inp))
660 return failure();
661 // Look for direct x = x OP y for semi-ring ready reductions.
662 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
663 .getOperand(0)
664 .getDefiningOp();
665 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667 arith::MaxUIOp>(red))
668 return failure();
669 Value s0 = op.getBlock()->getArgument(0);
670 Value s1 = op.getBlock()->getArgument(1);
671 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
673 return failure();
674 // Identity.
675 Location loc = op.getLoc();
676 Value identity =
677 rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
678 // Unary {
679 // present -> value
680 // absent -> zero.
681 // }
682 Type rtp = s0.getType();
683 rewriter.setInsertionPointToStart(&op.getRegion().front());
684 auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
685 Block *present =
686 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
687 rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
688 rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
689 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
690 rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
691 auto zero =
692 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
693 rewriter.create<sparse_tensor::YieldOp>(loc, zero);
694 rewriter.setInsertionPointAfter(semiring);
695 // CustomReduce {
696 // x = x REDUC y, identity
697 // }
698 auto custom = rewriter.create<sparse_tensor::ReduceOp>(
699 loc, rtp, semiring.getResult(), s1, identity);
700 Block *region =
701 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
702 rewriter.setInsertionPointToStart(&custom.getRegion().front());
703 IRMapping irMap;
704 irMap.map(red->getOperand(0), region->getArgument(i: 0));
705 irMap.map(red->getOperand(1), region->getArgument(i: 1));
706 auto *cloned = rewriter.clone(*red, irMap);
707 rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
708 rewriter.setInsertionPointAfter(custom);
709 rewriter.replaceOp(red, custom.getResult());
710 return success();
711 }
712};
713
714/// Sparse rewriting rule for the print operator. This operation is mainly used
715/// for debugging and testing. As such, it lowers to the vector.print operation
716/// which only require very light-weight runtime support.
717struct PrintRewriter : public OpRewritePattern<PrintOp> {
718public:
719 using OpRewritePattern::OpRewritePattern;
720 LogicalResult matchAndRewrite(PrintOp op,
721 PatternRewriter &rewriter) const override {
722 Location loc = op.getLoc();
723 auto tensor = op.getTensor();
724 auto stt = getSparseTensorType(tensor);
725 // Header with NSE.
726 auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
727 rewriter.create<vector::PrintOp>(
728 loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
729 rewriter.create<vector::PrintOp>(loc, nse);
730 // Print run-time contents for dim/lvl sizes.
731 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
732 printSizes(rewriter, loc, tensor: tensor, size: stt.getDimRank(), /*isDim=*/true);
733 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
734 printSizes(rewriter, loc, tensor: tensor, size: stt.getLvlRank(), /*isDim=*/false);
735 // Use the "codegen" foreach loop construct to iterate over
736 // all typical sparse tensor components for printing.
737 foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
738 &stt](Type, FieldIndex,
739 SparseTensorFieldKind kind,
740 Level l, LevelType) {
741 switch (kind) {
742 case SparseTensorFieldKind::StorageSpec: {
743 break;
744 }
745 case SparseTensorFieldKind::PosMemRef: {
746 auto lvl = constantIndex(builder&: rewriter, loc, i: l);
747 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
748 rewriter.create<vector::PrintOp>(
749 loc, lvl, vector::PrintPunctuation::NoPunctuation);
750 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
751 auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
752 printContents(rewriter, loc, vec: pos);
753 break;
754 }
755 case SparseTensorFieldKind::CrdMemRef: {
756 auto lvl = constantIndex(builder&: rewriter, loc, i: l);
757 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
758 rewriter.create<vector::PrintOp>(
759 loc, lvl, vector::PrintPunctuation::NoPunctuation);
760 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
761 Value crd = nullptr;
762 // For COO AoS storage, we want to print a single, linear view of
763 // the full coordinate storage at this level. For any other storage,
764 // we show the coordinate storage for every indivual level.
765 if (stt.getAoSCOOStart() == l)
766 crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
767 else
768 crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
769 printContents(rewriter, loc, vec: crd);
770 break;
771 }
772 case SparseTensorFieldKind::ValMemRef: {
773 rewriter.create<vector::PrintOp>(loc,
774 rewriter.getStringAttr("values : "));
775 auto val = rewriter.create<ToValuesOp>(loc, tensor);
776 printContents(rewriter, loc, vec: val);
777 break;
778 }
779 }
780 return true;
781 });
782 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
783 rewriter.eraseOp(op: op);
784 return success();
785 }
786
787private:
788 // Helper to print contents of a single memref. For "push_back" vectors,
789 // we assume that the previous getters for pos/crd/val have added a
790 // slice-to-size view to make sure we just print the size and not the
791 // full capacity.
792 //
793 // Generates code to print (1-dim or higher):
794 // ( a0, a1, ... )
795 static void printContents(PatternRewriter &rewriter, Location loc,
796 Value vec) {
797 auto shape = cast<ShapedType>(vec.getType()).getShape();
798 SmallVector<Value> idxs;
799 printContentsLevel(rewriter, loc, vec, i: 0, shape: shape, idxs);
800 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
801 }
802
803 // Helper to the helper.
804 static void printContentsLevel(PatternRewriter &rewriter, Location loc,
805 Value vec, unsigned i, ArrayRef<int64_t> shape,
806 SmallVectorImpl<Value> &idxs) {
807 // Open bracket.
808 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
809 // Generate for loop.
810 auto zero = constantIndex(builder&: rewriter, loc, i: 0);
811 auto index = constantIndex(builder&: rewriter, loc, i);
812 auto size = rewriter.create<memref::DimOp>(loc, vec, index);
813 auto step = constantIndex(builder&: rewriter, loc, i: 1);
814 auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
815 idxs.push_back(Elt: forOp.getInductionVar());
816 rewriter.setInsertionPointToStart(forOp.getBody());
817 if (i < shape.size() - 1) {
818 // Enter deeper loop nest.
819 printContentsLevel(rewriter, loc, vec, i: i + 1, shape, idxs);
820 } else {
821 // Actual contents printing.
822 auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
823 if (llvm::isa<ComplexType>(val.getType())) {
824 // Since the vector dialect does not support complex types in any op,
825 // we split those into (real, imag) pairs here.
826 Value real = rewriter.create<complex::ReOp>(loc, val);
827 Value imag = rewriter.create<complex::ImOp>(loc, val);
828 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
829 rewriter.create<vector::PrintOp>(loc, real,
830 vector::PrintPunctuation::Comma);
831 rewriter.create<vector::PrintOp>(loc, imag,
832 vector::PrintPunctuation::Close);
833 } else {
834 rewriter.create<vector::PrintOp>(
835 loc, val, vector::PrintPunctuation::NoPunctuation);
836 }
837 // Terminating comma (except at end).
838 auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step);
839 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
840 bound, size);
841 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
842 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
843 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
844 }
845 idxs.pop_back();
846 rewriter.setInsertionPointAfter(forOp);
847 // Close bracket.
848 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
849 }
850
851 // Helper method to print run-time lvl/dim sizes.
852 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
853 unsigned size, bool isDim) {
854 // Open bracket.
855 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
856 // Print unrolled contents (dimop requires constant value).
857 for (unsigned i = 0; i < size; i++) {
858 auto idx = constantIndex(builder&: rewriter, loc, i);
859 Value val;
860 if (isDim)
861 val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
862 else
863 val = rewriter.create<LvlOp>(loc, tensor, idx);
864 rewriter.create<vector::PrintOp>(
865 loc, val,
866 i != size - 1 ? vector::PrintPunctuation::Comma
867 : vector::PrintPunctuation::NoPunctuation);
868 }
869 // Close bracket and end of line.
870 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
871 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
872 }
873};
874
875/// Sparse rewriting rule for sparse-to-sparse reshape operator.
876struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
877public:
878 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
879
880 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
881 PatternRewriter &rewriter) const override {
882 Location loc = op.getLoc();
883 Value srcTensor = op.getSource();
884 const auto srcTp = tryGetSparseTensorType(srcTensor);
885 const auto dstTp = tryGetSparseTensorType(op.getResult());
886 if (!srcTp || !dstTp)
887 return failure();
888
889 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
890 !dstTp->hasStaticDimShape())
891 return failure();
892
893 SmallVector<Value> srcSizes;
894 sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
895 SmallVector<Value> dstSizes;
896 for (Dimension d : dstTp->getDimShape())
897 dstSizes.push_back(constantIndex(rewriter, loc, d));
898
899 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
900 // Only need an unordered COO buffer if input and output are not sorted
901 // in the same way.
902 Type bufferTp = getBufferType(
903 dstTp->withoutDimToLvl(),
904 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
905 SmallVector<Value> dynSizes;
906 Value buffer = rewriter
907 .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
908 nnz, Attribute())
909 .getResult();
910
911 // Convert src coordinates to dst coordinates by first collapsing it to 1D
912 // and then expand it to the match the rank of the destination tensor.
913 // Implemented as follows:
914 // foreach srcCoords %srcTensor
915 // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
916 // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
917 // insert expandedCoords, %buffer
918 //
919 // followed by an optional
920 // %t = sparse_tensor.cast %tmp
921 // depending on whether the input/output are sorted in the same way.
922 const auto encSrc = srcTp->getEncoding();
923 ForeachOp foreachOp = rewriter.create<ForeachOp>(
924 loc, srcTensor, buffer,
925 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
926 ValueRange reduc) {
927 const Dimension srcRank = srcTp->getDimRank();
928 SmallVector<Value> srcDcvs;
929 srcDcvs.reserve(srcRank);
930 for (Dimension d = 0; d < srcRank; d++) {
931 Level lvl = toLvl(encSrc, d);
932 srcDcvs.push_back(srcLcvs[lvl]);
933 }
934
935 Value collapseSize = constantIndex(builder, loc, 1);
936 for (Dimension d = 0; d < srcRank; d++)
937 collapseSize =
938 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
939 SmallVector<Value, 1> collapsedSizes = {collapseSize};
940
941 ReassociationIndices collapseIdx;
942 for (Dimension i = 0; i < srcRank; i++)
943 collapseIdx.push_back(i);
944 SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
945 SmallVector<Value, 1> collapsedDcvs;
946 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947 collapsedSizes, collapsedDcvs);
948
949 ReassociationIndices expandIdx;
950 for (Dimension i = 0; i < dstTp->getDimRank(); i++)
951 expandIdx.push_back(i);
952 SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
953 SmallVector<Value> dstDcvs;
954 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
955 dstSizes, dstDcvs);
956
957 auto t =
958 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
959 builder.create<sparse_tensor::YieldOp>(loc, t);
960 });
961
962 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
963 if (bufferTp != *dstTp) {
964 auto dstRTT = dstTp->getRankedTensorType();
965 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
966 rewriter.create<DeallocTensorOp>(loc, t);
967 t = converted;
968 }
969 rewriter.replaceOp(op, t);
970 return success();
971 }
972};
973
974/// Sparse rewriting rule for sparse-to-sparse reshape operator.
975template <typename ReshapeOp>
976struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
977public:
978 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
979
980 LogicalResult matchAndRewrite(ReshapeOp op,
981 PatternRewriter &rewriter) const override {
982 Location loc = op.getLoc();
983 Value srcTensor = op.getSrc();
984 const auto srcTp = getSparseTensorType(val: srcTensor);
985 const auto dstTp = getSparseTensorType(op.getResult());
986 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
987 return failure();
988
989 // Generate code to represent the static dimension constants or compute
990 // the dynamic dimension values.
991 SmallVector<Value> srcSizes;
992 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
993 SmallVector<Value> dstSizes;
994 SmallVector<Value> dstDynSizes;
995 if (dstTp.hasStaticDimShape()) {
996 for (Dimension d : dstTp.getDimShape())
997 dstSizes.push_back(Elt: constantIndex(builder&: rewriter, loc, i: d));
998 } else {
999 ArrayRef<Size> dstShape = dstTp.getDimShape();
1000 genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
1001 op.getReassociationIndices());
1002 for (auto [idx, shape] : llvm::enumerate(First&: dstShape)) {
1003 if (shape == ShapedType::kDynamic)
1004 dstDynSizes.push_back(Elt: dstSizes[idx]);
1005 }
1006 }
1007 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
1008 // Only need a unordered COO buffer if input and output are not sorted
1009 // in the same way.
1010 Type bufferTp = getBufferType(
1011 dstTp.withoutDimToLvl(),
1012 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1013
1014 Value buffer =
1015 rewriter
1016 .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
1017 /*sizeHint=*/nnz, Attribute())
1018 .getResult();
1019
1020 // Implement the sparse2sparse reshape as follows:
1021 // foreach srcCoords %srcTensor
1022 // insert reshapeCvs(srcCoords), %buffer
1023 //
1024 // followed by an optional
1025 // %t = sparse_tensor.cast %tmp
1026 // depending on whether the input/output are sorted in the same way.
1027 const auto encSrc = srcTp.getEncoding();
1028 ForeachOp foreachOp = rewriter.create<ForeachOp>(
1029 loc, srcTensor, buffer,
1030 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
1031 ValueRange reduc) {
1032 const Dimension dimRank = srcTp.getDimRank();
1033 SmallVector<Value> srcDcvs;
1034 srcDcvs.reserve(dimRank);
1035 for (Dimension d = 0; d < dimRank; d++) {
1036 Level lvl = toLvl(encSrc, d);
1037 srcDcvs.push_back(srcLcvs[lvl]);
1038 }
1039 SmallVector<Value> dstDcvs;
1040 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1041 srcDcvs, dstSizes, dstDcvs);
1042 auto t =
1043 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
1044 builder.create<sparse_tensor::YieldOp>(loc, t);
1045 });
1046
1047 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
1048 if (bufferTp != dstTp) {
1049 auto dstRTT = dstTp.getRankedTensorType();
1050 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
1051 rewriter.create<DeallocTensorOp>(loc, t);
1052 t = converted;
1053 }
1054 rewriter.replaceOp(op, t);
1055 return success();
1056 }
1057};
1058
1059/// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1060/// operator.
1061template <typename ReshapeOp>
1062struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1063public:
1064 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1065
1066 LogicalResult matchAndRewrite(ReshapeOp op,
1067 PatternRewriter &rewriter) const override {
1068 Location loc = op->getLoc();
1069 auto encDst = getSparseTensorEncoding(op.getResult().getType());
1070 auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1071 // Since a pure dense expansion is very cheap (change of view), for
1072 // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1073 // conversion from the reshape operation itself.
1074 // All other cases are handled elsewhere.
1075 if (encDst && encSrc) {
1076 return failure();
1077 }
1078 if (encSrc) {
1079 auto rtp = getRankedTensorType(op.getSrc());
1080 auto denseTp =
1081 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1082 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
1083 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1084 return success();
1085 }
1086 if (encDst) {
1087 auto rtp = getRankedTensorType(op.getResult());
1088 auto denseTp =
1089 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1090 ReshapeOp reshape;
1091 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1092 reshape = rewriter.create<ReshapeOp>(
1093 loc, denseTp, op.getSrc(), op.getReassociation(),
1094 op.getOutputShape(), op.getStaticOutputShape());
1095 } else {
1096 reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
1097 op.getReassociation());
1098 }
1099 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
1100 rewriter.replaceOp(op, convert);
1101 return success();
1102 }
1103 return failure();
1104 }
1105};
1106
1107// A trivial wrapper to help generate different operations for dense/sparse
1108// tensors.
1109struct TensorLike {
1110 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1111 ValueRange sizes) {
1112 SmallVector<Value> dynSzs;
1113 getDynamicSizes(rtt, sizes, dynSzs);
1114
1115 val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
1116 if (!isSparse()) {
1117 Value c0 = constantZero(builder, loc, rtt.getElementType());
1118 val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
1119 }
1120 }
1121
1122 void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1123 val = builder.create<tensor::InsertOp>(loc, v, val, crds);
1124 }
1125
1126 Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1127 if (isSparse())
1128 return builder.create<LoadOp>(loc, val, true);
1129 return val;
1130 }
1131
1132 bool isSparse() const {
1133 return getSparseTensorEncoding(val.getType()) != nullptr;
1134 }
1135
1136 Value val;
1137};
1138
1139struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1140 using OpRewritePattern::OpRewritePattern;
1141 LogicalResult matchAndRewrite(tensor::DimOp op,
1142 PatternRewriter &rewriter) const override {
1143 std::optional<int64_t> dim = op.getConstantIndex();
1144 auto stt = tryGetSparseTensorType(op.getSource());
1145 if (!dim || !stt || !stt->hasEncoding())
1146 return failure();
1147
1148 if (stt->isPermutation()) {
1149 rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1150 toLvl(stt->getEncoding(), *dim));
1151 return success();
1152 }
1153
1154 // Non-permutation dim2lvl/lvl2dim maps.
1155 // Compute as follows:
1156 // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1157 // Note that it is not the most efficient way (but a more general one) for
1158 // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1159 // computed simply by lvl_size * block_size.
1160 Location loc = op.getLoc();
1161 SmallVector<Value> maxLvlCrds;
1162 for (Level l = 0; l < stt->getLvlRank(); l++) {
1163 Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1164 Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1165 loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1166 maxLvlCrds.push_back(Elt: maxLvlCrd);
1167 }
1168
1169 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170 Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1171 op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1172 maxLvlCrds);
1173
1174 Value dimSz = rewriter.create<arith::AddIOp>(
1175 loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1176 rewriter.replaceOp(op, dimSz);
1177 return success();
1178 }
1179};
1180
1181struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1182 using OpRewritePattern::OpRewritePattern;
1183 LogicalResult matchAndRewrite(ConcatenateOp op,
1184 PatternRewriter &rewriter) const override {
1185 if (op.needsExtraSort())
1186 op.emitError("ConcatenateOp not staged");
1187
1188 const Location loc = op.getLoc();
1189 const auto dstTp = getSparseTensorType(op);
1190 const Dimension conDim = op.getDimension();
1191 SmallVector<Value> sizes;
1192 concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1193
1194 // %t = concatenate %s1, %s2, %s3 {dim = 1}
1195 // ==>
1196 // if (isSparseDst)
1197 // if (allDense)
1198 // %tmp = bufferization.alloc_tensor dstTp
1199 // else
1200 // %tmp = bufferization.alloc_tensor : unordered COO
1201 // else
1202 // %tmp = memref.alloc : dense tensor
1203 // foreach in %s1 : insert d0, d1, %tmp
1204 // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1205 // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1206
1207 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1208 Value offset = constantIndex(builder&: rewriter, loc, i: 0);
1209 Value iterArg = dstBuf.val;
1210
1211 ForeachOp foreachOp;
1212 for (Value input : op.getInputs()) {
1213 // Builds a for op for each input tensor to append new values into the
1214 // output tensor.
1215 foreachOp = rewriter.create<ForeachOp>(
1216 loc, input, iterArg,
1217 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1218 ValueRange reduc) {
1219 SmallVector<Value> offDimCrd(dcvs);
1220 offDimCrd[conDim] =
1221 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1222
1223 // Enters foreach, updates the SSA chain.
1224 dstBuf.val = reduc.front();
1225 if (!dstTp.isAllDense()) {
1226 Value cond = genIsNonzero(builder, loc, v);
1227 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1228 /*else*/ true);
1229 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1230 builder.create<scf::YieldOp>(loc, dstBuf.val);
1231
1232 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1233 dstBuf.insert(builder, loc, v, offDimCrd);
1234 builder.create<scf::YieldOp>(loc, dstBuf.val);
1235
1236 // Exits the ifOp, update the sparse tensor SSA value.
1237 builder.setInsertionPointAfter(ifOp);
1238 dstBuf.val = ifOp.getResult(0);
1239 } else {
1240 dstBuf.insert(builder, loc, v, offDimCrd);
1241 }
1242 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1243 });
1244 // Accumulates the offset. Note that only static-shaped inputs are allowed
1245 // by concatenate op verifier, which saves us from computing the offset
1246 // dynamically.
1247 const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1248 assert(!ShapedType::isDynamic(sz));
1249 offset = rewriter.create<arith::AddIOp>(loc, offset,
1250 constantIndex(rewriter, loc, sz));
1251 iterArg = foreachOp.getResult(0);
1252 dstBuf.val = iterArg;
1253 }
1254
1255 dstBuf.val = iterArg;
1256 Value ret = dstBuf.finalize(builder&: rewriter, loc, rtp: dstTp.getRankedTensorType());
1257 rewriter.replaceOp(op, ret);
1258 return success();
1259 }
1260};
1261
1262struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1263 using OpRewritePattern::OpRewritePattern;
1264 LogicalResult matchAndRewrite(ConvertOp op,
1265 PatternRewriter &rewriter) const override {
1266 if (op.needsExtraSort())
1267 return op.emitError("ConvertOp not staged.");
1268
1269 // TODO: Maybe we want a different operation for this too.
1270 auto encDst = getSparseTensorEncoding(op.getType());
1271 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1272 if (encDst && encSrc && !encSrc.isSlice() &&
1273 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1274 // Trivial tensor conversion and simple element type conversion is handled
1275 // in codegen.
1276 return failure();
1277 }
1278
1279 Location loc = op.getLoc();
1280 Value src = op.getSource();
1281
1282 SparseTensorType srcStt = getSparseTensorType(op.getSource());
1283 SparseTensorType dstStt = getSparseTensorType(op.getDest());
1284
1285 bool fromSparseConst = false;
1286 if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1287 if (isa<SparseElementsAttr>(constOp.getValue()))
1288 fromSparseConst = true;
1289
1290 const AffineMapAttr foreachOrder =
1291 (!dstStt.isIdentity() && fromSparseConst)
1292 ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
1293 : nullptr;
1294
1295 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1296
1297 SmallVector<Value> sizes;
1298 sizesFromSrc(builder&: rewriter, sizes, loc, src);
1299 ValueRange vs;
1300 TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1301
1302 auto foreachOp = rewriter.create<ForeachOp>(
1303 loc, src, dstBuf.val, foreachOrder,
1304 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1305 ValueRange reduc) {
1306 // Enters the loop, update the SSA value for insertion chain.
1307 dstBuf.val = reduc.front();
1308 if (!skipZeroCheck) {
1309 Value cond = genIsNonzero(builder, loc, v);
1310 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1311 /*else*/ true);
1312 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1313 builder.create<scf::YieldOp>(loc, dstBuf.val);
1314
1315 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1316 dstBuf.insert(builder, loc, v, dcvs);
1317 builder.create<scf::YieldOp>(loc, dstBuf.val);
1318
1319 // Exits the ifOp, update the sparse tensor SSA value.
1320 builder.setInsertionPointAfter(ifOp);
1321 dstBuf.val = ifOp.getResult(0);
1322 } else {
1323 dstBuf.insert(builder, loc, v, dcvs);
1324 }
1325 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1326 });
1327
1328 rewriter.setInsertionPointAfter(foreachOp);
1329
1330 // Exits the for loop, links the SSA chain.
1331 dstBuf.val = foreachOp.getResult(0);
1332
1333 Value ret = dstBuf.finalize(builder&: rewriter, loc, rtp: dstStt.getRankedTensorType());
1334 rewriter.replaceOp(op, ret);
1335 return success();
1336 }
1337};
1338
1339struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1340 using OpRewritePattern::OpRewritePattern;
1341 LogicalResult matchAndRewrite(CrdTranslateOp op,
1342 PatternRewriter &rewriter) const override {
1343 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1344 ? op.getEncoder().getDimToLvl()
1345 : op.getEncoder().getLvlToDim();
1346
1347 SmallVector<Value> outCrds;
1348 for (AffineExpr result : map.getResults()) {
1349 // TODO: we should probably expand the affine map to IR using our own
1350 // rules, since affine.apply assume signed value, while the cooridinates
1351 // we provided must always be signless.
1352 Value trans = rewriter.create<affine::AffineApplyOp>(
1353 op.getLoc(), AffineMap::get(dimCount: map.getNumDims(), symbolCount: 0, result),
1354 op.getInCrds());
1355 outCrds.push_back(Elt: trans);
1356 }
1357 rewriter.replaceOp(op, outCrds);
1358 return success();
1359 }
1360};
1361
1362/// Sparse rewriting rule for the foreach operator.
1363struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1364public:
1365 using OpRewritePattern::OpRewritePattern;
1366
1367 LogicalResult matchAndRewrite(ForeachOp op,
1368 PatternRewriter &rewriter) const override {
1369
1370 auto loc = op.getLoc();
1371 Value input = op.getTensor();
1372 SmallVector<Value> reduc = op.getInitArgs();
1373 const auto stt = getSparseTensorType(val: input);
1374 const Level lvlRank = stt.getLvlRank();
1375
1376 // Special-case: for each over a sparse constant uses its own rewriting
1377 // rule.
1378 if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1379 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1380 return genForeachOnSparseConstant(op, rewriter, attr);
1381 }
1382 }
1383
1384 // Otherwise, use loop emitter to generate loops.
1385 const auto enc = stt.getEncoding();
1386
1387 // 1. Generates loop for the sparse input.
1388 LoopEmitter loopEmitter(
1389 ValueRange{input},
1390 StringAttr::get(getContext(), ForeachOp::getOperationName()));
1391 loopEmitter.initializeLoopEmit(builder&: rewriter, loc: loc);
1392 for (Level l = 0; l < lvlRank; l++) {
1393 // TODO: provide utility function for loop sequences that only contains
1394 // one for loop?
1395 const SmallVector<TensorLevel, 1> tidLvls{
1396 loopEmitter.makeTensorLevel(t: 0, l)};
1397 loopEmitter.enterNewLoopSeq(builder&: rewriter, loc: loc, tidLvls);
1398 // Note that reduc will be taken care of by loop emitter and get updated
1399 // in place.
1400 loopEmitter.enterCoIterationOverTensorsAtLvls(builder&: rewriter, loc: loc, tidLvls, numCases: 1,
1401 reduc);
1402 }
1403
1404 SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1405 if (op.getOrder()) {
1406 // TODO: Support it so that we can do direct conversion from CSR->BSR.
1407 llvm_unreachable(
1408 "Level order not yet implemented on non-constant input tensors.");
1409 }
1410
1411 Value vals = loopEmitter.getValBuffer()[0];
1412 SmallVector<Value> pos = loopEmitter.getValPosits(tid: 0);
1413 // Loads the value from sparse tensor using position-index;
1414 // loads the value from dense tensor using coords.
1415 Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1416 : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1417
1418 // 2. Inline the block in the foreach operator.
1419 Block *srcBlock = op.getBody();
1420
1421 // Remap coordinates.
1422 SmallVector<Value> args =
1423 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1424
1425 // Remap value.
1426 args.push_back(Elt: val);
1427 // Remap reduction variables.
1428 args.append(RHS: reduc);
1429
1430 // Remove sparse_tensor.yield.
1431 SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1432 rewriter.eraseOp(op: srcBlock->getTerminator());
1433
1434 Operation &last = rewriter.getBlock()->back();
1435 if (llvm::isa<scf::YieldOp>(last)) {
1436 // Because `scf.for` inserts an implicit yield op when there is no
1437 // reduction variable upon creation, we reset the insertion point such
1438 // that the block is inlined before *before* the yield op.
1439 rewriter.setInsertionPoint(&last);
1440 }
1441
1442 rewriter.inlineBlockBefore(source: srcBlock, dest: rewriter.getBlock(),
1443 before: rewriter.getInsertionPoint(), argValues: args);
1444 rewriter.setInsertionPointToEnd(rewriter.getBlock());
1445 for (Level l = 0; l < lvlRank; l++) {
1446 // Link the reduction chain. Note that loop emitter update the reducValue
1447 // in place.
1448 loopEmitter.exitCurrentLoop(rewriter, loc: loc, reduc: reducValue);
1449 loopEmitter.exitCurrentLoopSeq(builder&: rewriter, loc: loc);
1450 }
1451
1452 // Replace the foreach operator with the value returned by the outtermost
1453 // for loop.
1454 rewriter.replaceOp(op, reducValue);
1455 return success();
1456 }
1457};
1458
1459/// Sparse rewriting rule for the new operator.
1460struct NewRewriter : public OpRewritePattern<NewOp> {
1461 using OpRewritePattern::OpRewritePattern;
1462 LogicalResult matchAndRewrite(NewOp op,
1463 PatternRewriter &rewriter) const override {
1464 Location loc = op.getLoc();
1465 auto stt = getSparseTensorType(op.getResult());
1466 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1467 return failure();
1468
1469 // Implement the NewOp as follows:
1470 // %orderedCoo = sparse_tensor.new %filename
1471 // %t = sparse_tensor.convert %orderedCoo
1472 // with enveloping reinterpreted_map ops for non-permutations.
1473 RankedTensorType dstTp = stt.getRankedTensorType();
1474 RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1475 Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1476 Value convert = cooTensor;
1477 auto enc = stt.getEncoding();
1478 if (!stt.isPermutation()) { // demap coo, demap dstTp
1479 auto coo = getSparseTensorType(val: cooTensor).getEncoding().withoutDimToLvl();
1480 convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1481 dstTp = getSparseTensorType(val: convert).withEncoding(enc.withoutDimToLvl());
1482 }
1483 convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1484 if (!stt.isPermutation()) // remap to original enc
1485 convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1486 rewriter.replaceOp(op, convert);
1487
1488 // Release the temporary ordered COO tensor.
1489 rewriter.setInsertionPointAfterValue(convert);
1490 rewriter.create<DeallocTensorOp>(loc, cooTensor);
1491
1492 return success();
1493 }
1494};
1495
1496/// Sparse rewriting rule for the out operator.
1497struct OutRewriter : public OpRewritePattern<OutOp> {
1498 using OpRewritePattern::OpRewritePattern;
1499 LogicalResult matchAndRewrite(OutOp op,
1500 PatternRewriter &rewriter) const override {
1501 Location loc = op.getLoc();
1502 // Calculate NNZ.
1503 Value src = op.getTensor();
1504 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1505
1506 // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1507 const auto srcTp = getSparseTensorType(val: src);
1508 const Dimension dimRank = srcTp.getDimRank();
1509 Type indexTp = rewriter.getIndexType();
1510 Value dimSizes = genAlloca(builder&: rewriter, loc, sz: dimRank, tp: indexTp);
1511
1512 // Generate code to calculate dimension size values and store the values to
1513 // the buffer.
1514 SmallVector<Value> dims;
1515 sizesForTensor(rewriter, dims, loc, srcTp, src);
1516 for (Dimension d = 0; d < dimRank; d++) {
1517 rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1518 constantIndex(rewriter, loc, d));
1519 }
1520
1521 // Create a sparse tensor writer and output meta data.
1522 Type opaqueTp = getOpaquePointerType(builder&: rewriter);
1523 Value writer =
1524 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1525 {op.getDest()}, EmitCInterface::Off)
1526 .getResult(0);
1527 Value rankValue = constantIndex(builder&: rewriter, loc, i: dimRank);
1528 createFuncCall(builder&: rewriter, loc, name: "outSparseTensorWriterMetaData", resultType: {},
1529 operands: {writer, rankValue, nnz, dimSizes}, emitCInterface: EmitCInterface::On);
1530
1531 Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1532 Type eltTp = srcTp.getElementType();
1533 SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1534 primaryTypeFunctionSuffix(elemTp: eltTp)};
1535 Value value = genAllocaScalar(builder&: rewriter, loc, tp: eltTp);
1536 ModuleOp module = op->getParentOfType<ModuleOp>();
1537
1538 // For each element in the source tensor, output the element.
1539 rewriter.create<ForeachOp>(
1540 loc, src, std::nullopt,
1541 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1542 ValueRange reduc) {
1543 for (Dimension d = 0; d < dimRank; d++) {
1544 rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1545 constantIndex(builder, loc, d));
1546 }
1547 rewriter.create<memref::StoreOp>(loc, v, value);
1548 SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1549 FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1550 EmitCInterface::On);
1551 builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1552 builder.create<sparse_tensor::YieldOp>(loc);
1553 });
1554
1555 // Release the writer.
1556 createFuncCall(builder&: rewriter, loc, name: "delSparseTensorWriter", resultType: {}, operands: {writer},
1557 emitCInterface: EmitCInterface::Off);
1558
1559 rewriter.eraseOp(op: op);
1560 return success();
1561 }
1562};
1563
1564} // namespace
1565
1566//===---------------------------------------------------------------------===//
1567// Methods that add patterns described in this file to a pattern list.
1568//===---------------------------------------------------------------------===//
1569
1570void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
1571 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1572 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1573 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1574 arg: patterns.getContext());
1575}
1576
1577void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
1578 bool enableRT,
1579 bool enableConvert) {
1580 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1581 ReshapeRewriter<tensor::CollapseShapeOp>,
1582 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1583 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1584 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1585 patterns.getContext());
1586
1587 if (enableConvert)
1588 patterns.add<DirectConvertRewriter>(arg: patterns.getContext());
1589 if (!enableRT)
1590 patterns.add<NewRewriter>(arg: patterns.getContext());
1591}
1592
1593void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
1594 // Run CrdTranslateRewriter later in the pipeline so that operation can be
1595 // folded before lowering to affine.apply
1596 patterns.add<CrdTranslateRewriter, ForeachRewriter>(arg: patterns.getContext());
1597}
1598

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp