1 | //===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" |
10 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
11 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
12 | #include "mlir/IR/PatternMatch.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::sparse_tensor; |
16 | |
17 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc" |
18 | |
19 | /// Stage the operations into a sequence of simple operations as follow: |
20 | /// op -> unsorted_coo + |
21 | /// unsorted_coo -> sorted_coo + |
22 | /// sorted_coo -> dstTp. |
23 | /// |
24 | /// return `tmpBuf` if a intermediate memory is allocated. |
25 | LogicalResult sparse_tensor::detail::stageWithSortImpl( |
26 | StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) { |
27 | if (!op.needsExtraSort()) |
28 | return failure(); |
29 | |
30 | Location loc = op.getLoc(); |
31 | Type finalTp = op->getOpResult(0).getType(); |
32 | SparseTensorType dstStt(cast<RankedTensorType>(finalTp)); |
33 | Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false); |
34 | |
35 | // Clones the original operation but changing the output to an unordered COO. |
36 | Operation *cloned = rewriter.clone(*op.getOperation()); |
37 | rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() { |
38 | cloned->getOpResult(idx: 0).setType(srcCOOTp); |
39 | }); |
40 | Value srcCOO = cloned->getOpResult(idx: 0); |
41 | |
42 | // -> sort |
43 | Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true); |
44 | Value dstCOO = rewriter.create<ReorderCOOOp>( |
45 | loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort); |
46 | |
47 | // -> dest. |
48 | if (dstCOO.getType() == finalTp) { |
49 | rewriter.replaceOp(op, dstCOO); |
50 | } else { |
51 | // Need an extra conversion if the target type is not COO. |
52 | auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO); |
53 | rewriter.setInsertionPointAfter(c); |
54 | // Informs the caller about the intermediate buffer we allocated. We can not |
55 | // create a bufferization::DeallocateTensorOp here because it would |
56 | // introduce cyclic dependency between the SparseTensorDialect and the |
57 | // BufferizationDialect. Besides, whether the buffer need to be deallocated |
58 | // by SparseTensorDialect or by BufferDeallocationPass is still TBD. |
59 | tmpBufs = dstCOO; |
60 | } |
61 | |
62 | return success(); |
63 | } |
64 | |