1//===- IterationGraphSorter.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IterationGraphSorter.h"
10
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
15#include "mlir/IR/AffineExprVisitor.h"
16#include "mlir/IR/BuiltinTypes.h"
17
18using namespace mlir;
19using namespace mlir::sparse_tensor;
20
21namespace {
22
23/// A helper class that visits an affine expression and tries to find
24/// an AffineDimExpr to which the corresponding iterator from a GenericOp
25/// matches the desired iterator type. If there is no matched iterator
26/// type, the method returns the first DimExpr in the expression.
27class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
28public:
29 explicit AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)
30 : iterTypes(itTypes) {}
31
32 /// Overrides the visit method from AffineExprVisitor.
33 void visitDimExpr(AffineDimExpr expr) {
34 if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()])
35 pickedDim = expr;
36 }
37
38 /// Sets the desired iterator type that we want to pick.
39 void setPickedIterType(utils::IteratorType iterType) {
40 pickIterType = iterType;
41 }
42
43 /// Gets the desired AffineDimExpr.
44 AffineDimExpr getDimExpr() const {
45 return llvm::cast<AffineDimExpr>(Val: pickedDim);
46 }
47
48 /// Walks the graph in post order to find dim expr.
49 void walkPostOrder(AffineExpr expr) {
50 pickedDim = nullptr;
51 AffineExprVisitor<AffineDimFinder>::walkPostOrder(expr);
52 }
53
54private:
55 /// The picked AffineDimExpr after visit.
56 AffineExpr pickedDim;
57 /// The iterator type that we want.
58 utils::IteratorType pickIterType;
59 /// The mapping between levels and iterator types.
60 ArrayRef<utils::IteratorType> iterTypes;
61};
62
63/// Flattens an affine expression into a list of AffineDimExprs.
64struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
65 // Overrides method from AffineExprVisitor.
66 void visitDimExpr(AffineDimExpr expr) { dims.push_back(Elt: expr); }
67 SmallVector<AffineDimExpr> dims;
68};
69
70} // namespace
71
72inline static bool includesAny(SortMask mask1, SortMask mask2) {
73 return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
74}
75
76inline static bool includesDenseInput(SortMask mask) {
77 return includesAny(mask1: mask, mask2: SortMask::kIncludeDenseInput);
78}
79
80inline static bool includesDenseOutput(SortMask mask) {
81 return includesAny(mask1: mask, mask2: SortMask::kIncludeDenseOutput);
82}
83
84AffineMap IterationGraphSorter::topoSort() {
85 // The sorted result will put the first Reduction iterator to the
86 // latest possible position.
87 std::vector<unsigned> redIt; // reduce iterator with 0 degree
88 std::vector<unsigned> parIt; // parallel iterator with 0 degree
89 const unsigned numLoops = getNumLoops();
90 for (unsigned i = 0; i < numLoops; i++) {
91 if (inDegree[i] == 0) {
92 if (iterTypes[i] == utils::IteratorType::reduction)
93 redIt.push_back(x: i);
94 else
95 parIt.push_back(x: i);
96 }
97 }
98
99 SmallVector<unsigned> loopOrder;
100 while (!redIt.empty() || !parIt.empty()) {
101 // We always prefer a parallel loop over a reduction loop because putting
102 // a reduction loop early might make the loop sequence inadmissible.
103 auto &it = !parIt.empty() ? parIt : redIt;
104 auto src = it.back();
105 loopOrder.push_back(Elt: src);
106 it.pop_back();
107 // Update in-degree, and push 0-degree node into worklist.
108 for (unsigned dst = 0; dst < numLoops; dst++) {
109 if (itGraph[src][dst] && --inDegree[dst] == 0) {
110 if (iterTypes[dst] == utils::IteratorType::reduction)
111 redIt.push_back(x: dst);
112 else
113 parIt.push_back(x: dst);
114 }
115 }
116 }
117
118 // Return the topological sort on success.
119 if (loopOrder.size() == numLoops)
120 return AffineMap::getPermutationMap(permutation: loopOrder, context: out.getContext());
121
122 // Cycle detected.
123 return AffineMap();
124}
125
126IterationGraphSorter
127IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
128 // Must be a demapped sparse kernel.
129 assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
130 hasAnySparseOperandOrResult(genericOp) &&
131 genericOp.getNumDpsInits() == 1);
132
133 SmallVector<AffineMap> loopMap = genericOp.getIndexingMapsArray();
134 SmallVector<Value> ins = genericOp.getDpsInputs();
135
136 AffineMap outMap = loopMap.back();
137 loopMap.pop_back();
138
139 Value out = genericOp.getDpsInitOperand(0)->get();
140 SmallVector<utils::IteratorType> iterTypes =
141 genericOp.getIteratorTypesArray();
142
143 return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
144 std::move(iterTypes));
145}
146
147IterationGraphSorter::IterationGraphSorter(
148 SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
149 AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
150 : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
151 loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
152 // One map per tensor.
153 assert(loop2InsLvl.size() == ins.size());
154 // All the affine maps have the same number of dimensions (loops).
155 assert(llvm::all_equal(llvm::map_range(
156 loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
157 // The number of results of the map should match the rank of the tensor.
158 assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
159 auto [m, v] = mvPair;
160 return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
161 }));
162
163 itGraph.resize(new_size: getNumLoops(), x: std::vector<bool>(getNumLoops(), false));
164 inDegree.resize(new_size: getNumLoops());
165}
166
167AffineMap IterationGraphSorter::sort(SortMask mask, Value ignored) {
168 // Reset the adjacency matrix that represents the iteration graph.
169 for (auto &row : itGraph)
170 std::fill(first: row.begin(), last: row.end(), value: false);
171
172 // Reset in-degree.
173 std::fill(first: inDegree.begin(), last: inDegree.end(), value: 0);
174
175 // Add the constraints for the loop to level map.
176 for (auto [in, map] : llvm::zip(t&: ins, u&: loop2InsLvl)) {
177 // Get map and encoding.
178 const auto enc = getSparseTensorEncoding(in.getType());
179 // Skip dense inputs when not requested.
180 if ((!enc && !includesDenseInput(mask)) || in == ignored)
181 continue;
182 addConstraints(t: in, loop2LvlMap: map);
183 }
184
185 // Add the constraints for the output map.
186 const auto enc = getSparseTensorEncoding(out.getType());
187 if ((enc || includesDenseOutput(mask)) && out != ignored)
188 addConstraints(t: out, loop2LvlMap: loop2OutLvl);
189
190 // Return the topological sort (empty for cyclic).
191 return topoSort();
192}
193
194void IterationGraphSorter::addConstraints(Value t, AffineMap loop2LvlMap) {
195 auto addIterOrdering = [this](unsigned f, unsigned t) {
196 if (!itGraph[f][t] && f != t) {
197 itGraph[f][t] = true;
198 inDegree[t]++;
199 }
200 };
201
202 // Set up a reduction finder.
203 AffineDimFinder finder(iterTypes);
204 finder.setPickedIterType(utils::IteratorType::reduction);
205
206 // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
207 // we require there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
208 // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
209 const Level lvlRank = loop2LvlMap.getNumResults();
210 for (Level lvl = 1; lvl < lvlRank; lvl++) {
211 const AffineExpr fa = loop2LvlMap.getResult(idx: lvl - 1);
212 const AffineExpr ta = loop2LvlMap.getResult(idx: lvl);
213
214 if (llvm::isa<AffineDimExpr>(Val: fa) || llvm::isa<AffineDimExpr>(Val: ta)) {
215 // Special case when at least one loop2LvlExp is a simple AffineDimExpr
216 // (say, d0) and we require d0 > {d1, d2, ...} or {d1, d2, ...} > d0
217 AffineDimCollector fCollector;
218 fCollector.walkPostOrder(expr: fa);
219 AffineDimCollector tCollector;
220 tCollector.walkPostOrder(expr: ta);
221
222 for (auto fd : fCollector.dims) {
223 for (auto td : tCollector.dims) {
224 const unsigned f = fd.getPosition();
225 const unsigned t = td.getPosition();
226 addIterOrdering(f, t);
227 }
228 }
229 continue;
230 }
231
232 // When both loop2LvlExpr is compound, we pick an abitrary reduction loop
233 // from lhs and rhs and use them as d_x and d_y.
234 finder.walkPostOrder(expr: fa);
235 const AffineDimExpr fexp = finder.getDimExpr();
236 const unsigned fldx = fexp.getPosition();
237
238 finder.walkPostOrder(expr: ta);
239 const AffineDimExpr texp = finder.getDimExpr();
240 const unsigned tldx = texp.getPosition();
241
242 // d_x > d_y
243 addIterOrdering(fldx, tldx);
244
245 AffineDimCollector fCollector;
246 fCollector.walkPostOrder(expr: fa);
247 AffineDimCollector tCollector;
248 tCollector.walkPostOrder(expr: ta);
249
250 // Make sure dx and dy is the last.
251 for (auto fd : fCollector.dims) {
252 const unsigned f = fd.getPosition();
253 addIterOrdering(f, fldx);
254 }
255 for (auto td : tCollector.dims) {
256 const unsigned t = td.getPosition();
257 addIterOrdering(t, tldx);
258 }
259 // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
260 // This is to ensure that the affine expressions are reduced in sparse
261 // tensor level ordering.
262 for (auto fd : fCollector.dims) {
263 const unsigned f = fd.getPosition();
264 if (f == fldx) // skip d_x
265 continue;
266 for (auto td : tCollector.dims) {
267 const unsigned t = td.getPosition();
268 if (t == tldx) // skip d_y
269 continue;
270 addIterOrdering(f, t);
271 }
272 }
273 }
274}
275

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp