1 | //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/ |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Utils/CodegenUtils.h" |
10 | #include "Utils/IterationGraphSorter.h" |
11 | |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
15 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
16 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
17 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
18 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
20 | #include "mlir/IR/AffineExprVisitor.h" |
21 | #include "mlir/IR/AffineMap.h" |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::sparse_tensor; |
25 | |
26 | namespace { |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // File Local Helper classes. |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | // CRTP to help implementing a rewriter that demaps all its inputs. |
33 | template <typename SubClass, typename SourceOp> |
34 | struct DemapInsRewriter : public OpRewritePattern<SourceOp> { |
35 | using OpRewritePattern<SourceOp>::OpRewritePattern; |
36 | using OpAdaptor = typename SourceOp::Adaptor; |
37 | |
38 | LogicalResult matchAndRewrite(SourceOp op, |
39 | PatternRewriter &rewriter) const override { |
40 | Location loc = op.getLoc(); |
41 | |
42 | // Demaps non-trivial inputs. |
43 | bool changed = false; |
44 | SmallVector<Value> deMappedIns(op->getOperands()); |
45 | for (Value &in : deMappedIns) { |
46 | if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) { |
47 | in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); |
48 | changed = true; |
49 | } |
50 | } |
51 | |
52 | // CRTP call. |
53 | OpAdaptor adaptor(deMappedIns, op); |
54 | LogicalResult status = |
55 | static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter); |
56 | return changed ? success() : status; |
57 | } |
58 | }; |
59 | |
60 | // Flattens an affine expression into a list of AffineDimExprs. |
61 | struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { |
62 | explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){}; |
63 | void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); } |
64 | BitVector dims; |
65 | }; |
66 | |
67 | // Flattens an affine expression into a list of AffineDimExprs. |
68 | struct AffineExprAdmissibleVisitor |
69 | : public AffineExprVisitor<AffineExprAdmissibleVisitor> { |
70 | explicit AffineExprAdmissibleVisitor(bool isOutput) |
71 | : admissible(true), isOutput(isOutput){}; |
72 | |
73 | // We only allow AffineDimExpr on output. |
74 | void visitAddExpr(AffineBinaryOpExpr expr) { |
75 | if (isOutput) |
76 | admissible = false; |
77 | } |
78 | void visitMulExpr(AffineBinaryOpExpr expr) { |
79 | if (isOutput) |
80 | admissible = false; |
81 | } |
82 | |
83 | // We disallow mod, floor div and ceil div on inputs. |
84 | void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; } |
85 | void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; } |
86 | void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; } |
87 | operator bool() { return admissible; } |
88 | |
89 | private: |
90 | bool admissible; |
91 | bool isOutput; |
92 | }; |
93 | |
94 | // The first BitVector stores levels where inadmissible exprs are used. |
95 | // The second BitVector stores the AffineDimExp that are used by the |
96 | // inadmissible expressions. |
97 | using InadmissInfo = std::pair<BitVector, BitVector>; |
98 | |
99 | } // namespace |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // File Local Helper methods. |
103 | //===----------------------------------------------------------------------===// |
104 | |
105 | // Collects the inadmissible affine expression imposed on levels. |
106 | static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) { |
107 | auto ret = std::make_pair(x: BitVector(map.getNumResults()), |
108 | y: BitVector(map.getNumDims())); |
109 | AffineDimCollector collector(map.getNumDims()); |
110 | for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) { |
111 | AffineExprAdmissibleVisitor admissible(isOutput); |
112 | admissible.walkPostOrder(expr: map.getResult(idx: lvl)); |
113 | if (!admissible) { |
114 | // Record the inadmissible level. |
115 | ret.first.set(lvl); |
116 | // Record the AffineDimExpr that is used in the inadmissible expr. |
117 | collector.walkPostOrder(expr: map.getResult(idx: lvl)); |
118 | } |
119 | } |
120 | ret.second = collector.dims; |
121 | return ret; |
122 | } |
123 | |
124 | // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht |
125 | // inadmissible affine expressions can be eliminated. |
126 | // For example, we can rewrite |
127 | // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3) |
128 | // to |
129 | // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3) |
130 | // by composing inverse(idxMap), that is |
131 | // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3) |
132 | // -> ((l0 * 2 + l2) floordiv 2, |
133 | // (l1 * 3 + l3) floordiv 3, |
134 | // (l0 * 2 + l2) mod 2, |
135 | // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3) |
136 | // |
137 | // This function builds the inverse(idxMap) that replace every dimensions used |
138 | // in `info` to levels, and updates the iterator type array `itTps` for the new |
139 | // index variable introduced. |
140 | // |
141 | // Note that the returned affine map does not retain the order of the input |
142 | // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the |
143 | // replaced levels, and remaining ones for unused dimensions. |
144 | // For example, to handle |
145 | // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4) |
146 | // which is a typical map for block_2to4. The function returns: |
147 | // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1) |
148 | // in which, (l0, l1) together replaces `d1`, yet they appear |
149 | // before `d0` in the resulting affine map. |
150 | // The index (loop) order can later be canonicalized by a topo sort. |
151 | static AffineMap |
152 | genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, |
153 | SmallVector<utils::IteratorType> &itTps) { |
154 | MLIRContext *ctx = idxMap.getContext(); |
155 | auto [inAdLvls, usedDims] = info; |
156 | // Note that idxMap does not equal to dim2Lvl map, it is computed by |
157 | // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an |
158 | // ID map. |
159 | // TODO: we might fail here, in those case we should really return |
160 | // failure instead of assertion error. |
161 | auto lvl2Idx = inferLvlToDim(dimToLvl: idxMap, context: ctx); |
162 | |
163 | assert(lvl2Idx.getNumResults() <= idxMap.getNumDims()); |
164 | if (lvl2Idx.getNumResults() != idxMap.getNumDims()) { |
165 | // This could happen when some dimensions are projected. |
166 | // E.g., idx2Lvl = (*i*, j, k) -> (j, k) |
167 | // ==> lvl2Idx = (j, k) -> (j, k) |
168 | // In this case, we append the unused dimesion at the end. |
169 | // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k) |
170 | SmallVector<AffineExpr> results; |
171 | AffineDimCollector usedInLvl(idxMap.getNumDims()); |
172 | for (auto e : idxMap.getResults()) |
173 | usedInLvl.walkPostOrder(expr: e); |
174 | |
175 | unsigned curUsedDimID = 0; |
176 | unsigned curUnusedDimID = lvl2Idx.getNumDims(); |
177 | |
178 | BitVector unused = usedInLvl.dims.flip(); |
179 | for (unsigned i = 0; i < idxMap.getNumDims(); i++) { |
180 | if (unused.test(Idx: i)) |
181 | results.push_back(Elt: getAffineDimExpr(position: curUnusedDimID++, context: ctx)); |
182 | else |
183 | results.push_back(Elt: lvl2Idx.getResult(idx: curUsedDimID++)); |
184 | } |
185 | lvl2Idx = |
186 | AffineMap::get(dimCount: lvl2Idx.getNumDims() + unused.count(), symbolCount: 0, results, context: ctx); |
187 | } |
188 | assert(lvl2Idx.getNumResults() == idxMap.getNumDims()); |
189 | |
190 | // We do not need to replace the DimExpr that is not used in inadmissible |
191 | // level expressions. We use the first inAdLvl.count() dim to represent the |
192 | // replaced level, the remainings are reserved for unchanged ones. |
193 | // Note that results from the inverse map computed previously does not follow |
194 | // the convention we used, and we need to fix the mismatch below. |
195 | unsigned curRepID = 0; |
196 | unsigned curOriID = inAdLvls.count(); |
197 | SmallVector<AffineExpr> results; |
198 | SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr()); |
199 | SmallVector<utils::IteratorType> transItTps; |
200 | |
201 | for (unsigned l : inAdLvls.set_bits()) { |
202 | // By our convention, the inadmissible level `l` always appears in the |
203 | // leading part (accumulated by curRepID) of the affine map's parameter |
204 | // list. Record the mapping so that we can replace all the uses of `l` to |
205 | // the correct position after the translation. |
206 | dimRep[l] = getAffineDimExpr(position: curRepID++, context: ctx); |
207 | // A new index variable is introduced for the inadmissible level, inherit |
208 | // the iterator type. E.g., if l0 = d0 floordiv 2, the |
209 | // iterator type of l0 equals to the iterator type of d0. |
210 | AffineExpr lvlExp = idxMap.getResult(idx: l); |
211 | AffineDimCollector collector(idxMap.getNumDims()); |
212 | collector.walkPostOrder(expr: lvlExp); |
213 | // We assumes a level can only be derived from one dimension. |
214 | assert(collector.dims.count() == 1); |
215 | transItTps.push_back(Elt: itTps[collector.dims.find_first()]); |
216 | } |
217 | |
218 | for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) { |
219 | if (usedDims.test(Idx: d)) { |
220 | // The dimension is used in some of the inadmissible levels, and it need |
221 | // to be inversed. Get the inversion from the inverse map, and fix the |
222 | // mismatch captured by the above loop. |
223 | results.push_back(Elt: lvl2Idx.getResult(idx: d).replaceDims(dimReplacements: dimRep)); |
224 | } else { |
225 | // The dimension is not used in any of the inadmissible levels, and it |
226 | // does not need to be inversed. Fix the mismatch by mapping it to the |
227 | // trailing part of the affine map (accumulated by curOriID). |
228 | results.push_back(Elt: getAffineDimExpr(position: curOriID++, context: ctx)); |
229 | transItTps.push_back(Elt: itTps[d]); |
230 | } |
231 | } |
232 | unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count(); |
233 | // Update iterator type. |
234 | itTps.assign(in_start: transItTps.begin(), in_end: transItTps.end()); |
235 | return AffineMap::get(dimCount: numDim, symbolCount: 0, results, context: ctx); |
236 | } |
237 | |
238 | // Translates the index map in the linalg::GenericOp from idx->dim map to |
239 | // idx->lvl map. Returns failure if the index map can not be translated to an |
240 | // admissible form. |
241 | // Returns the translated index map array and the iterator type array. |
242 | static std::optional<std::pair<ArrayAttr, ArrayAttr>> |
243 | translateMap(linalg::GenericOp op, PatternRewriter &rewriter) { |
244 | // idxMap is a idx2dim map before reinterpretation. |
245 | MLIRContext *ctx = op.getContext(); |
246 | SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray(); |
247 | SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray(); |
248 | for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) { |
249 | Value tensor = op->getOpOperand(i).get(); |
250 | auto stt = tryGetSparseTensorType(tensor); |
251 | if (stt && !stt->isIdentity()) { |
252 | AffineMap dim2Lvl = stt->getDimToLvl(); |
253 | // By composing the idx2dim(dim2lvl), we got a idx2lvl Map |
254 | idxMapArray[i] = dim2Lvl.compose(map: idxMapArray[i]); |
255 | } |
256 | } |
257 | |
258 | // A naive way to handle common constant expressions that arise during dim2lvl |
259 | // translation. |
260 | auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping, |
261 | unsigned pos, int64_t lvlSz) { |
262 | if (!ShapedType::isDynamic(lvlSz)) { |
263 | auto c0 = getAffineConstantExpr(constant: 0, context: ctx); |
264 | auto lvlExp = getAffineDimExpr(position: pos, context: ctx); |
265 | auto szExp = getAffineConstantExpr(constant: lvlSz, context: ctx); |
266 | |
267 | // lvl floordiv lvlSz = 0 |
268 | auto divExp = |
269 | getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp); |
270 | cstMapping.try_emplace(divExp, c0); |
271 | |
272 | // lvl mod lvlSz = lvl |
273 | auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp); |
274 | cstMapping.try_emplace(modExp, lvlExp); |
275 | } |
276 | }; |
277 | |
278 | unsigned boundedNum = 0; |
279 | // A fixed-point algorithm. |
280 | bool changed = true; |
281 | while (changed) { |
282 | changed = false; |
283 | for (OpOperand &operand : op->getOpOperands()) { |
284 | auto stt = tryGetSparseTensorType(operand.get()); |
285 | // Skip on dense operands. |
286 | if (!stt || !stt->getEncoding()) |
287 | continue; |
288 | |
289 | unsigned tid = operand.getOperandNumber(); |
290 | bool isOutput = &operand == op.getDpsInitOperand(0); |
291 | AffineMap idxMap = idxMapArray[tid]; |
292 | InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput); |
293 | auto [inAdLvls, dimExprs] = inAdInfo; |
294 | for (unsigned d : dimExprs.set_bits()) { |
295 | // The first `boundedNum` used in the AffineMap is introduced to |
296 | // resolve previous inadmissible expressions. We can not replace them |
297 | // as it might bring back the inadmissible expressions. |
298 | if (d < boundedNum) |
299 | return std::nullopt; |
300 | } |
301 | |
302 | if (inAdLvls.count() != 0) { |
303 | // Naive constant progagation, should be sufficient to handle block |
304 | // sparsity in our cases. |
305 | SmallVector<int64_t> lvlShape = stt->getLvlShape(); |
306 | DenseMap<AffineExpr, AffineExpr> cstMapping; |
307 | unsigned position = 0; |
308 | for (unsigned lvl : inAdLvls.set_bits()) { |
309 | int64_t lvlSz = lvlShape[lvl]; |
310 | populateCstMapping(cstMapping, position, lvlSz); |
311 | position++; |
312 | } |
313 | |
314 | AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps); |
315 | // Compose the lvl2Idx Map to all AffineIdxMap to eliminate |
316 | // inadmissible expressions. |
317 | for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) { |
318 | AffineMap transMap = idxMapArray[tid].compose(lvl2Idx); |
319 | idxMapArray[tid] = transMap.replace( |
320 | cstMapping, /*numResultDims=*/transMap.getNumDims(), |
321 | /*numResultSyms=*/0); |
322 | } |
323 | changed = true; |
324 | boundedNum += inAdLvls.count(); |
325 | } |
326 | } |
327 | }; |
328 | |
329 | SmallVector<Attribute> iterAttr = |
330 | llvm::map_to_vector(C&: itTps, F: [ctx](auto itTp) -> Attribute { |
331 | return linalg::IteratorTypeAttr::get(ctx, itTp); |
332 | }); |
333 | |
334 | return std::make_pair(x: rewriter.getAffineMapArrayAttr(idxMapArray), |
335 | y: rewriter.getArrayAttr(iterAttr)); |
336 | } |
337 | |
338 | // Generates a "de"mapping reinterpretation of the map. |
339 | static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, |
340 | Value val) { |
341 | return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(), |
342 | val); |
343 | } |
344 | |
345 | // Generates a "re"mapping reinterpretation of the map. |
346 | static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, |
347 | Value val) { |
348 | return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val); |
349 | } |
350 | |
351 | static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types, |
352 | ValueRange outs) { |
353 | SmallVector<Value> ret(outs); |
354 | assert(outs.size() == types.size()); |
355 | for (auto [r, t] : llvm::zip(ret, types)) |
356 | if (r.getType() != t) |
357 | r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r); |
358 | return ret; |
359 | } |
360 | |
361 | namespace { |
362 | |
363 | //===----------------------------------------------------------------------===// |
364 | // Rewriting rules for linalg generic ops. |
365 | //===----------------------------------------------------------------------===// |
366 | |
367 | /// Sparse rewriting rule for the generic `linalg` operation. |
368 | struct GenericOpReinterpretMap |
369 | : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> { |
370 | public: |
371 | using DemapInsRewriter::DemapInsRewriter; |
372 | LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor, |
373 | PatternRewriter &rewriter) const { |
374 | // Only rewrite single output operations with pure (sparse) tensor |
375 | // semantics. |
376 | if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() || |
377 | !hasAnySparseOperandOrResult(linalgOp) || |
378 | !hasAnyNonIdentityOperandsOrResults(linalgOp)) |
379 | return failure(); |
380 | |
381 | // Try translating the index map. |
382 | auto transMap = translateMap(linalgOp, rewriter); |
383 | if (!transMap) |
384 | return rewriter.notifyMatchFailure( |
385 | linalgOp, "the sparse kernel can not be sparsified." ); |
386 | |
387 | // On success, replace update the linalg operands and maps in place. |
388 | Value res = linalgOp.getResult(0); |
389 | auto stt = tryGetSparseTensorType(res); |
390 | auto [idxMap, itTp] = *transMap; |
391 | |
392 | rewriter.startOpModification(op: linalgOp); |
393 | linalgOp.setIndexingMapsAttr(idxMap); |
394 | linalgOp.setIteratorTypesAttr(itTp); |
395 | // Use demapped arguments. |
396 | linalgOp.getInputsMutable().assign(adaptor.getInputs()); |
397 | linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs()); |
398 | res.setType(adaptor.getOutputs()[0].getType()); |
399 | rewriter.finalizeOpModification(op: linalgOp); |
400 | |
401 | rewriter.setInsertionPointAfter(linalgOp); |
402 | if (stt && stt->hasEncoding()) { |
403 | Value t = genRemap(rewriter, stt->getEncoding(), res); |
404 | rewriter.replaceAllUsesExcept(from: res, to: t, exceptedUser: t.getDefiningOp()); |
405 | } |
406 | return success(); |
407 | } |
408 | }; |
409 | |
410 | struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> { |
411 | using OpRewritePattern::OpRewritePattern; |
412 | LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, |
413 | PatternRewriter &rewriter) const override { |
414 | if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() || |
415 | hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first |
416 | !hasAnySparseOperandOrResult(linalgOp)) { |
417 | return failure(); |
418 | } |
419 | |
420 | const StringRef sorted = "sorted" ; |
421 | if (linalgOp->hasAttr(sorted)) |
422 | return failure(); |
423 | |
424 | auto scheduler = IterationGraphSorter::fromGenericOp(genericOp: linalgOp); |
425 | bool isAdmissible = false; |
426 | AffineMap order; |
427 | // A const list of all masks that we used for iteration graph |
428 | // computation. Must be ordered from more strict to less strict. |
429 | // Ideally (though might not be guaranteed), the earlier a constraint mask |
430 | // can be satisfied, the faster the generated kernel will be. |
431 | const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense, |
432 | SortMask::kIncludeDenseInput, |
433 | SortMask::kIncludeDenseOutput, |
434 | SortMask::kSparseOnly}; |
435 | for (const SortMask mask : allMasks) { |
436 | order = scheduler.sort(mask); |
437 | if (order) { |
438 | if (isAdmissibleOrder(linalgOp: linalgOp, order)) { |
439 | isAdmissible = true; |
440 | break; |
441 | } |
442 | // else try a set of less strict constraints. |
443 | } |
444 | } |
445 | |
446 | if (!order) { |
447 | // Cycles detected. |
448 | if (failed(resolveCycle(scheduler&: scheduler, linalgOp: linalgOp, rewriter))) { |
449 | return rewriter.notifyMatchFailure( |
450 | linalgOp, "the sparse kernel can not be scheduled: loop detected." ); |
451 | } |
452 | return success(); |
453 | } |
454 | |
455 | if (!isAdmissible) { |
456 | return rewriter.notifyMatchFailure( |
457 | linalgOp, "the sparse kernel can not be scheduled." ); |
458 | } |
459 | |
460 | // Marks the GenericOp to avoid recursive matching. |
461 | rewriter.modifyOpInPlace(linalgOp, [&]() { |
462 | linalgOp->setAttr(sorted, rewriter.getBoolAttr(value: true)); |
463 | }); |
464 | |
465 | // Already sorted. |
466 | if (order.isIdentity()) |
467 | return success(); |
468 | |
469 | assert(order.isPermutation()); |
470 | // `order` is orignial loop -> sorted loop map |
471 | ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr(); |
472 | SmallVector<Attribute> curItTypes; |
473 | curItTypes.reserve(N: preItTypes.size()); |
474 | for (AffineExpr expr : order.getResults()) { |
475 | unsigned loopID = llvm::cast<AffineDimExpr>(Val&: expr).getPosition(); |
476 | curItTypes.push_back(Elt: preItTypes[loopID]); |
477 | } |
478 | |
479 | // Inverse `order` to get sorted loop -> original loop map |
480 | order = inversePermutation(map: order); |
481 | SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray(); |
482 | for (AffineMap &idxMap : idxMaps) |
483 | idxMap = idxMap.compose(order); // sorted loop -> lvl map |
484 | |
485 | rewriter.startOpModification(op: linalgOp); |
486 | linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps)); |
487 | linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes)); |
488 | rewriter.finalizeOpModification(op: linalgOp); |
489 | |
490 | return success(); |
491 | } |
492 | |
493 | private: |
494 | /// Whether the loop order is admissible by sparsification. |
495 | static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) { |
496 | if (!hasAnySparseResult(linalgOp)) |
497 | return true; |
498 | |
499 | OpOperand *lhs = linalgOp.getDpsInitOperand(0); |
500 | unsigned nest = 0; |
501 | const auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
502 | for (const AffineExpr l : order.getResults()) { |
503 | unsigned loopId = llvm::cast<AffineDimExpr>(Val: l).getPosition(); |
504 | auto itTp = |
505 | cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]); |
506 | if (linalg::isReductionIterator(iteratorType: itTp.getValue())) |
507 | break; // terminate at first reduction |
508 | nest++; |
509 | } |
510 | // Determine admissible dynamic insertion situations: |
511 | // (1) fully injective, since there are no reductions, |
512 | // (2) admissible 1-d expansion in innermost dimension. |
513 | return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1; |
514 | }; |
515 | |
516 | // Last resort cycle resolution. |
517 | static LogicalResult resolveCycle(IterationGraphSorter &scheduler, |
518 | linalg::LinalgOp linalgOp, |
519 | PatternRewriter &rewriter) { |
520 | // Compute topological sort while leaving out every sparse input tensor in |
521 | // succession until an acylic iteration graph results. |
522 | for (OpOperand *t : linalgOp.getDpsInputOperands()) { |
523 | Value tval = t->get(); |
524 | auto srcEnc = getSparseTensorEncoding(tval.getType()); |
525 | // The constraints introduced by compound index expression are |
526 | // complicated. Skip them. |
527 | AffineMap idxMap = linalgOp.getMatchingIndexingMap(t); |
528 | bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) { |
529 | return !llvm::isa<AffineDimExpr>(exp); |
530 | }); |
531 | if (!srcEnc || hasCompExpr) |
532 | continue; |
533 | |
534 | // Try scheduling loop without constraints from `tval`. |
535 | AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval); |
536 | if (!order) // still cyclic |
537 | continue; |
538 | |
539 | // Found an input tensor that resolves the cycle by inserting a |
540 | // conversion into a sparse tensor that adheres to the iteration |
541 | // graph order. |
542 | auto stt = getSparseTensorType(tval); |
543 | assert(stt.isIdentity()); |
544 | order = inversePermutation(order); |
545 | // sorted loop -> lvl map. |
546 | idxMap = idxMap.compose(order); |
547 | |
548 | // Found a permutation such that the results in `idxMap` is sorted. |
549 | // For example, |
550 | // (d0, d1, d2, d3) -> (d2, d1, d0) |
551 | // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle, |
552 | // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the |
553 | // transposed tensor's levels are visited in the same order as the loop |
554 | // scheduling order. |
555 | SmallVector<std::pair<unsigned, unsigned>> lvlSeq; |
556 | for (AffineExpr expr : idxMap.getResults()) { |
557 | unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition(); |
558 | lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size())); |
559 | } |
560 | std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool { |
561 | return lhs.first < rhs.first; |
562 | }); |
563 | SmallVector<unsigned> perm = |
564 | llvm::to_vector(llvm::make_second_range(lvlSeq)); |
565 | auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext()); |
566 | // The result of the idxMap must be unsorted. |
567 | assert(!dimToLvl.isIdentity()); |
568 | |
569 | // Inserting the transpose |
570 | rewriter.setInsertionPoint(linalgOp); |
571 | RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType(); |
572 | Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval); |
573 | rewriter.modifyOpInPlace(linalgOp, [&]() { |
574 | linalgOp->setOperand(t->getOperandNumber(), dst); |
575 | }); |
576 | |
577 | // Release the transposed form afterwards. |
578 | // TODO: CSE when used in more than one following op? |
579 | rewriter.setInsertionPointAfter(linalgOp); |
580 | rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst); |
581 | |
582 | return success(); |
583 | } |
584 | // Cannot be resolved with a single conversion. |
585 | // TODO: convert more than one? |
586 | return failure(); |
587 | } |
588 | }; |
589 | |
590 | //===----------------------------------------------------------------------===// |
591 | // Reinterpret Map Rewriters for operations other than linalg.generics |
592 | //===----------------------------------------------------------------------===// |
593 | |
594 | template <typename AllocOp> |
595 | struct TensorAllocDemapper : public OpRewritePattern<AllocOp> { |
596 | using OpRewritePattern<AllocOp>::OpRewritePattern; |
597 | LogicalResult matchAndRewrite(AllocOp op, |
598 | PatternRewriter &rewriter) const override { |
599 | if (!hasAnyNonIdentityOperandsOrResults(op)) |
600 | return failure(); |
601 | |
602 | Location loc = op.getLoc(); |
603 | auto stt = getSparseTensorType(op.getResult()); |
604 | |
605 | SmallVector<Value> maxDimCrds; |
606 | maxDimCrds.reserve(N: stt.getDimRank()); |
607 | ValueRange dynSz = op.getDynamicSizes(); |
608 | for (int64_t dimSz : stt.getDimShape()) { |
609 | if (ShapedType::isDynamic(dimSz)) { |
610 | Value maxCrd = rewriter.create<arith::SubIOp>( |
611 | loc, dynSz.front(), constantIndex(rewriter, loc, 1)); |
612 | maxDimCrds.push_back(Elt: maxCrd); |
613 | dynSz = dynSz.drop_front(); |
614 | } else { |
615 | maxDimCrds.push_back(Elt: constantIndex(builder&: rewriter, loc, i: dimSz - 1)); |
616 | } |
617 | } |
618 | |
619 | ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds, |
620 | CrdTransDirectionKind::dim2lvl); |
621 | auto lvlShape = stt.getLvlShape(); |
622 | SmallVector<Value> dynLvlSzs; |
623 | for (unsigned i = 0, e = lvlShape.size(); i < e; i++) { |
624 | if (ShapedType::isDynamic(lvlShape[i])) { |
625 | Value sz = rewriter.create<arith::AddIOp>( |
626 | loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1)); |
627 | dynLvlSzs.push_back(Elt: sz); |
628 | } |
629 | } |
630 | |
631 | assert(dynSz.empty()); // should have consumed all. |
632 | rewriter.startOpModification(op); |
633 | op->setOperands(dynLvlSzs); |
634 | op.getResult().setType(stt.getDemappedType()); |
635 | rewriter.finalizeOpModification(op); |
636 | rewriter.setInsertionPointAfter(op); |
637 | |
638 | Value t = genRemap(rewriter, stt.getEncoding(), op.getResult()); |
639 | rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp()); |
640 | return success(); |
641 | } |
642 | }; |
643 | |
644 | struct TensorInsertDemapper |
645 | : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> { |
646 | using DemapInsRewriter::DemapInsRewriter; |
647 | LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor, |
648 | PatternRewriter &rewriter) const { |
649 | if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op)) |
650 | return failure(); |
651 | |
652 | Location loc = op.getLoc(); |
653 | auto stt = getSparseTensorType(op.getResult()); |
654 | ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(), |
655 | CrdTransDirectionKind::dim2lvl); |
656 | auto insertOp = rewriter.create<tensor::InsertOp>( |
657 | loc, op.getScalar(), adaptor.getDest(), lvlCrd); |
658 | |
659 | Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult()); |
660 | rewriter.replaceOp(op, out); |
661 | return success(); |
662 | } |
663 | }; |
664 | |
665 | struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> { |
666 | using OpRewritePattern::OpRewritePattern; |
667 | LogicalResult matchAndRewrite(AssembleOp op, |
668 | PatternRewriter &rewriter) const override { |
669 | if (!hasAnyNonIdentityOperandsOrResults(op)) |
670 | return failure(); |
671 | |
672 | assert(hasAnySparseResult(op)); |
673 | auto stt = getSparseTensorType(op.getResult()); |
674 | rewriter.modifyOpInPlace( |
675 | op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); }); |
676 | rewriter.setInsertionPointAfter(op); |
677 | Value out = genRemap(rewriter, stt.getEncoding(), op.getResult()); |
678 | rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp()); |
679 | return success(); |
680 | } |
681 | }; |
682 | |
683 | struct SparseDisassembleDemapper |
684 | : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> { |
685 | using DemapInsRewriter::DemapInsRewriter; |
686 | LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor, |
687 | PatternRewriter &rewriter) const { |
688 | if (!hasAnyNonIdentityOperandsOrResults(op)) |
689 | return failure(); |
690 | |
691 | assert(hasAnySparseOperandOrResult(op)); |
692 | rewriter.modifyOpInPlace(op, [&op, &adaptor]() { |
693 | op.getTensorMutable().assign(adaptor.getTensor()); |
694 | }); |
695 | return success(); |
696 | } |
697 | }; |
698 | |
699 | struct ForeachOpDemapper |
700 | : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> { |
701 | using DemapInsRewriter::DemapInsRewriter; |
702 | LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor, |
703 | PatternRewriter &rewriter) const { |
704 | // Only handle operations with sparse input/output with non-identity dim2lvl |
705 | // maps. |
706 | if (!hasAnyNonIdentityOperandsOrResults(op)) |
707 | return failure(); |
708 | |
709 | // TODO: demap constant as well. |
710 | if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>()) |
711 | if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) |
712 | return failure(); |
713 | |
714 | Location loc = op.getLoc(); |
715 | // Cache the type information since we update the foreach op in-place. |
716 | auto srcStt = getSparseTensorType(op.getTensor()); |
717 | SmallVector<Type> prevRetTps(op.getResultTypes()); |
718 | |
719 | rewriter.startOpModification(op: op); |
720 | op.getTensorMutable().assign(adaptor.getTensor()); |
721 | op.getInitArgsMutable().assign(adaptor.getInitArgs()); |
722 | // Update results' types. |
723 | for (auto r : op.getResults()) |
724 | if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity()) |
725 | r.setType(stt->getDemappedType()); |
726 | |
727 | Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank(); |
728 | // Update the foreach body. |
729 | SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType()); |
730 | blockArgTps.push_back(Elt: srcStt.getElementType()); |
731 | blockArgTps.append(adaptor.getInitArgs().getTypes().begin(), |
732 | adaptor.getInitArgs().getTypes().end()); |
733 | Block *body = op.getBody(); |
734 | // Block Args: [dimCrd, val, initArgs] |
735 | unsigned preArgNum = body->getNumArguments(); |
736 | for (Type t : blockArgTps) |
737 | body->addArgument(t, loc); |
738 | |
739 | // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs] |
740 | rewriter.setInsertionPointToStart(body); |
741 | ValueRange lvlCrds = body->getArguments().slice(N: preArgNum, M: lvlRank); |
742 | |
743 | ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds, |
744 | CrdTransDirectionKind::lvl2dim); |
745 | rewriter.replaceAllUsesWith( |
746 | body->getArguments().take_front(N: srcStt.getDimRank()), dimCrds); |
747 | body->eraseArguments(0, srcStt.getDimRank()); |
748 | // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs] |
749 | unsigned numInitArgs = op.getInitArgs().size(); |
750 | rewriter.replaceAllUsesWith(from: body->getArgument(i: 0), |
751 | to: body->getArgument(i: lvlRank + numInitArgs + 1)); |
752 | body->eraseArgument(index: 0); |
753 | // Block Args: [initArgs, lvlCrds, val, DemappedArgs] |
754 | ValueRange srcArgs = body->getArguments().take_front(N: numInitArgs); |
755 | ValueRange dstArgs = body->getArguments().take_back(N: numInitArgs); |
756 | // Remap back before replacement. |
757 | SmallVector<Value> reMappedArgs = |
758 | remapValueRange(rewriter, types: srcArgs.getTypes(), outs: dstArgs); |
759 | rewriter.replaceAllUsesWith(from: srcArgs, to: reMappedArgs); |
760 | body->eraseArguments(start: 0, num: numInitArgs); |
761 | // Block Args: [lvlCrds, DemappedArgs] and we are done. |
762 | |
763 | // Update yield operations. |
764 | if (numInitArgs != 0) { |
765 | rewriter.setInsertionPointToEnd(body); |
766 | auto yield = llvm::cast<YieldOp>(body->getTerminator()); |
767 | if (auto stt = tryGetSparseTensorType(yield.getSingleResult()); |
768 | stt && !stt->isIdentity()) { |
769 | Value y = |
770 | genDemap(rewriter, stt->getEncoding(), yield.getSingleResult()); |
771 | rewriter.create<YieldOp>(loc, y); |
772 | rewriter.eraseOp(op: yield); |
773 | } |
774 | } |
775 | rewriter.finalizeOpModification(op: op); |
776 | |
777 | rewriter.setInsertionPointAfter(op); |
778 | SmallVector<Value> outs = |
779 | remapValueRange(rewriter, prevRetTps, op.getResults()); |
780 | |
781 | // Replace all the uses of the foreach results, expect the use in |
782 | // reinterpret_map used to remap the output. |
783 | for (auto [from, to] : llvm::zip(op.getResults(), outs)) |
784 | rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp()); |
785 | |
786 | return success(); |
787 | } |
788 | }; |
789 | |
790 | } // namespace |
791 | |
792 | void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, |
793 | ReinterpretMapScope scope) { |
794 | if (scope == ReinterpretMapScope::kAll || |
795 | scope == ReinterpretMapScope::kGenericOnly) { |
796 | patterns.add<GenericOpReinterpretMap, GenericOpScheduler>( |
797 | arg: patterns.getContext()); |
798 | } |
799 | if (scope == ReinterpretMapScope::kAll || |
800 | scope == ReinterpretMapScope::kExceptGeneric) { |
801 | patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>, |
802 | TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper, |
803 | SparseDisassembleDemapper, TensorInsertDemapper, |
804 | ForeachOpDemapper>(patterns.getContext()); |
805 | } |
806 | } |
807 | |