1
2#include "Utils/CodegenUtils.h"
3#include "Utils/LoopEmitter.h"
4#include "Utils/SparseTensorIterator.h"
5
6#include "mlir/Dialect/MemRef/IR/MemRef.h"
7#include "mlir/Dialect/SCF/IR/SCF.h"
8#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
9#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10#include "mlir/Transforms/DialectConversion.h"
11
12using namespace mlir;
13using namespace mlir::sparse_tensor;
14
15static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
16 SmallVectorImpl<Type> &fields) {
17 // Position and coordinate buffer in the sparse structure.
18 if (enc.getLvlType(lvl).isWithPosLT())
19 fields.push_back(Elt: enc.getPosMemRefType());
20 if (enc.getLvlType(lvl).isWithCrdLT())
21 fields.push_back(Elt: enc.getCrdMemRefType());
22 // One index for shape bound (result from lvlOp).
23 fields.push_back(IndexType::get(enc.getContext()));
24}
25
26static std::optional<LogicalResult>
27convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
28
29 auto idxTp = IndexType::get(itSp.getContext());
30 for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
31 convertLevelType(itSp.getEncoding(), l, fields);
32
33 // Two indices for lower and upper bound (we only need one pair for the last
34 // iteration space).
35 fields.append({idxTp, idxTp});
36 return success();
37}
38
39static std::optional<LogicalResult>
40convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
41 // The actually Iterator Values (that are updated every iteration).
42 auto idxTp = IndexType::get(itTp.getContext());
43 // TODO: handle batch dimension.
44 assert(itTp.getEncoding().getBatchLvlRank() == 0);
45 if (!itTp.isUnique()) {
46 // Segment high for non-unique iterator.
47 fields.push_back(Elt: idxTp);
48 }
49 fields.push_back(Elt: idxTp);
50 return success();
51}
52
53static ValueRange
54genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
55 Value loopCrd,
56 ArrayRef<std::unique_ptr<SparseIterator>> iters,
57 ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
58 ArrayRef<Value> userReduc) {
59 if (newBlocks.empty())
60 return userReduc;
61
62 // The current branch that we are handling.
63 Block *newBlock = newBlocks.front();
64 Block *oldBlock = oldBlocks.front();
65 Value casePred = constantI1(builder&: rewriter, loc, b: true);
66 I64BitSet caseBits =
67 op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
68 for (unsigned i : caseBits.bits()) {
69 SparseIterator *it = iters[i].get();
70 Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
71 it->getCrd(), loopCrd);
72 casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
73 }
74 scf::IfOp ifOp = rewriter.create<scf::IfOp>(
75 loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
76 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
77
78 // Erase the empty block.
79 rewriter.eraseBlock(block: &ifOp.getThenRegion().front());
80 // Set up block arguments: user-provided values -> loop coord -> iterators.
81 SmallVector<Value> blockArgs(userReduc);
82 blockArgs.push_back(Elt: loopCrd);
83 for (unsigned idx : caseBits.bits())
84 llvm::append_range(blockArgs, iters[idx]->getCursor());
85
86 // Map the old block arguments, because the dialect conversion driver does
87 // not immediately perform SSA value replacements. This function is still
88 // seeing the old uses.
89 IRMapping mapping;
90 for (auto [from, to] : llvm::zip_equal(t: oldBlock->getArguments(), u&: blockArgs)) {
91 mapping.map(from, to);
92 }
93
94 // Clone the region, we can not erase the region now because the same region
95 // might be a subcase for multiple lattice point.
96 rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
97 ifOp.getThenRegion().begin(), mapping);
98 // Remove the block arguments, they were already replaced via `mapping`.
99 ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
100
101 // replace sparse_tensor::YieldOp -> scf::YieldOp
102 auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
103 ValueRange yields = spY.getResults();
104 rewriter.eraseOp(op: spY);
105 rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
106 rewriter.create<scf::YieldOp>(loc, yields);
107
108 // Generates remaining case recursively.
109 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
110 ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
111 newBlocks.drop_front(),
112 oldBlocks.drop_front(), userReduc);
113 if (!res.empty())
114 rewriter.create<scf::YieldOp>(loc, res);
115
116 rewriter.setInsertionPointAfter(ifOp);
117 return ifOp.getResults();
118}
119
120static ValueRange genLoopWithIterator(
121 PatternRewriter &rewriter, Location loc, SparseIterator *it,
122 ValueRange reduc,
123 function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
124 Region &loopBody, SparseIterator *it,
125 ValueRange reduc)>
126 bodyBuilder) {
127 if (it->iteratableByFor()) {
128 auto [lo, hi] = it->genForCond(b&: rewriter, l: loc);
129 Value step = constantIndex(builder&: rewriter, loc, i: 1);
130 scf::ForOp forOp = rewriter.create<scf::ForOp>(
131 loc, lo, hi, step, reduc,
132 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
133 // Empty builder function to ensure that no terminator is created.
134 });
135 {
136 OpBuilder::InsertionGuard guard(rewriter);
137 it->linkNewScope(pos: forOp.getInductionVar());
138 rewriter.setInsertionPointToStart(forOp.getBody());
139 SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
140 it, forOp.getRegionIterArgs());
141
142 rewriter.setInsertionPointToEnd(forOp.getBody());
143 rewriter.create<scf::YieldOp>(loc, ret);
144 }
145 return forOp.getResults();
146 }
147
148 SmallVector<Value> ivs(reduc);
149 llvm::append_range(C&: ivs, R: it->getCursor());
150
151 TypeRange types = ValueRange(ivs).getTypes();
152 auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
153 {
154 OpBuilder::InsertionGuard guard(rewriter);
155 // Generates loop conditions.
156 SmallVector<Location> l(types.size(), loc);
157 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
158 rewriter.setInsertionPointToStart(before);
159 ValueRange bArgs = before->getArguments();
160 auto [whileCond, remArgs] = it->genWhileCond(b&: rewriter, l: loc, vs: bArgs);
161 rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
162
163 // Delegates loop body generation.
164 Region &dstRegion = whileOp.getAfter();
165 Block *after = rewriter.createBlock(parent: &dstRegion, insertPt: {}, argTypes: types, locs: l);
166 ValueRange aArgs = whileOp.getAfterArguments();
167 it->linkNewScope(pos: aArgs.drop_front(n: reduc.size()));
168 aArgs = aArgs.take_front(n: reduc.size());
169
170 rewriter.setInsertionPointToStart(after);
171 SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
172 rewriter.setInsertionPointToEnd(after);
173
174 // Forward loops
175 SmallVector<Value> yields;
176 llvm::append_range(C&: yields, R&: ret);
177 llvm::append_range(C&: yields, R: it->forward(b&: rewriter, l: loc));
178 rewriter.create<scf::YieldOp>(loc, yields);
179 }
180 return whileOp.getResults().drop_front(it->getCursor().size());
181}
182
183namespace {
184
185/// Sparse codegen rule for number of entries operator.
186class ExtractIterSpaceConverter
187 : public OpConversionPattern<ExtractIterSpaceOp> {
188public:
189 using OpConversionPattern::OpConversionPattern;
190 LogicalResult
191 matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter) const override {
193 Location loc = op.getLoc();
194
195 // Construct the iteration space.
196 SparseIterationSpace space(loc, rewriter,
197 llvm::getSingleElement(adaptor.getTensor()), 0,
198 op.getLvlRange(), adaptor.getParentIter());
199
200 SmallVector<Value> result = space.toValues();
201 rewriter.replaceOpWithMultiple(op, {result});
202 return success();
203 }
204};
205
206/// Sparse codegen rule for number of entries operator.
207class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
208public:
209 using OpConversionPattern::OpConversionPattern;
210 LogicalResult
211 matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
212 ConversionPatternRewriter &rewriter) const override {
213 Location loc = op.getLoc();
214 Value pos = adaptor.getIterator().back();
215 Value valBuf = rewriter.create<ToValuesOp>(
216 loc, llvm::getSingleElement(adaptor.getTensor()));
217 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
218 return success();
219 }
220};
221
222class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
223public:
224 using OpConversionPattern::OpConversionPattern;
225 LogicalResult
226 matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override {
228 if (!op.getCrdUsedLvls().empty())
229 return rewriter.notifyMatchFailure(
230 op, "non-empty coordinates list not implemented.");
231
232 Location loc = op.getLoc();
233
234 auto iterSpace = SparseIterationSpace::fromValues(
235 op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
236
237 std::unique_ptr<SparseIterator> it =
238 iterSpace.extractIterator(rewriter, loc);
239
240 SmallVector<Value> ivs;
241 for (ValueRange inits : adaptor.getInitArgs())
242 llvm::append_range(ivs, inits);
243
244 // Type conversion on iterate op block.
245 unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
246 TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
247 if (failed(typeConverter->convertSignatureArgs(
248 op.getBody()->getArgumentTypes(), signatureConversion)))
249 return rewriter.notifyMatchFailure(
250 op, "failed to convert iterate region argurment types");
251
252 Block *block = rewriter.applySignatureConversion(
253 block: op.getBody(), conversion&: signatureConversion, converter: getTypeConverter());
254 ValueRange ret = genLoopWithIterator(
255 rewriter, loc, it: it.get(), reduc: ivs,
256 bodyBuilder: [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
257 SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
258 SmallVector<Value> blockArgs(reduc);
259 // TODO: Also appends coordinates if used.
260 // blockArgs.push_back(it->deref(rewriter, loc));
261 llvm::append_range(C&: blockArgs, R: it->getCursor());
262
263 Block *dstBlock = &loopBody.getBlocks().front();
264 rewriter.inlineBlockBefore(source: block, dest: dstBlock, before: dstBlock->end(),
265 argValues: blockArgs);
266 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
267 // We can not use ValueRange as the operation holding the values will
268 // be destroyed.
269 SmallVector<Value> result(yield.getResults());
270 rewriter.eraseOp(op: yield);
271 return result;
272 });
273
274 rewriter.replaceOp(op, ret);
275 return success();
276 }
277};
278
279class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
280 using OpConversionPattern::OpConversionPattern;
281
282 LogicalResult
283 matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
284 ConversionPatternRewriter &rewriter) const override {
285 assert(op.getSpaceDim() == 1 && "Not implemented");
286 Location loc = op.getLoc();
287
288 I64BitSet denseBits(0);
289 for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
290 if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
291 denseBits.set(idx);
292
293 // If there exists a case that only contains dense spaces. I.e., case
294 // bits is a subset of dense bits, or when there is a full empty case (due
295 // to complements), we need a universal pointer to forward the coiteration
296 // loop.
297 bool needUniv =
298 any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
299 // A case for complement.
300 if (caseBits.count() == 0)
301 return true;
302 // An all-dense case.
303 return caseBits.isSubSetOf(p: denseBits);
304 });
305 assert(!needUniv && "Not implemented");
306 (void)needUniv;
307
308 SmallVector<Block *> newBlocks;
309 DenseMap<Block *, Block *> newToOldBlockMap;
310 for (Region &region : op.getCaseRegions()) {
311 // Do a one-shot type conversion on all region blocks, since the same
312 // region might be used multiple time.
313 Block *block = &region.getBlocks().front();
314 TypeConverter::SignatureConversion blockTypeMapping(
315 block->getArgumentTypes().size());
316 if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
317 blockTypeMapping))) {
318 return rewriter.notifyMatchFailure(
319 op, "failed to convert coiterate region argurment types");
320 }
321
322 newBlocks.push_back(rewriter.applySignatureConversion(
323 block, blockTypeMapping, getTypeConverter()));
324 newToOldBlockMap[newBlocks.back()] = block;
325 }
326
327 SmallVector<SparseIterationSpace> spaces;
328 SmallVector<std::unique_ptr<SparseIterator>> iters;
329 for (auto [spaceTp, spaceVals] : llvm::zip_equal(
330 op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
331 // TODO: do we really need tid?
332 spaces.push_back(SparseIterationSpace::fromValues(
333 cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
334 // Extract the iterator.
335 iters.push_back(spaces.back().extractIterator(rewriter, loc));
336 }
337
338 auto getFilteredIters = [&iters](I64BitSet caseBits) {
339 // Retrives a vector of pointers to the iterators used in the case.
340 SmallVector<SparseIterator *> validIters;
341 for (auto idx : caseBits.bits())
342 validIters.push_back(Elt: iters[idx].get());
343 return validIters;
344 };
345
346 // Get a flattened user-provided loop reduction values.
347 SmallVector<Value> userReduc;
348 for (ValueRange r : adaptor.getInitArgs())
349 llvm::append_range(userReduc, r);
350
351 // TODO: we need to sort the cases such that they appears in lexical order.
352 // Although sparsification always generates cases in that order, it might
353 // not be the case for human-written code.
354
355 // Generates a loop sequence, one loop per case.
356 for (auto [r, caseBits] :
357 llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
358 assert(caseBits.count() > 0 && "Complement space not implemented");
359
360 // Retrives a vector of pointers to the iterators used in the case.
361 SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
362
363 if (validIters.size() > 1) {
364 auto [loop, loopCrd] =
365 genCoIteration(rewriter, loc, validIters, userReduc,
366 /*uniIdx=*/nullptr, /*userReducFirst=*/true);
367
368 // 1st. find all the cases that is a strict subset of the current case
369 // condition, for which we generate one branch per case inside the loop.
370 // The subcases are never empty, it must contains at least the current
371 // region itself.
372 // TODO: these cases should be sorted.
373 SmallVector<Region *> subCases =
374 op.getSubCasesOf(r->getParent()->getRegionNumber());
375 SmallVector<Block *> newBlocks, oldBlocks;
376 for (Region *r : subCases) {
377 newBlocks.push_back(&r->front());
378 oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
379 }
380 assert(!subCases.empty());
381
382 ValueRange res = genCoIterateBranchNest(
383 rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
384
385 SmallVector<Value> nextIterYields(res);
386 // 2nd. foward the loop.
387 for (SparseIterator *it : validIters) {
388 Value cmp = rewriter.create<arith::CmpIOp>(
389 loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
390 it->forwardIf(rewriter, loc, cmp);
391 llvm::append_range(nextIterYields, it->getCursor());
392 }
393 rewriter.create<scf::YieldOp>(loc, nextIterYields);
394
395 // Exit the loop, relink the iterator SSA value.
396 rewriter.setInsertionPointAfter(loop);
397 ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
398 for (SparseIterator *it : validIters)
399 iterVals = it->linkNewScope(iterVals);
400 assert(iterVals.empty());
401
402 ValueRange curResult = loop->getResults().take_front(userReduc.size());
403 userReduc.assign(curResult.begin(), curResult.end());
404 } else {
405 // This is a simple iteration loop.
406 assert(caseBits.count() == 1);
407
408 Block *block = r;
409 ValueRange curResult = genLoopWithIterator(
410 rewriter, loc, validIters.front(), userReduc,
411 /*bodyBuilder=*/
412 [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
413 SparseIterator *it,
414 ValueRange reduc) -> SmallVector<Value> {
415 SmallVector<Value> blockArgs(reduc);
416 blockArgs.push_back(it->deref(rewriter, loc));
417 llvm::append_range(blockArgs, it->getCursor());
418
419 Block *dstBlock = &dstRegion.getBlocks().front();
420 rewriter.inlineBlockBefore(
421 block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
422 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
423 SmallVector<Value> result(yield.getResults());
424 rewriter.eraseOp(yield);
425 return result;
426 });
427
428 userReduc.assign(curResult.begin(), curResult.end());
429 }
430 }
431
432 rewriter.replaceOp(op, userReduc);
433 return success();
434 }
435};
436
437} // namespace
438
439mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
440 addConversion(callback: [](Type type) { return type; });
441 addConversion(convertIteratorType);
442 addConversion(convertIterSpaceType);
443
444 addSourceMaterialization(callback: [](OpBuilder &builder, IterSpaceType spTp,
445 ValueRange inputs, Location loc) -> Value {
446 return builder
447 .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
448 .getResult(0);
449 });
450}
451
452void mlir::populateLowerSparseIterationToSCFPatterns(
453 const TypeConverter &converter, RewritePatternSet &patterns) {
454
455 IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
456 patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
457 SparseIterateOpConverter, SparseCoIterateOpConverter>(
458 arg: converter, args: patterns.getContext());
459}
460

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp