1//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
10#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
11#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
12#include "mlir/IR/PatternMatch.h"
13
14using namespace mlir;
15using namespace mlir::sparse_tensor;
16
17#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
18
19/// Stage the operations into a sequence of simple operations as follow:
20/// op -> unsorted_coo +
21/// unsorted_coo -> sorted_coo +
22/// sorted_coo -> dstTp.
23///
24/// return `tmpBuf` if a intermediate memory is allocated.
25LogicalResult sparse_tensor::detail::stageWithSortImpl(
26 StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
27 if (!op.needsExtraSort())
28 return failure();
29
30 Location loc = op.getLoc();
31 Type finalTp = op->getOpResult(0).getType();
32 SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
33 Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
34
35 // Clones the original operation but changing the output to an unordered COO.
36 Operation *cloned = rewriter.clone(*op.getOperation());
37 rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
38 cloned->getOpResult(idx: 0).setType(srcCOOTp);
39 });
40 Value srcCOO = cloned->getOpResult(idx: 0);
41
42 // -> sort
43 Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
44 Value dstCOO = rewriter.create<ReorderCOOOp>(
45 loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
46
47 // -> dest.
48 if (dstCOO.getType() == finalTp) {
49 rewriter.replaceOp(op, dstCOO);
50 } else {
51 // Need an extra conversion if the target type is not COO.
52 auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
53 rewriter.setInsertionPointAfter(c);
54 // Informs the caller about the intermediate buffer we allocated. We can not
55 // create a bufferization::DeallocateTensorOp here because it would
56 // introduce cyclic dependency between the SparseTensorDialect and the
57 // BufferizationDialect. Besides, whether the buffer need to be deallocated
58 // by SparseTensorDialect or by BufferDeallocationPass is still TBD.
59 tmpBufs = dstCOO;
60 }
61
62 return success();
63}
64

source code of mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp