1//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils/CodegenUtils.h"
10#include "Utils/IterationGraphSorter.h"
11
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Utils/Utils.h"
16#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/IR/AffineExprVisitor.h"
21#include "mlir/IR/AffineMap.h"
22
23using namespace mlir;
24using namespace mlir::sparse_tensor;
25
26namespace {
27
28//===----------------------------------------------------------------------===//
29// File Local Helper classes.
30//===----------------------------------------------------------------------===//
31
32// CRTP to help implementing a rewriter that demaps all its inputs.
33template <typename SubClass, typename SourceOp>
34struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
35 using OpRewritePattern<SourceOp>::OpRewritePattern;
36 using OpAdaptor = typename SourceOp::Adaptor;
37
38 LogicalResult matchAndRewrite(SourceOp op,
39 PatternRewriter &rewriter) const override {
40 Location loc = op.getLoc();
41
42 // Demaps non-trivial inputs.
43 bool changed = false;
44 SmallVector<Value> deMappedIns(op->getOperands());
45 for (Value &in : deMappedIns) {
46 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
47 in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48 changed = true;
49 }
50 }
51
52 // CRTP call.
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56 return changed ? success() : status;
57 }
58};
59
60// Flattens an affine expression into a list of AffineDimExprs.
61struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
63 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64 BitVector dims;
65};
66
67// Flattens an affine expression into a list of AffineDimExprs.
68struct AffineExprAdmissibleVisitor
69 : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70 explicit AffineExprAdmissibleVisitor(bool isOutput)
71 : admissible(true), isOutput(isOutput){};
72
73 // We only allow AffineDimExpr on output.
74 void visitAddExpr(AffineBinaryOpExpr expr) {
75 if (isOutput)
76 admissible = false;
77 }
78 void visitMulExpr(AffineBinaryOpExpr expr) {
79 if (isOutput)
80 admissible = false;
81 }
82
83 // We disallow mod, floor div and ceil div on inputs.
84 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
85 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
87 operator bool() { return admissible; }
88
89private:
90 bool admissible;
91 bool isOutput;
92};
93
94// The first BitVector stores levels where inadmissible exprs are used.
95// The second BitVector stores the AffineDimExp that are used by the
96// inadmissible expressions.
97using InadmissInfo = std::pair<BitVector, BitVector>;
98
99} // namespace
100
101//===----------------------------------------------------------------------===//
102// File Local Helper methods.
103//===----------------------------------------------------------------------===//
104
105// Collects the inadmissible affine expression imposed on levels.
106static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
107 auto ret = std::make_pair(x: BitVector(map.getNumResults()),
108 y: BitVector(map.getNumDims()));
109 AffineDimCollector collector(map.getNumDims());
110 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
111 AffineExprAdmissibleVisitor admissible(isOutput);
112 admissible.walkPostOrder(expr: map.getResult(idx: lvl));
113 if (!admissible) {
114 // Record the inadmissible level.
115 ret.first.set(lvl);
116 // Record the AffineDimExpr that is used in the inadmissible expr.
117 collector.walkPostOrder(expr: map.getResult(idx: lvl));
118 }
119 }
120 ret.second = collector.dims;
121 return ret;
122}
123
124// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125// inadmissible affine expressions can be eliminated.
126// For example, we can rewrite
127// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
128// to
129// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130// by composing inverse(idxMap), that is
131// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132// -> ((l0 * 2 + l2) floordiv 2,
133// (l1 * 3 + l3) floordiv 3,
134// (l0 * 2 + l2) mod 2,
135// (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
136//
137// This function builds the inverse(idxMap) that replace every dimensions used
138// in `info` to levels, and updates the iterator type array `itTps` for the new
139// index variable introduced.
140//
141// Note that the returned affine map does not retain the order of the input
142// affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143// replaced levels, and remaining ones for unused dimensions.
144// For example, to handle
145// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146// which is a typical map for block_2to4. The function returns:
147// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148// in which, (l0, l1) together replaces `d1`, yet they appear
149// before `d0` in the resulting affine map.
150// The index (loop) order can later be canonicalized by a topo sort.
151static AffineMap
152genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153 SmallVector<utils::IteratorType> &itTps) {
154 MLIRContext *ctx = idxMap.getContext();
155 auto [inAdLvls, usedDims] = info;
156 // Note that idxMap does not equal to dim2Lvl map, it is computed by
157 // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
158 // ID map.
159 // TODO: we might fail here, in those case we should really return
160 // failure instead of assertion error.
161 auto lvl2Idx = inferLvlToDim(dimToLvl: idxMap, context: ctx);
162
163 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
164 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
165 // This could happen when some dimensions are projected.
166 // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167 // ==> lvl2Idx = (j, k) -> (j, k)
168 // In this case, we append the unused dimesion at the end.
169 // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170 SmallVector<AffineExpr> results;
171 AffineDimCollector usedInLvl(idxMap.getNumDims());
172 for (auto e : idxMap.getResults())
173 usedInLvl.walkPostOrder(expr: e);
174
175 unsigned curUsedDimID = 0;
176 unsigned curUnusedDimID = lvl2Idx.getNumDims();
177
178 BitVector unused = usedInLvl.dims.flip();
179 for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
180 if (unused.test(Idx: i))
181 results.push_back(Elt: getAffineDimExpr(position: curUnusedDimID++, context: ctx));
182 else
183 results.push_back(Elt: lvl2Idx.getResult(idx: curUsedDimID++));
184 }
185 lvl2Idx =
186 AffineMap::get(dimCount: lvl2Idx.getNumDims() + unused.count(), symbolCount: 0, results, context: ctx);
187 }
188 assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
189
190 // We do not need to replace the DimExpr that is not used in inadmissible
191 // level expressions. We use the first inAdLvl.count() dim to represent the
192 // replaced level, the remainings are reserved for unchanged ones.
193 // Note that results from the inverse map computed previously does not follow
194 // the convention we used, and we need to fix the mismatch below.
195 unsigned curRepID = 0;
196 unsigned curOriID = inAdLvls.count();
197 SmallVector<AffineExpr> results;
198 SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
199 SmallVector<utils::IteratorType> transItTps;
200
201 for (unsigned l : inAdLvls.set_bits()) {
202 // By our convention, the inadmissible level `l` always appears in the
203 // leading part (accumulated by curRepID) of the affine map's parameter
204 // list. Record the mapping so that we can replace all the uses of `l` to
205 // the correct position after the translation.
206 dimRep[l] = getAffineDimExpr(position: curRepID++, context: ctx);
207 // A new index variable is introduced for the inadmissible level, inherit
208 // the iterator type. E.g., if l0 = d0 floordiv 2, the
209 // iterator type of l0 equals to the iterator type of d0.
210 AffineExpr lvlExp = idxMap.getResult(idx: l);
211 AffineDimCollector collector(idxMap.getNumDims());
212 collector.walkPostOrder(expr: lvlExp);
213 // We assumes a level can only be derived from one dimension.
214 assert(collector.dims.count() == 1);
215 transItTps.push_back(Elt: itTps[collector.dims.find_first()]);
216 }
217
218 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
219 if (usedDims.test(Idx: d)) {
220 // The dimension is used in some of the inadmissible levels, and it need
221 // to be inversed. Get the inversion from the inverse map, and fix the
222 // mismatch captured by the above loop.
223 results.push_back(Elt: lvl2Idx.getResult(idx: d).replaceDims(dimReplacements: dimRep));
224 } else {
225 // The dimension is not used in any of the inadmissible levels, and it
226 // does not need to be inversed. Fix the mismatch by mapping it to the
227 // trailing part of the affine map (accumulated by curOriID).
228 results.push_back(Elt: getAffineDimExpr(position: curOriID++, context: ctx));
229 transItTps.push_back(Elt: itTps[d]);
230 }
231 }
232 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
233 // Update iterator type.
234 itTps.assign(in_start: transItTps.begin(), in_end: transItTps.end());
235 return AffineMap::get(dimCount: numDim, symbolCount: 0, results, context: ctx);
236}
237
238// Translates the index map in the linalg::GenericOp from idx->dim map to
239// idx->lvl map. Returns failure if the index map can not be translated to an
240// admissible form.
241// Returns the translated index map array and the iterator type array.
242static std::optional<std::pair<ArrayAttr, ArrayAttr>>
243translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
244 // idxMap is a idx2dim map before reinterpretation.
245 MLIRContext *ctx = op.getContext();
246 SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
247 SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
248 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
249 Value tensor = op->getOpOperand(i).get();
250 auto stt = tryGetSparseTensorType(tensor);
251 if (stt && !stt->isIdentity()) {
252 AffineMap dim2Lvl = stt->getDimToLvl();
253 // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254 idxMapArray[i] = dim2Lvl.compose(map: idxMapArray[i]);
255 }
256 }
257
258 // A naive way to handle common constant expressions that arise during dim2lvl
259 // translation.
260 auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
261 unsigned pos, int64_t lvlSz) {
262 if (!ShapedType::isDynamic(lvlSz)) {
263 auto c0 = getAffineConstantExpr(constant: 0, context: ctx);
264 auto lvlExp = getAffineDimExpr(position: pos, context: ctx);
265 auto szExp = getAffineConstantExpr(constant: lvlSz, context: ctx);
266
267 // lvl floordiv lvlSz = 0
268 auto divExp =
269 getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
270 cstMapping.try_emplace(divExp, c0);
271
272 // lvl mod lvlSz = lvl
273 auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
274 cstMapping.try_emplace(modExp, lvlExp);
275 }
276 };
277
278 unsigned boundedNum = 0;
279 // A fixed-point algorithm.
280 bool changed = true;
281 while (changed) {
282 changed = false;
283 for (OpOperand &operand : op->getOpOperands()) {
284 auto stt = tryGetSparseTensorType(operand.get());
285 // Skip on dense operands.
286 if (!stt || !stt->getEncoding())
287 continue;
288
289 unsigned tid = operand.getOperandNumber();
290 bool isOutput = &operand == op.getDpsInitOperand(0);
291 AffineMap idxMap = idxMapArray[tid];
292 InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
293 auto [inAdLvls, dimExprs] = inAdInfo;
294 for (unsigned d : dimExprs.set_bits()) {
295 // The first `boundedNum` used in the AffineMap is introduced to
296 // resolve previous inadmissible expressions. We can not replace them
297 // as it might bring back the inadmissible expressions.
298 if (d < boundedNum)
299 return std::nullopt;
300 }
301
302 if (inAdLvls.count() != 0) {
303 // Naive constant progagation, should be sufficient to handle block
304 // sparsity in our cases.
305 SmallVector<int64_t> lvlShape = stt->getLvlShape();
306 DenseMap<AffineExpr, AffineExpr> cstMapping;
307 unsigned position = 0;
308 for (unsigned lvl : inAdLvls.set_bits()) {
309 int64_t lvlSz = lvlShape[lvl];
310 populateCstMapping(cstMapping, position, lvlSz);
311 position++;
312 }
313
314 AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
315 // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316 // inadmissible expressions.
317 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319 idxMapArray[tid] = transMap.replace(
320 cstMapping, /*numResultDims=*/transMap.getNumDims(),
321 /*numResultSyms=*/0);
322 }
323 changed = true;
324 boundedNum += inAdLvls.count();
325 }
326 }
327 };
328
329 SmallVector<Attribute> iterAttr =
330 llvm::map_to_vector(C&: itTps, F: [ctx](auto itTp) -> Attribute {
331 return linalg::IteratorTypeAttr::get(ctx, itTp);
332 });
333
334 return std::make_pair(x: rewriter.getAffineMapArrayAttr(idxMapArray),
335 y: rewriter.getArrayAttr(iterAttr));
336}
337
338// Generates a "de"mapping reinterpretation of the map.
339static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340 Value val) {
341 return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
342 val);
343}
344
345// Generates a "re"mapping reinterpretation of the map.
346static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
347 Value val) {
348 return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
349}
350
351static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
352 ValueRange outs) {
353 SmallVector<Value> ret(outs);
354 assert(outs.size() == types.size());
355 for (auto [r, t] : llvm::zip(ret, types))
356 if (r.getType() != t)
357 r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
358 return ret;
359}
360
361namespace {
362
363//===----------------------------------------------------------------------===//
364// Rewriting rules for linalg generic ops.
365//===----------------------------------------------------------------------===//
366
367/// Sparse rewriting rule for the generic `linalg` operation.
368struct GenericOpReinterpretMap
369 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370public:
371 using DemapInsRewriter::DemapInsRewriter;
372 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
373 PatternRewriter &rewriter) const {
374 // Only rewrite single output operations with pure (sparse) tensor
375 // semantics.
376 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
377 !hasAnySparseOperandOrResult(linalgOp) ||
378 !hasAnyNonIdentityOperandsOrResults(linalgOp))
379 return failure();
380
381 // Try translating the index map.
382 auto transMap = translateMap(linalgOp, rewriter);
383 if (!transMap)
384 return rewriter.notifyMatchFailure(
385 linalgOp, "the sparse kernel can not be sparsified.");
386
387 // On success, replace update the linalg operands and maps in place.
388 Value res = linalgOp.getResult(0);
389 auto stt = tryGetSparseTensorType(res);
390 auto [idxMap, itTp] = *transMap;
391
392 rewriter.startOpModification(op: linalgOp);
393 linalgOp.setIndexingMapsAttr(idxMap);
394 linalgOp.setIteratorTypesAttr(itTp);
395 // Use demapped arguments.
396 linalgOp.getInputsMutable().assign(adaptor.getInputs());
397 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398 res.setType(adaptor.getOutputs()[0].getType());
399 rewriter.finalizeOpModification(op: linalgOp);
400
401 rewriter.setInsertionPointAfter(linalgOp);
402 if (stt && stt->hasEncoding()) {
403 Value t = genRemap(rewriter, stt->getEncoding(), res);
404 rewriter.replaceAllUsesExcept(from: res, to: t, exceptedUser: t.getDefiningOp());
405 }
406 return success();
407 }
408};
409
410struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
411 using OpRewritePattern::OpRewritePattern;
412 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
413 PatternRewriter &rewriter) const override {
414 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
415 hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
416 !hasAnySparseOperandOrResult(linalgOp)) {
417 return failure();
418 }
419
420 const StringRef sorted = "sorted";
421 if (linalgOp->hasAttr(sorted))
422 return failure();
423
424 auto scheduler = IterationGraphSorter::fromGenericOp(genericOp: linalgOp);
425 bool isAdmissible = false;
426 AffineMap order;
427 // A const list of all masks that we used for iteration graph
428 // computation. Must be ordered from more strict to less strict.
429 // Ideally (though might not be guaranteed), the earlier a constraint mask
430 // can be satisfied, the faster the generated kernel will be.
431 const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
432 SortMask::kIncludeDenseInput,
433 SortMask::kIncludeDenseOutput,
434 SortMask::kSparseOnly};
435 for (const SortMask mask : allMasks) {
436 order = scheduler.sort(mask);
437 if (order) {
438 if (isAdmissibleOrder(linalgOp: linalgOp, order)) {
439 isAdmissible = true;
440 break;
441 }
442 // else try a set of less strict constraints.
443 }
444 }
445
446 if (!order) {
447 // Cycles detected.
448 if (failed(resolveCycle(scheduler&: scheduler, linalgOp: linalgOp, rewriter))) {
449 return rewriter.notifyMatchFailure(
450 linalgOp, "the sparse kernel can not be scheduled: loop detected.");
451 }
452 return success();
453 }
454
455 if (!isAdmissible) {
456 return rewriter.notifyMatchFailure(
457 linalgOp, "the sparse kernel can not be scheduled.");
458 }
459
460 // Marks the GenericOp to avoid recursive matching.
461 rewriter.modifyOpInPlace(linalgOp, [&]() {
462 linalgOp->setAttr(sorted, rewriter.getBoolAttr(value: true));
463 });
464
465 // Already sorted.
466 if (order.isIdentity())
467 return success();
468
469 assert(order.isPermutation());
470 // `order` is orignial loop -> sorted loop map
471 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
472 SmallVector<Attribute> curItTypes;
473 curItTypes.reserve(N: preItTypes.size());
474 for (AffineExpr expr : order.getResults()) {
475 unsigned loopID = llvm::cast<AffineDimExpr>(Val&: expr).getPosition();
476 curItTypes.push_back(Elt: preItTypes[loopID]);
477 }
478
479 // Inverse `order` to get sorted loop -> original loop map
480 order = inversePermutation(map: order);
481 SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
482 for (AffineMap &idxMap : idxMaps)
483 idxMap = idxMap.compose(order); // sorted loop -> lvl map
484
485 rewriter.startOpModification(op: linalgOp);
486 linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
487 linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
488 rewriter.finalizeOpModification(op: linalgOp);
489
490 return success();
491 }
492
493private:
494 /// Whether the loop order is admissible by sparsification.
495 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
496 if (!hasAnySparseResult(linalgOp))
497 return true;
498
499 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
500 unsigned nest = 0;
501 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
502 for (const AffineExpr l : order.getResults()) {
503 unsigned loopId = llvm::cast<AffineDimExpr>(Val: l).getPosition();
504 auto itTp =
505 cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
506 if (linalg::isReductionIterator(iteratorType: itTp.getValue()))
507 break; // terminate at first reduction
508 nest++;
509 }
510 // Determine admissible dynamic insertion situations:
511 // (1) fully injective, since there are no reductions,
512 // (2) admissible 1-d expansion in innermost dimension.
513 return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
514 };
515
516 // Last resort cycle resolution.
517 static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
518 linalg::LinalgOp linalgOp,
519 PatternRewriter &rewriter) {
520 // Compute topological sort while leaving out every sparse input tensor in
521 // succession until an acylic iteration graph results.
522 for (OpOperand *t : linalgOp.getDpsInputOperands()) {
523 Value tval = t->get();
524 auto srcEnc = getSparseTensorEncoding(tval.getType());
525 // The constraints introduced by compound index expression are
526 // complicated. Skip them.
527 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
528 bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
529 return !llvm::isa<AffineDimExpr>(exp);
530 });
531 if (!srcEnc || hasCompExpr)
532 continue;
533
534 // Try scheduling loop without constraints from `tval`.
535 AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
536 if (!order) // still cyclic
537 continue;
538
539 // Found an input tensor that resolves the cycle by inserting a
540 // conversion into a sparse tensor that adheres to the iteration
541 // graph order.
542 auto stt = getSparseTensorType(tval);
543 assert(stt.isIdentity());
544 order = inversePermutation(order);
545 // sorted loop -> lvl map.
546 idxMap = idxMap.compose(order);
547
548 // Found a permutation such that the results in `idxMap` is sorted.
549 // For example,
550 // (d0, d1, d2, d3) -> (d2, d1, d0)
551 // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
552 // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
553 // transposed tensor's levels are visited in the same order as the loop
554 // scheduling order.
555 SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
556 for (AffineExpr expr : idxMap.getResults()) {
557 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
558 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
559 }
560 std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool {
561 return lhs.first < rhs.first;
562 });
563 SmallVector<unsigned> perm =
564 llvm::to_vector(llvm::make_second_range(lvlSeq));
565 auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
566 // The result of the idxMap must be unsorted.
567 assert(!dimToLvl.isIdentity());
568
569 // Inserting the transpose
570 rewriter.setInsertionPoint(linalgOp);
571 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
572 Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
573 rewriter.modifyOpInPlace(linalgOp, [&]() {
574 linalgOp->setOperand(t->getOperandNumber(), dst);
575 });
576
577 // Release the transposed form afterwards.
578 // TODO: CSE when used in more than one following op?
579 rewriter.setInsertionPointAfter(linalgOp);
580 rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst);
581
582 return success();
583 }
584 // Cannot be resolved with a single conversion.
585 // TODO: convert more than one?
586 return failure();
587 }
588};
589
590//===----------------------------------------------------------------------===//
591// Reinterpret Map Rewriters for operations other than linalg.generics
592//===----------------------------------------------------------------------===//
593
594template <typename AllocOp>
595struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
596 using OpRewritePattern<AllocOp>::OpRewritePattern;
597 LogicalResult matchAndRewrite(AllocOp op,
598 PatternRewriter &rewriter) const override {
599 if (!hasAnyNonIdentityOperandsOrResults(op))
600 return failure();
601
602 Location loc = op.getLoc();
603 auto stt = getSparseTensorType(op.getResult());
604
605 SmallVector<Value> maxDimCrds;
606 maxDimCrds.reserve(N: stt.getDimRank());
607 ValueRange dynSz = op.getDynamicSizes();
608 for (int64_t dimSz : stt.getDimShape()) {
609 if (ShapedType::isDynamic(dimSz)) {
610 Value maxCrd = rewriter.create<arith::SubIOp>(
611 loc, dynSz.front(), constantIndex(rewriter, loc, 1));
612 maxDimCrds.push_back(Elt: maxCrd);
613 dynSz = dynSz.drop_front();
614 } else {
615 maxDimCrds.push_back(Elt: constantIndex(builder&: rewriter, loc, i: dimSz - 1));
616 }
617 }
618
619 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
620 CrdTransDirectionKind::dim2lvl);
621 auto lvlShape = stt.getLvlShape();
622 SmallVector<Value> dynLvlSzs;
623 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
624 if (ShapedType::isDynamic(lvlShape[i])) {
625 Value sz = rewriter.create<arith::AddIOp>(
626 loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
627 dynLvlSzs.push_back(Elt: sz);
628 }
629 }
630
631 assert(dynSz.empty()); // should have consumed all.
632 rewriter.startOpModification(op);
633 op->setOperands(dynLvlSzs);
634 op.getResult().setType(stt.getDemappedType());
635 rewriter.finalizeOpModification(op);
636 rewriter.setInsertionPointAfter(op);
637
638 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
639 rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
640 return success();
641 }
642};
643
644struct TensorInsertDemapper
645 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
646 using DemapInsRewriter::DemapInsRewriter;
647 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
648 PatternRewriter &rewriter) const {
649 if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
650 return failure();
651
652 Location loc = op.getLoc();
653 auto stt = getSparseTensorType(op.getResult());
654 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
655 CrdTransDirectionKind::dim2lvl);
656 auto insertOp = rewriter.create<tensor::InsertOp>(
657 loc, op.getScalar(), adaptor.getDest(), lvlCrd);
658
659 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
660 rewriter.replaceOp(op, out);
661 return success();
662 }
663};
664
665struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
666 using OpRewritePattern::OpRewritePattern;
667 LogicalResult matchAndRewrite(AssembleOp op,
668 PatternRewriter &rewriter) const override {
669 if (!hasAnyNonIdentityOperandsOrResults(op))
670 return failure();
671
672 assert(hasAnySparseResult(op));
673 auto stt = getSparseTensorType(op.getResult());
674 rewriter.modifyOpInPlace(
675 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
676 rewriter.setInsertionPointAfter(op);
677 Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
678 rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
679 return success();
680 }
681};
682
683struct SparseDisassembleDemapper
684 : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
685 using DemapInsRewriter::DemapInsRewriter;
686 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
687 PatternRewriter &rewriter) const {
688 if (!hasAnyNonIdentityOperandsOrResults(op))
689 return failure();
690
691 assert(hasAnySparseOperandOrResult(op));
692 rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
693 op.getTensorMutable().assign(adaptor.getTensor());
694 });
695 return success();
696 }
697};
698
699struct ForeachOpDemapper
700 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
701 using DemapInsRewriter::DemapInsRewriter;
702 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
703 PatternRewriter &rewriter) const {
704 // Only handle operations with sparse input/output with non-identity dim2lvl
705 // maps.
706 if (!hasAnyNonIdentityOperandsOrResults(op))
707 return failure();
708
709 // TODO: demap constant as well.
710 if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
711 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
712 return failure();
713
714 Location loc = op.getLoc();
715 // Cache the type information since we update the foreach op in-place.
716 auto srcStt = getSparseTensorType(op.getTensor());
717 SmallVector<Type> prevRetTps(op.getResultTypes());
718
719 rewriter.startOpModification(op: op);
720 op.getTensorMutable().assign(adaptor.getTensor());
721 op.getInitArgsMutable().assign(adaptor.getInitArgs());
722 // Update results' types.
723 for (auto r : op.getResults())
724 if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
725 r.setType(stt->getDemappedType());
726
727 Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
728 // Update the foreach body.
729 SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
730 blockArgTps.push_back(Elt: srcStt.getElementType());
731 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
732 adaptor.getInitArgs().getTypes().end());
733 Block *body = op.getBody();
734 // Block Args: [dimCrd, val, initArgs]
735 unsigned preArgNum = body->getNumArguments();
736 for (Type t : blockArgTps)
737 body->addArgument(t, loc);
738
739 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
740 rewriter.setInsertionPointToStart(body);
741 ValueRange lvlCrds = body->getArguments().slice(N: preArgNum, M: lvlRank);
742
743 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
744 CrdTransDirectionKind::lvl2dim);
745 rewriter.replaceAllUsesWith(
746 body->getArguments().take_front(N: srcStt.getDimRank()), dimCrds);
747 body->eraseArguments(0, srcStt.getDimRank());
748 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
749 unsigned numInitArgs = op.getInitArgs().size();
750 rewriter.replaceAllUsesWith(from: body->getArgument(i: 0),
751 to: body->getArgument(i: lvlRank + numInitArgs + 1));
752 body->eraseArgument(index: 0);
753 // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
754 ValueRange srcArgs = body->getArguments().take_front(N: numInitArgs);
755 ValueRange dstArgs = body->getArguments().take_back(N: numInitArgs);
756 // Remap back before replacement.
757 SmallVector<Value> reMappedArgs =
758 remapValueRange(rewriter, types: srcArgs.getTypes(), outs: dstArgs);
759 rewriter.replaceAllUsesWith(from: srcArgs, to: reMappedArgs);
760 body->eraseArguments(start: 0, num: numInitArgs);
761 // Block Args: [lvlCrds, DemappedArgs] and we are done.
762
763 // Update yield operations.
764 if (numInitArgs != 0) {
765 rewriter.setInsertionPointToEnd(body);
766 auto yield = llvm::cast<YieldOp>(body->getTerminator());
767 if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
768 stt && !stt->isIdentity()) {
769 Value y =
770 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
771 rewriter.create<YieldOp>(loc, y);
772 rewriter.eraseOp(op: yield);
773 }
774 }
775 rewriter.finalizeOpModification(op: op);
776
777 rewriter.setInsertionPointAfter(op);
778 SmallVector<Value> outs =
779 remapValueRange(rewriter, prevRetTps, op.getResults());
780
781 // Replace all the uses of the foreach results, expect the use in
782 // reinterpret_map used to remap the output.
783 for (auto [from, to] : llvm::zip(op.getResults(), outs))
784 rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
785
786 return success();
787 }
788};
789
790} // namespace
791
792void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
793 ReinterpretMapScope scope) {
794 if (scope == ReinterpretMapScope::kAll ||
795 scope == ReinterpretMapScope::kGenericOnly) {
796 patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
797 arg: patterns.getContext());
798 }
799 if (scope == ReinterpretMapScope::kAll ||
800 scope == ReinterpretMapScope::kExceptGeneric) {
801 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
802 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
803 SparseDisassembleDemapper, TensorInsertDemapper,
804 ForeachOpDemapper>(patterns.getContext());
805 }
806}
807

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp