1//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/IR/IRMapping.h"
11#include "mlir/Transforms/Passes.h"
12
13#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15
16namespace mlir {
17#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
18#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
19} // namespace mlir
20
21#define DEBUG_TYPE "sparse-space-collapse"
22
23using namespace mlir;
24using namespace sparse_tensor;
25
26namespace {
27
28struct CollapseSpaceInfo {
29 ExtractIterSpaceOp space;
30 IterateOp loop;
31};
32
33bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
34 auto pIterArgs = parent.getRegionIterArgs();
35 auto nInitArgs = node.getInits();
36 if (pIterArgs.size() != nInitArgs.size())
37 return false;
38
39 // Two loops are collapsable if they are perfectly nested.
40 auto pYields = parent.getYieldedValues();
41 auto nResult = node.getLoopResults().value();
42
43 bool yieldEq =
44 llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
45 return std::get<0>(zipped) == std::get<1>(zipped);
46 });
47
48 // Parent iter_args should be passed directly to the node's init_args.
49 bool iterArgEq =
50 llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
51 return std::get<0>(zipped) == std::get<1>(zipped);
52 });
53
54 return yieldEq && iterArgEq;
55}
56
57bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
58 ExtractIterSpaceOp curSpace) {
59
60 auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
61 Value spaceVal = space.getExtractedSpace();
62 if (spaceVal.hasOneUse())
63 return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
64 return nullptr;
65 };
66
67 if (toCollapse.empty()) {
68 // Collapse root.
69 if (auto itOp = getIterateOpOverSpace(curSpace)) {
70 CollapseSpaceInfo &info = toCollapse.emplace_back();
71 info.space = curSpace;
72 info.loop = itOp;
73 return true;
74 }
75 return false;
76 }
77
78 auto parent = toCollapse.back().space;
79 auto pItOp = toCollapse.back().loop;
80 auto nItOp = getIterateOpOverSpace(curSpace);
81
82 // Can only collapse spaces extracted from the same tensor.
83 if (parent.getTensor() != curSpace.getTensor()) {
84 LLVM_DEBUG({
85 llvm::dbgs()
86 << "failed to collpase spaces extracted from different tensors.";
87 });
88 return false;
89 }
90
91 // Can only collapse consecutive simple iteration on one tensor (i.e., no
92 // coiteration).
93 if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
94 pItOp.getIterator() != curSpace.getParentIter() ||
95 curSpace->getParentOp() != pItOp.getOperation()) {
96 LLVM_DEBUG(
97 { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
98 return false;
99 }
100
101 if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
102 LLVM_DEBUG({
103 llvm::dbgs()
104 << "failed to collapse IterateOps that are not perfectly nested.";
105 });
106 return false;
107 }
108
109 CollapseSpaceInfo &info = toCollapse.emplace_back();
110 info.space = curSpace;
111 info.loop = nItOp;
112 return true;
113}
114
115void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
116 if (toCollapse.size() < 2)
117 return;
118
119 ExtractIterSpaceOp root = toCollapse.front().space;
120 ExtractIterSpaceOp leaf = toCollapse.back().space;
121 Location loc = root.getLoc();
122
123 assert(root->hasOneUse() && leaf->hasOneUse());
124
125 // Insert collapsed operation at the same scope as root operation.
126 OpBuilder builder(root);
127
128 // Construct the collapsed iteration space.
129 auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
130 loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
131 leaf.getHiLvl());
132
133 auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
134 auto innermost = toCollapse.back().loop;
135
136 IRMapping mapper;
137 mapper.map(leaf, collapsedSpace.getExtractedSpace());
138 for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
139 mapper.map(std::get<0>(z), std::get<1>(z));
140
141 auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
142 builder.setInsertionPointToStart(cloned.getBody());
143
144 I64BitSet crdUsedLvls;
145 unsigned shift = 0, argIdx = 1;
146 for (auto info : toCollapse.drop_back()) {
147 I64BitSet set = info.loop.getCrdUsedLvls();
148 crdUsedLvls |= set.lshift(shift);
149 shift += info.loop.getSpaceDim();
150 for (BlockArgument crd : info.loop.getCrds()) {
151 BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
152 argIdx++, builder.getIndexType(), crd.getLoc());
153 crd.replaceAllUsesWith(collapsedCrd);
154 }
155 }
156 crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
157 cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
158 cloned.setCrdUsedLvls(crdUsedLvls);
159
160 rItOp.replaceAllUsesWith(cloned.getResults());
161 // Erase collapsed loops.
162 rItOp.erase();
163 root.erase();
164}
165
166struct SparseSpaceCollapsePass
167 : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
168 SparseSpaceCollapsePass() = default;
169
170 void runOnOperation() override {
171 func::FuncOp func = getOperation();
172
173 // A naive (experimental) implementation to collapse consecutive sparse
174 // spaces. It does NOT handle complex cases where multiple spaces are
175 // extracted in the same basic block. E.g.,
176 //
177 // %space1 = extract_space %t1 ...
178 // %space2 = extract_space %t2 ...
179 // sparse_tensor.iterate(%sp1) ...
180 //
181 SmallVector<CollapseSpaceInfo> toCollapse;
182 func->walk([&](ExtractIterSpaceOp op) {
183 if (!legalToCollapse(toCollapse, op)) {
184 // if not legal to collapse one more space, collapse the existing ones
185 // and clear.
186 collapseSparseSpace(toCollapse);
187 toCollapse.clear();
188 }
189 });
190
191 collapseSparseSpace(toCollapse);
192 }
193};
194
195} // namespace
196
197std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
198 return std::make_unique<SparseSpaceCollapsePass>();
199}
200

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp