1 | |
2 | #include "Utils/CodegenUtils.h" |
3 | #include "Utils/LoopEmitter.h" |
4 | #include "Utils/SparseTensorIterator.h" |
5 | |
6 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
7 | #include "mlir/Dialect/SCF/IR/SCF.h" |
8 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
9 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
10 | #include "mlir/Transforms/DialectConversion.h" |
11 | |
12 | using namespace mlir; |
13 | using namespace mlir::sparse_tensor; |
14 | |
15 | static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, |
16 | SmallVectorImpl<Type> &fields) { |
17 | // Position and coordinate buffer in the sparse structure. |
18 | if (enc.getLvlType(lvl).isWithPosLT()) |
19 | fields.push_back(Elt: enc.getPosMemRefType()); |
20 | if (enc.getLvlType(lvl).isWithCrdLT()) |
21 | fields.push_back(Elt: enc.getCrdMemRefType()); |
22 | // One index for shape bound (result from lvlOp). |
23 | fields.push_back(IndexType::get(enc.getContext())); |
24 | } |
25 | |
26 | static std::optional<LogicalResult> |
27 | convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) { |
28 | |
29 | auto idxTp = IndexType::get(itSp.getContext()); |
30 | for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++) |
31 | convertLevelType(itSp.getEncoding(), l, fields); |
32 | |
33 | // Two indices for lower and upper bound (we only need one pair for the last |
34 | // iteration space). |
35 | fields.append({idxTp, idxTp}); |
36 | return success(); |
37 | } |
38 | |
39 | static std::optional<LogicalResult> |
40 | convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) { |
41 | // The actually Iterator Values (that are updated every iteration). |
42 | auto idxTp = IndexType::get(itTp.getContext()); |
43 | // TODO: handle batch dimension. |
44 | assert(itTp.getEncoding().getBatchLvlRank() == 0); |
45 | if (!itTp.isUnique()) { |
46 | // Segment high for non-unique iterator. |
47 | fields.push_back(Elt: idxTp); |
48 | } |
49 | fields.push_back(Elt: idxTp); |
50 | return success(); |
51 | } |
52 | |
53 | static ValueRange |
54 | genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, |
55 | Value loopCrd, |
56 | ArrayRef<std::unique_ptr<SparseIterator>> iters, |
57 | ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks, |
58 | ArrayRef<Value> userReduc) { |
59 | if (newBlocks.empty()) |
60 | return userReduc; |
61 | |
62 | // The current branch that we are handling. |
63 | Block *newBlock = newBlocks.front(); |
64 | Block *oldBlock = oldBlocks.front(); |
65 | Value casePred = constantI1(builder&: rewriter, loc, b: true); |
66 | I64BitSet caseBits = |
67 | op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber()); |
68 | for (unsigned i : caseBits.bits()) { |
69 | SparseIterator *it = iters[i].get(); |
70 | Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
71 | it->getCrd(), loopCrd); |
72 | casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred); |
73 | } |
74 | scf::IfOp ifOp = rewriter.create<scf::IfOp>( |
75 | loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true); |
76 | rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
77 | |
78 | // Erase the empty block. |
79 | rewriter.eraseBlock(block: &ifOp.getThenRegion().front()); |
80 | // Set up block arguments: user-provided values -> loop coord -> iterators. |
81 | SmallVector<Value> blockArgs(userReduc); |
82 | blockArgs.push_back(Elt: loopCrd); |
83 | for (unsigned idx : caseBits.bits()) |
84 | llvm::append_range(blockArgs, iters[idx]->getCursor()); |
85 | |
86 | // Map the old block arguments, because the dialect conversion driver does |
87 | // not immediately perform SSA value replacements. This function is still |
88 | // seeing the old uses. |
89 | IRMapping mapping; |
90 | for (auto [from, to] : llvm::zip_equal(t: oldBlock->getArguments(), u&: blockArgs)) { |
91 | mapping.map(from, to); |
92 | } |
93 | |
94 | // Clone the region, we can not erase the region now because the same region |
95 | // might be a subcase for multiple lattice point. |
96 | rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(), |
97 | ifOp.getThenRegion().begin(), mapping); |
98 | // Remove the block arguments, they were already replaced via `mapping`. |
99 | ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size()); |
100 | |
101 | // replace sparse_tensor::YieldOp -> scf::YieldOp |
102 | auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back()); |
103 | ValueRange yields = spY.getResults(); |
104 | rewriter.eraseOp(op: spY); |
105 | rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front()); |
106 | rewriter.create<scf::YieldOp>(loc, yields); |
107 | |
108 | // Generates remaining case recursively. |
109 | rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
110 | ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters, |
111 | newBlocks.drop_front(), |
112 | oldBlocks.drop_front(), userReduc); |
113 | if (!res.empty()) |
114 | rewriter.create<scf::YieldOp>(loc, res); |
115 | |
116 | rewriter.setInsertionPointAfter(ifOp); |
117 | return ifOp.getResults(); |
118 | } |
119 | |
120 | static ValueRange genLoopWithIterator( |
121 | PatternRewriter &rewriter, Location loc, SparseIterator *it, |
122 | ValueRange reduc, |
123 | function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc, |
124 | Region &loopBody, SparseIterator *it, |
125 | ValueRange reduc)> |
126 | bodyBuilder) { |
127 | if (it->iteratableByFor()) { |
128 | auto [lo, hi] = it->genForCond(b&: rewriter, l: loc); |
129 | Value step = constantIndex(builder&: rewriter, loc, i: 1); |
130 | scf::ForOp forOp = rewriter.create<scf::ForOp>( |
131 | loc, lo, hi, step, reduc, |
132 | [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { |
133 | // Empty builder function to ensure that no terminator is created. |
134 | }); |
135 | { |
136 | OpBuilder::InsertionGuard guard(rewriter); |
137 | it->linkNewScope(pos: forOp.getInductionVar()); |
138 | rewriter.setInsertionPointToStart(forOp.getBody()); |
139 | SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(), |
140 | it, forOp.getRegionIterArgs()); |
141 | |
142 | rewriter.setInsertionPointToEnd(forOp.getBody()); |
143 | rewriter.create<scf::YieldOp>(loc, ret); |
144 | } |
145 | return forOp.getResults(); |
146 | } |
147 | |
148 | SmallVector<Value> ivs(reduc); |
149 | llvm::append_range(C&: ivs, R: it->getCursor()); |
150 | |
151 | TypeRange types = ValueRange(ivs).getTypes(); |
152 | auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs); |
153 | { |
154 | OpBuilder::InsertionGuard guard(rewriter); |
155 | // Generates loop conditions. |
156 | SmallVector<Location> l(types.size(), loc); |
157 | Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); |
158 | rewriter.setInsertionPointToStart(before); |
159 | ValueRange bArgs = before->getArguments(); |
160 | auto [whileCond, remArgs] = it->genWhileCond(b&: rewriter, l: loc, vs: bArgs); |
161 | rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); |
162 | |
163 | // Delegates loop body generation. |
164 | Region &dstRegion = whileOp.getAfter(); |
165 | Block *after = rewriter.createBlock(parent: &dstRegion, insertPt: {}, argTypes: types, locs: l); |
166 | ValueRange aArgs = whileOp.getAfterArguments(); |
167 | it->linkNewScope(pos: aArgs.drop_front(n: reduc.size())); |
168 | aArgs = aArgs.take_front(n: reduc.size()); |
169 | |
170 | rewriter.setInsertionPointToStart(after); |
171 | SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs); |
172 | rewriter.setInsertionPointToEnd(after); |
173 | |
174 | // Forward loops |
175 | SmallVector<Value> yields; |
176 | llvm::append_range(C&: yields, R&: ret); |
177 | llvm::append_range(C&: yields, R: it->forward(b&: rewriter, l: loc)); |
178 | rewriter.create<scf::YieldOp>(loc, yields); |
179 | } |
180 | return whileOp.getResults().drop_front(it->getCursor().size()); |
181 | } |
182 | |
183 | namespace { |
184 | |
185 | /// Sparse codegen rule for number of entries operator. |
186 | class |
187 | : public OpConversionPattern<ExtractIterSpaceOp> { |
188 | public: |
189 | using OpConversionPattern::OpConversionPattern; |
190 | LogicalResult |
191 | matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor, |
192 | ConversionPatternRewriter &rewriter) const override { |
193 | Location loc = op.getLoc(); |
194 | |
195 | // Construct the iteration space. |
196 | SparseIterationSpace space(loc, rewriter, |
197 | llvm::getSingleElement(adaptor.getTensor()), 0, |
198 | op.getLvlRange(), adaptor.getParentIter()); |
199 | |
200 | SmallVector<Value> result = space.toValues(); |
201 | rewriter.replaceOpWithMultiple(op, {result}); |
202 | return success(); |
203 | } |
204 | }; |
205 | |
206 | /// Sparse codegen rule for number of entries operator. |
207 | class : public OpConversionPattern<ExtractValOp> { |
208 | public: |
209 | using OpConversionPattern::OpConversionPattern; |
210 | LogicalResult |
211 | matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor, |
212 | ConversionPatternRewriter &rewriter) const override { |
213 | Location loc = op.getLoc(); |
214 | Value pos = adaptor.getIterator().back(); |
215 | Value valBuf = rewriter.create<ToValuesOp>( |
216 | loc, llvm::getSingleElement(adaptor.getTensor())); |
217 | rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos); |
218 | return success(); |
219 | } |
220 | }; |
221 | |
222 | class SparseIterateOpConverter : public OpConversionPattern<IterateOp> { |
223 | public: |
224 | using OpConversionPattern::OpConversionPattern; |
225 | LogicalResult |
226 | matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor, |
227 | ConversionPatternRewriter &rewriter) const override { |
228 | if (!op.getCrdUsedLvls().empty()) |
229 | return rewriter.notifyMatchFailure( |
230 | op, "non-empty coordinates list not implemented." ); |
231 | |
232 | Location loc = op.getLoc(); |
233 | |
234 | auto iterSpace = SparseIterationSpace::fromValues( |
235 | op.getIterSpace().getType(), adaptor.getIterSpace(), 0); |
236 | |
237 | std::unique_ptr<SparseIterator> it = |
238 | iterSpace.extractIterator(rewriter, loc); |
239 | |
240 | SmallVector<Value> ivs; |
241 | for (ValueRange inits : adaptor.getInitArgs()) |
242 | llvm::append_range(ivs, inits); |
243 | |
244 | // Type conversion on iterate op block. |
245 | unsigned numOrigArgs = op.getBody()->getArgumentTypes().size(); |
246 | TypeConverter::SignatureConversion signatureConversion(numOrigArgs); |
247 | if (failed(typeConverter->convertSignatureArgs( |
248 | op.getBody()->getArgumentTypes(), signatureConversion))) |
249 | return rewriter.notifyMatchFailure( |
250 | op, "failed to convert iterate region argurment types" ); |
251 | |
252 | Block *block = rewriter.applySignatureConversion( |
253 | block: op.getBody(), conversion&: signatureConversion, converter: getTypeConverter()); |
254 | ValueRange ret = genLoopWithIterator( |
255 | rewriter, loc, it: it.get(), reduc: ivs, |
256 | bodyBuilder: [block](PatternRewriter &rewriter, Location loc, Region &loopBody, |
257 | SparseIterator *it, ValueRange reduc) -> SmallVector<Value> { |
258 | SmallVector<Value> blockArgs(reduc); |
259 | // TODO: Also appends coordinates if used. |
260 | // blockArgs.push_back(it->deref(rewriter, loc)); |
261 | llvm::append_range(C&: blockArgs, R: it->getCursor()); |
262 | |
263 | Block *dstBlock = &loopBody.getBlocks().front(); |
264 | rewriter.inlineBlockBefore(source: block, dest: dstBlock, before: dstBlock->end(), |
265 | argValues: blockArgs); |
266 | auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); |
267 | // We can not use ValueRange as the operation holding the values will |
268 | // be destroyed. |
269 | SmallVector<Value> result(yield.getResults()); |
270 | rewriter.eraseOp(op: yield); |
271 | return result; |
272 | }); |
273 | |
274 | rewriter.replaceOp(op, ret); |
275 | return success(); |
276 | } |
277 | }; |
278 | |
279 | class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> { |
280 | using OpConversionPattern::OpConversionPattern; |
281 | |
282 | LogicalResult |
283 | matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor, |
284 | ConversionPatternRewriter &rewriter) const override { |
285 | assert(op.getSpaceDim() == 1 && "Not implemented" ); |
286 | Location loc = op.getLoc(); |
287 | |
288 | I64BitSet denseBits(0); |
289 | for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes())) |
290 | if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT)) |
291 | denseBits.set(idx); |
292 | |
293 | // If there exists a case that only contains dense spaces. I.e., case |
294 | // bits is a subset of dense bits, or when there is a full empty case (due |
295 | // to complements), we need a universal pointer to forward the coiteration |
296 | // loop. |
297 | bool needUniv = |
298 | any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) { |
299 | // A case for complement. |
300 | if (caseBits.count() == 0) |
301 | return true; |
302 | // An all-dense case. |
303 | return caseBits.isSubSetOf(p: denseBits); |
304 | }); |
305 | assert(!needUniv && "Not implemented" ); |
306 | (void)needUniv; |
307 | |
308 | SmallVector<Block *> newBlocks; |
309 | DenseMap<Block *, Block *> newToOldBlockMap; |
310 | for (Region ®ion : op.getCaseRegions()) { |
311 | // Do a one-shot type conversion on all region blocks, since the same |
312 | // region might be used multiple time. |
313 | Block *block = ®ion.getBlocks().front(); |
314 | TypeConverter::SignatureConversion blockTypeMapping( |
315 | block->getArgumentTypes().size()); |
316 | if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), |
317 | blockTypeMapping))) { |
318 | return rewriter.notifyMatchFailure( |
319 | op, "failed to convert coiterate region argurment types" ); |
320 | } |
321 | |
322 | newBlocks.push_back(rewriter.applySignatureConversion( |
323 | block, blockTypeMapping, getTypeConverter())); |
324 | newToOldBlockMap[newBlocks.back()] = block; |
325 | } |
326 | |
327 | SmallVector<SparseIterationSpace> spaces; |
328 | SmallVector<std::unique_ptr<SparseIterator>> iters; |
329 | for (auto [spaceTp, spaceVals] : llvm::zip_equal( |
330 | op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) { |
331 | // TODO: do we really need tid? |
332 | spaces.push_back(SparseIterationSpace::fromValues( |
333 | cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0)); |
334 | // Extract the iterator. |
335 | iters.push_back(spaces.back().extractIterator(rewriter, loc)); |
336 | } |
337 | |
338 | auto getFilteredIters = [&iters](I64BitSet caseBits) { |
339 | // Retrives a vector of pointers to the iterators used in the case. |
340 | SmallVector<SparseIterator *> validIters; |
341 | for (auto idx : caseBits.bits()) |
342 | validIters.push_back(Elt: iters[idx].get()); |
343 | return validIters; |
344 | }; |
345 | |
346 | // Get a flattened user-provided loop reduction values. |
347 | SmallVector<Value> userReduc; |
348 | for (ValueRange r : adaptor.getInitArgs()) |
349 | llvm::append_range(userReduc, r); |
350 | |
351 | // TODO: we need to sort the cases such that they appears in lexical order. |
352 | // Although sparsification always generates cases in that order, it might |
353 | // not be the case for human-written code. |
354 | |
355 | // Generates a loop sequence, one loop per case. |
356 | for (auto [r, caseBits] : |
357 | llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) { |
358 | assert(caseBits.count() > 0 && "Complement space not implemented" ); |
359 | |
360 | // Retrives a vector of pointers to the iterators used in the case. |
361 | SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits); |
362 | |
363 | if (validIters.size() > 1) { |
364 | auto [loop, loopCrd] = |
365 | genCoIteration(rewriter, loc, validIters, userReduc, |
366 | /*uniIdx=*/nullptr, /*userReducFirst=*/true); |
367 | |
368 | // 1st. find all the cases that is a strict subset of the current case |
369 | // condition, for which we generate one branch per case inside the loop. |
370 | // The subcases are never empty, it must contains at least the current |
371 | // region itself. |
372 | // TODO: these cases should be sorted. |
373 | SmallVector<Region *> subCases = |
374 | op.getSubCasesOf(r->getParent()->getRegionNumber()); |
375 | SmallVector<Block *> newBlocks, oldBlocks; |
376 | for (Region *r : subCases) { |
377 | newBlocks.push_back(&r->front()); |
378 | oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]); |
379 | } |
380 | assert(!subCases.empty()); |
381 | |
382 | ValueRange res = genCoIterateBranchNest( |
383 | rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc); |
384 | |
385 | SmallVector<Value> nextIterYields(res); |
386 | // 2nd. foward the loop. |
387 | for (SparseIterator *it : validIters) { |
388 | Value cmp = rewriter.create<arith::CmpIOp>( |
389 | loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd); |
390 | it->forwardIf(rewriter, loc, cmp); |
391 | llvm::append_range(nextIterYields, it->getCursor()); |
392 | } |
393 | rewriter.create<scf::YieldOp>(loc, nextIterYields); |
394 | |
395 | // Exit the loop, relink the iterator SSA value. |
396 | rewriter.setInsertionPointAfter(loop); |
397 | ValueRange iterVals = loop->getResults().drop_front(userReduc.size()); |
398 | for (SparseIterator *it : validIters) |
399 | iterVals = it->linkNewScope(iterVals); |
400 | assert(iterVals.empty()); |
401 | |
402 | ValueRange curResult = loop->getResults().take_front(userReduc.size()); |
403 | userReduc.assign(curResult.begin(), curResult.end()); |
404 | } else { |
405 | // This is a simple iteration loop. |
406 | assert(caseBits.count() == 1); |
407 | |
408 | Block *block = r; |
409 | ValueRange curResult = genLoopWithIterator( |
410 | rewriter, loc, validIters.front(), userReduc, |
411 | /*bodyBuilder=*/ |
412 | [block](PatternRewriter &rewriter, Location loc, Region &dstRegion, |
413 | SparseIterator *it, |
414 | ValueRange reduc) -> SmallVector<Value> { |
415 | SmallVector<Value> blockArgs(reduc); |
416 | blockArgs.push_back(it->deref(rewriter, loc)); |
417 | llvm::append_range(blockArgs, it->getCursor()); |
418 | |
419 | Block *dstBlock = &dstRegion.getBlocks().front(); |
420 | rewriter.inlineBlockBefore( |
421 | block, dstBlock, rewriter.getInsertionPoint(), blockArgs); |
422 | auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); |
423 | SmallVector<Value> result(yield.getResults()); |
424 | rewriter.eraseOp(yield); |
425 | return result; |
426 | }); |
427 | |
428 | userReduc.assign(curResult.begin(), curResult.end()); |
429 | } |
430 | } |
431 | |
432 | rewriter.replaceOp(op, userReduc); |
433 | return success(); |
434 | } |
435 | }; |
436 | |
437 | } // namespace |
438 | |
439 | mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { |
440 | addConversion(callback: [](Type type) { return type; }); |
441 | addConversion(convertIteratorType); |
442 | addConversion(convertIterSpaceType); |
443 | |
444 | addSourceMaterialization(callback: [](OpBuilder &builder, IterSpaceType spTp, |
445 | ValueRange inputs, Location loc) -> Value { |
446 | return builder |
447 | .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs) |
448 | .getResult(0); |
449 | }); |
450 | } |
451 | |
452 | void mlir::populateLowerSparseIterationToSCFPatterns( |
453 | const TypeConverter &converter, RewritePatternSet &patterns) { |
454 | |
455 | IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext()); |
456 | patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter, |
457 | SparseIterateOpConverter, SparseCoIterateOpConverter>( |
458 | arg: converter, args: patterns.getContext()); |
459 | } |
460 | |