1 | //===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
10 | #include "mlir/IR/IRMapping.h" |
11 | #include "mlir/Transforms/Passes.h" |
12 | |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
14 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
15 | |
16 | namespace mlir { |
17 | #define GEN_PASS_DEF_SPARSESPACECOLLAPSE |
18 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" |
19 | } // namespace mlir |
20 | |
21 | #define DEBUG_TYPE "sparse-space-collapse" |
22 | |
23 | using namespace mlir; |
24 | using namespace sparse_tensor; |
25 | |
26 | namespace { |
27 | |
28 | struct CollapseSpaceInfo { |
29 | ExtractIterSpaceOp space; |
30 | IterateOp loop; |
31 | }; |
32 | |
33 | bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) { |
34 | auto pIterArgs = parent.getRegionIterArgs(); |
35 | auto nInitArgs = node.getInits(); |
36 | if (pIterArgs.size() != nInitArgs.size()) |
37 | return false; |
38 | |
39 | // Two loops are collapsable if they are perfectly nested. |
40 | auto pYields = parent.getYieldedValues(); |
41 | auto nResult = node.getLoopResults().value(); |
42 | |
43 | bool yieldEq = |
44 | llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) { |
45 | return std::get<0>(zipped) == std::get<1>(zipped); |
46 | }); |
47 | |
48 | // Parent iter_args should be passed directly to the node's init_args. |
49 | bool iterArgEq = |
50 | llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) { |
51 | return std::get<0>(zipped) == std::get<1>(zipped); |
52 | }); |
53 | |
54 | return yieldEq && iterArgEq; |
55 | } |
56 | |
57 | bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse, |
58 | ExtractIterSpaceOp curSpace) { |
59 | |
60 | auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp { |
61 | Value spaceVal = space.getExtractedSpace(); |
62 | if (spaceVal.hasOneUse()) |
63 | return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin()); |
64 | return nullptr; |
65 | }; |
66 | |
67 | if (toCollapse.empty()) { |
68 | // Collapse root. |
69 | if (auto itOp = getIterateOpOverSpace(curSpace)) { |
70 | CollapseSpaceInfo &info = toCollapse.emplace_back(); |
71 | info.space = curSpace; |
72 | info.loop = itOp; |
73 | return true; |
74 | } |
75 | return false; |
76 | } |
77 | |
78 | auto parent = toCollapse.back().space; |
79 | auto pItOp = toCollapse.back().loop; |
80 | auto nItOp = getIterateOpOverSpace(curSpace); |
81 | |
82 | // Can only collapse spaces extracted from the same tensor. |
83 | if (parent.getTensor() != curSpace.getTensor()) { |
84 | LLVM_DEBUG({ |
85 | llvm::dbgs() |
86 | << "failed to collpase spaces extracted from different tensors." ; |
87 | }); |
88 | return false; |
89 | } |
90 | |
91 | // Can only collapse consecutive simple iteration on one tensor (i.e., no |
92 | // coiteration). |
93 | if (!nItOp || nItOp->getBlock() != curSpace->getBlock() || |
94 | pItOp.getIterator() != curSpace.getParentIter() || |
95 | curSpace->getParentOp() != pItOp.getOperation()) { |
96 | LLVM_DEBUG( |
97 | { llvm::dbgs() << "failed to collapse non-consecutive IterateOps." ; }); |
98 | return false; |
99 | } |
100 | |
101 | if (pItOp && !isCollapsableLoops(pItOp, nItOp)) { |
102 | LLVM_DEBUG({ |
103 | llvm::dbgs() |
104 | << "failed to collapse IterateOps that are not perfectly nested." ; |
105 | }); |
106 | return false; |
107 | } |
108 | |
109 | CollapseSpaceInfo &info = toCollapse.emplace_back(); |
110 | info.space = curSpace; |
111 | info.loop = nItOp; |
112 | return true; |
113 | } |
114 | |
115 | void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) { |
116 | if (toCollapse.size() < 2) |
117 | return; |
118 | |
119 | ExtractIterSpaceOp root = toCollapse.front().space; |
120 | ExtractIterSpaceOp leaf = toCollapse.back().space; |
121 | Location loc = root.getLoc(); |
122 | |
123 | assert(root->hasOneUse() && leaf->hasOneUse()); |
124 | |
125 | // Insert collapsed operation at the same scope as root operation. |
126 | OpBuilder builder(root); |
127 | |
128 | // Construct the collapsed iteration space. |
129 | auto collapsedSpace = builder.create<ExtractIterSpaceOp>( |
130 | loc, root.getTensor(), root.getParentIter(), root.getLoLvl(), |
131 | leaf.getHiLvl()); |
132 | |
133 | auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin()); |
134 | auto innermost = toCollapse.back().loop; |
135 | |
136 | IRMapping mapper; |
137 | mapper.map(leaf, collapsedSpace.getExtractedSpace()); |
138 | for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs())) |
139 | mapper.map(std::get<0>(z), std::get<1>(z)); |
140 | |
141 | auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper)); |
142 | builder.setInsertionPointToStart(cloned.getBody()); |
143 | |
144 | I64BitSet crdUsedLvls; |
145 | unsigned shift = 0, argIdx = 1; |
146 | for (auto info : toCollapse.drop_back()) { |
147 | I64BitSet set = info.loop.getCrdUsedLvls(); |
148 | crdUsedLvls |= set.lshift(shift); |
149 | shift += info.loop.getSpaceDim(); |
150 | for (BlockArgument crd : info.loop.getCrds()) { |
151 | BlockArgument collapsedCrd = cloned.getBody()->insertArgument( |
152 | argIdx++, builder.getIndexType(), crd.getLoc()); |
153 | crd.replaceAllUsesWith(collapsedCrd); |
154 | } |
155 | } |
156 | crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift); |
157 | cloned.getIterator().setType(collapsedSpace.getType().getIteratorType()); |
158 | cloned.setCrdUsedLvls(crdUsedLvls); |
159 | |
160 | rItOp.replaceAllUsesWith(cloned.getResults()); |
161 | // Erase collapsed loops. |
162 | rItOp.erase(); |
163 | root.erase(); |
164 | } |
165 | |
166 | struct SparseSpaceCollapsePass |
167 | : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> { |
168 | SparseSpaceCollapsePass() = default; |
169 | |
170 | void runOnOperation() override { |
171 | func::FuncOp func = getOperation(); |
172 | |
173 | // A naive (experimental) implementation to collapse consecutive sparse |
174 | // spaces. It does NOT handle complex cases where multiple spaces are |
175 | // extracted in the same basic block. E.g., |
176 | // |
177 | // %space1 = extract_space %t1 ... |
178 | // %space2 = extract_space %t2 ... |
179 | // sparse_tensor.iterate(%sp1) ... |
180 | // |
181 | SmallVector<CollapseSpaceInfo> toCollapse; |
182 | func->walk([&](ExtractIterSpaceOp op) { |
183 | if (!legalToCollapse(toCollapse, op)) { |
184 | // if not legal to collapse one more space, collapse the existing ones |
185 | // and clear. |
186 | collapseSparseSpace(toCollapse); |
187 | toCollapse.clear(); |
188 | } |
189 | }); |
190 | |
191 | collapseSparseSpace(toCollapse); |
192 | } |
193 | }; |
194 | |
195 | } // namespace |
196 | |
197 | std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() { |
198 | return std::make_unique<SparseSpaceCollapsePass>(); |
199 | } |
200 | |