1 | //===- LoopEmitter.cpp ----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "LoopEmitter.h" |
10 | #include "CodegenUtils.h" |
11 | |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
15 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::sparse_tensor; |
24 | |
25 | //===----------------------------------------------------------------------===// |
26 | // File local shorthand macros |
27 | //===----------------------------------------------------------------------===// |
28 | |
29 | #define CMPI(p, l, r) \ |
30 | (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \ |
31 | .getResult()) |
32 | |
33 | #define C_IDX(v) (constantIndex(builder, loc, (v))) |
34 | #define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs))) |
35 | #define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs))) |
36 | #define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs))) |
37 | #define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs))) |
38 | #define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs))) |
39 | #define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs))) |
40 | #define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs))) |
41 | #define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r))) |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // Debugging utils |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | #ifndef NDEBUG |
48 | LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, |
49 | Location loc, Value memref) { |
50 | memref = builder.create<memref::CastOp>( |
51 | loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); |
52 | createFuncCall(builder, loc, name: "printMemrefInd" , resultType: TypeRange{}, |
53 | operands: ValueRange{memref}, emitCInterface: EmitCInterface::On); |
54 | } |
55 | #endif |
56 | |
57 | //===----------------------------------------------------------------------===// |
58 | // File local helper functions. |
59 | //===----------------------------------------------------------------------===// |
60 | |
61 | // For index reduction loops, since the tensor are sliced into non-continuous |
62 | // fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi) |
63 | // specifies the range of the fragment, and pPtr specifies the index of the |
64 | // corresponding fragment in the child level (i.e., a pointer to the sliced |
65 | // position array). |
66 | static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, |
67 | Level lvl) { |
68 | auto enc = getSparseTensorEncoding(tensor.getType()); |
69 | return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl)); |
70 | } |
71 | |
72 | static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, |
73 | Level lvl) { |
74 | auto enc = getSparseTensorEncoding(tensor.getType()); |
75 | return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl)); |
76 | } |
77 | |
78 | //===----------------------------------------------------------------------===// |
79 | // Sparse tensor loop emitter class implementations |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, |
83 | bool isSparseOut, unsigned numLoops, |
84 | DependentLvlGetter dimGetter, |
85 | SparseEmitStrategy emitStrategy) { |
86 | initialize(tensors, loopTag: loopTag, hasOutput, isSparseOut, numLoops, getter: dimGetter); |
87 | } |
88 | |
89 | void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, |
90 | bool isSparseOut, unsigned numLoops, |
91 | DependentLvlGetter dimGetter, |
92 | SparseEmitStrategy emitStrategy) { |
93 | // First initialize the top-level type of the fields. |
94 | this->loopTag = loopTag; |
95 | this->hasOutput = hasOutput; |
96 | this->isSparseOut = isSparseOut; |
97 | this->emitStrategy = emitStrategy; |
98 | |
99 | const unsigned numManifestTensors = ts.size(); |
100 | const unsigned synTensorId = numManifestTensors; |
101 | const unsigned numTensors = numManifestTensors + 1; |
102 | // tensors array (len == numManifestTensor). |
103 | this->tensors.assign(first: ts.begin(), last: ts.end()); |
104 | // Arrays with len == numTensor. |
105 | this->valBuffer.assign(n: numTensors, val: nullptr); |
106 | this->lvls.resize(new_size: numTensors); |
107 | this->iters.resize(new_size: numTensors); |
108 | |
109 | // These zeros will be overwritten below, but we need to initialize |
110 | // them to something since we'll need random-access assignment. |
111 | this->loopStack.reserve(n: numLoops); |
112 | this->loopSeqStack.reserve(n: numLoops); |
113 | |
114 | // Index-reduction related fields. |
115 | this->dependentLvlMap.assign( |
116 | n: numTensors, val: std::vector<std::vector<std::pair<TensorLevel, unsigned>>>()); |
117 | this->sliceMeta.assign( |
118 | n: numTensors, val: std::vector<std::vector<std::pair<Value, unsigned>>>()); |
119 | this->levelReducedDep.assign(n: numTensors, val: std::vector<unsigned>()); |
120 | |
121 | // Initialize nested types of `TensorId`-indexed fields. |
122 | for (TensorId tid = 0; tid < numTensors; tid++) { |
123 | Level lvlRank; |
124 | if (tid == synTensorId) { |
125 | // Synthetic tensor (conceptually) is an all-dense tensor with rank equal |
126 | // to the total number of loops (each level can potentially be mapped to |
127 | // one of the loop being generated). |
128 | lvlRank = numLoops; |
129 | } else { |
130 | const Value t = tensors[tid]; |
131 | // a scalar or 0-dimension tensors |
132 | if (isZeroRankedTensorOrScalar(type: t.getType())) |
133 | continue; |
134 | |
135 | auto rtp = getRankedTensorType(t); |
136 | const SparseTensorType stt(rtp); |
137 | lvlRank = stt.getLvlRank(); |
138 | } |
139 | |
140 | lvls[tid].resize(new_size: lvlRank); |
141 | iters[tid].resize(new_size: lvlRank); |
142 | loopHighs.assign(n: numLoops, val: nullptr); |
143 | |
144 | // Slice-driven loops related initialization. |
145 | levelReducedDep[tid].assign(n: lvlRank, val: 0); |
146 | dependentLvlMap[tid].assign( |
147 | n: lvlRank, val: std::vector<std::pair<TensorLevel, unsigned>>()); |
148 | sliceMeta[tid].assign(n: lvlRank, val: std::vector<std::pair<Value, unsigned>>()); |
149 | if (dimGetter && !isSynTensor(tid)) { |
150 | for (Level l = 0; l < lvlRank; l++) { |
151 | std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l); |
152 | // Sort the loop by order. |
153 | std::sort(first: deps.begin(), last: deps.end(), |
154 | comp: [](auto &lhs, auto &rhs) { return lhs.first < rhs.first; }); |
155 | |
156 | dependentLvlMap[tid][l] = std::move(deps); |
157 | unsigned depends = dependentLvlMap[tid][l].size(); |
158 | if (depends == 0) |
159 | continue; |
160 | sliceMeta[tid][l].reserve(n: depends); |
161 | } |
162 | } |
163 | } |
164 | } |
165 | |
166 | std::unique_ptr<SparseIterator> |
167 | LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t, |
168 | Level l) { |
169 | auto it = makeSimpleIterator(stl: *lvls[t][l], strategy: emitStrategy); |
170 | auto stt = getSparseTensorType(val: tensors[t]); |
171 | if (stt.hasEncoding() && stt.getEncoding().isSlice()) { |
172 | Value offset = genSliceOffset(builder, loc, tensor: tensors[t], lvl: l); |
173 | Value stride = genSliceStride(builder, loc, tensor: tensors[t], lvl: l); |
174 | auto slicedIt = makeSlicedLevelIterator( |
175 | sit: std::move(it), offset, stride, size: lvls[t][l]->getSize(), strategy: emitStrategy); |
176 | return slicedIt; |
177 | } |
178 | return it; |
179 | } |
180 | |
181 | void LoopEmitter::initializeLoopEmit( |
182 | OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater, |
183 | LoopEmitter::SynTensorBoundSetter synSetter) { |
184 | // For every synthetic tensor, set the high bound by calling the callback. |
185 | if (synSetter) { |
186 | TensorId synId = getSynTensorId(); |
187 | for (unsigned i = 0, e = loopHighs.size(); i < e; i++) { |
188 | Value sz = loopHighs[i] = synSetter(builder, loc, i); |
189 | auto [stl, it] = makeSynLevelAndIterator(sz, tid: synId, lvl: i, strategy: emitStrategy); |
190 | lvls[synId][i] = std::move(stl); |
191 | iters[synId][i].emplace_back(args: std::move(it)); |
192 | } |
193 | } |
194 | |
195 | // For every manifest tensor: |
196 | // * get the values buffer. |
197 | // * For every level: |
198 | // * get the positions and coordinates buffers |
199 | // * get/compute the level-size, which is also used as the upper-bound |
200 | // on positions. |
201 | for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors; |
202 | t++) { |
203 | const Value tensor = tensors[t]; |
204 | const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); |
205 | if (!rtp) |
206 | // Skips only scalar, zero ranked tensor still need to be bufferized and |
207 | // (probably) filled with zeros by users. |
208 | continue; |
209 | // FIXME: the definition of `lvlRank` looks more like a dim-rank; |
210 | // but the variable is used as a level everywhere below, which |
211 | // suggests there may be some dim/lvl confusion going on here. |
212 | auto stt = getSparseTensorType(val: tensor); |
213 | const Level lvlRank = stt.getLvlRank(); |
214 | const auto shape = rtp.getShape(); |
215 | |
216 | SmallVector<Value> lvlSzs; |
217 | for (Level l = 0; l < stt.getLvlRank(); l++) { |
218 | if (stt.hasEncoding()) |
219 | lvlSzs.push_back(builder.create<LvlOp>(loc, tensor, l)); |
220 | else |
221 | lvlSzs.push_back(builder.create<tensor::DimOp>(loc, tensor, l)); |
222 | } |
223 | |
224 | // Scan all levels of current tensor. |
225 | for (Level l = 0; l < lvlRank; l++) { |
226 | // Find upper bound in current dimension. |
227 | lvls[t][l] = makeSparseTensorLevel(builder, loc, t: tensor, tid: t, l); |
228 | if (!dependentLvlMap[t][l].empty()) |
229 | continue; |
230 | |
231 | auto it = makeLevelIterator(builder, loc, t, l); |
232 | iters[t][l].emplace_back(args: std::move(it)); |
233 | } |
234 | |
235 | // Perform the required bufferization. Dense inputs materialize |
236 | // from the input tensors. Sparse inputs use sparse primitives to obtain the |
237 | // values. |
238 | // Delegates extra output initialization to clients. |
239 | bool isOutput = isOutputTensor(tid: t); |
240 | Type elementType = stt.getElementType(); |
241 | if (!stt.hasEncoding()) { |
242 | // Non-annotated dense tensors. |
243 | BaseMemRefType denseTp = MemRefType::get(shape, elementType); |
244 | |
245 | // TODO: if we unconditionally use fully dynamic layout here, it breaks |
246 | // some vectorization passes which requires static stride = 1. |
247 | // Is it possible to call vectorization pass after bufferization? |
248 | if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(Val: tensor.getDefiningOp())) |
249 | denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType: rtp); |
250 | |
251 | Value denseVal = |
252 | builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); |
253 | // Dense outputs need special handling. |
254 | if (isOutput && updater) |
255 | denseVal = updater(builder, loc, denseVal, tensor); |
256 | |
257 | valBuffer[t] = denseVal; |
258 | } else { |
259 | // Annotated sparse tensors. |
260 | // We also need the value buffer for all-dense annotated "sparse" |
261 | // tensors. |
262 | valBuffer[t] = builder.create<ToValuesOp>(loc, tensor); |
263 | } |
264 | // NOTE: we can also prepare for 0 lvl here in advance, this will hoist |
265 | // some loop preparation from tensor iteration, but will also (undesirably) |
266 | // hoist the code ouside if-conditions. |
267 | } |
268 | // TODO: avoid treating subsection iterator as a special case. |
269 | initSubSectIterator(builder, loc); |
270 | } |
271 | |
272 | void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) { |
273 | Value c0 = C_IDX(0); |
274 | for (TensorId t = 0, e = tensors.size(); t < e; t++) { |
275 | auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType()); |
276 | if (!rtp) |
277 | continue; |
278 | |
279 | Level lvlRank = SparseTensorType(rtp).getLvlRank(); |
280 | |
281 | // Compute the dependency reduction order. |
282 | auto remDepStack = dependentLvlMap; |
283 | std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder; |
284 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
285 | // Reverse queue into a stack. |
286 | std::reverse(first: remDepStack[t][lvl].begin(), last: remDepStack[t][lvl].end()); |
287 | for (auto [loop, coeff] : dependentLvlMap[t][lvl]) |
288 | depRedOrder.emplace_back(args: std::make_tuple(args&: loop, args&: t, args&: lvl)); |
289 | } |
290 | |
291 | if (depRedOrder.empty()) |
292 | continue; |
293 | |
294 | std::sort(first: depRedOrder.begin(), last: depRedOrder.end(), |
295 | comp: [](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); }); |
296 | |
297 | SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr); |
298 | for (auto [loop, t, lvl] : depRedOrder) { |
299 | std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back(); |
300 | assert(curDep.first == loop); |
301 | remDepStack[t][lvl].pop_back(); |
302 | |
303 | auto lvlIt = makeLevelIterator(builder, loc, t, l: lvl); |
304 | const SparseIterator *parent = lastIter[t]; |
305 | if (!parent && lvl > 0) { |
306 | if (dependentLvlMap[t][lvl - 1].empty()) { |
307 | parent = iters[t][lvl - 1].back().get(); |
308 | } |
309 | } |
310 | |
311 | std::unique_ptr<SparseIterator> it; |
312 | if (!remDepStack[t][lvl].empty()) { |
313 | // Compute the subsection size. |
314 | Value size = c0; |
315 | for (auto [loop, stride] : remDepStack[t][lvl]) { |
316 | Value idxMax = SUBI(loopHighs[loop], C_IDX(1)); |
317 | size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1))); |
318 | } |
319 | it = makeNonEmptySubSectIterator(b&: builder, l: loc, parent, loopBound: loopHighs[loop], |
320 | delegate: std::move(lvlIt), size, stride: curDep.second, |
321 | strategy: emitStrategy); |
322 | } else { |
323 | const SparseIterator &subSectIter = *iters[t][lvl].back(); |
324 | it = makeTraverseSubSectIterator(b&: builder, l: loc, subsectIter: subSectIter, parent: *parent, |
325 | wrap: std::move(lvlIt), loopBound: loopHighs[loop], |
326 | stride: curDep.second, strategy: emitStrategy); |
327 | } |
328 | lastIter[t] = it.get(); |
329 | iters[t][lvl].emplace_back(args: std::move(it)); |
330 | } |
331 | } |
332 | } |
333 | |
334 | void LoopEmitter::categorizeIterators( |
335 | ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters, |
336 | SmallVectorImpl<SparseIterator *> &spIters) { |
337 | // Finds out the tensor level that we should use to generate loops. Amongs all |
338 | // the tensor levels, there is at most one sparse tensor level. |
339 | for (auto [t, l] : unpackTensorLevelRange(c&: tidLvls)) { |
340 | SparseIterator *it = &getCurIterator(tid: t, lvl: l); |
341 | if (it->randomAccessible()) |
342 | raIters.push_back(Elt: it); |
343 | else |
344 | spIters.push_back(Elt: it); |
345 | } |
346 | |
347 | std::stable_sort(first: spIters.begin(), last: spIters.end(), comp: [](auto lhs, auto rhs) { |
348 | // AffineUnRed > Affine > Slice > Trivial |
349 | return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind); |
350 | }); |
351 | } |
352 | |
353 | void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, |
354 | ArrayRef<TensorLevel> tidLvls) { |
355 | // TODO: sort |
356 | assert(loopSeqStack.size() == loopStack.size()); |
357 | // Prepares for all the tensors used in the current loop sequence. |
358 | |
359 | for (auto [tid, lvl] : unpackTensorLevelRange(c&: tidLvls)) { |
360 | levelReducedDep[tid][lvl]++; |
361 | prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); |
362 | } |
363 | |
364 | // Universal Index starts from 0. |
365 | loopSeqStack.emplace_back(C_IDX(0), args: tidLvls.vec()); |
366 | } |
367 | |
368 | void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) { |
369 | assert(loopSeqStack.size() == loopStack.size() + 1); |
370 | |
371 | // Depending on whether the slice is resolved or not at current loop sequence, |
372 | // end them in different ways. |
373 | for (auto [tid, lvl] : unpackTensorLevelRange(c&: loopSeqStack.back().second)) |
374 | levelReducedDep[tid][lvl]--; |
375 | |
376 | loopSeqStack.pop_back(); |
377 | } |
378 | |
379 | Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) { |
380 | switch (a.getKind()) { |
381 | case AffineExprKind::DimId: { |
382 | // FIXME: since the one callsite in Sparsification passes in a |
383 | // level-expression, the `getPosition` must in fact be a `Dimension`. |
384 | // However, elsewhere we have been lead to expect that `loopIdToOrd` |
385 | // should be indexed by `LoopId`... |
386 | const auto loopId = cast<AffineDimExpr>(Val&: a).getPosition(); |
387 | return loopStack[loopId].iv; |
388 | } |
389 | case AffineExprKind::Add: { |
390 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
391 | return ADDI(genAffine(builder, loc, binOp.getLHS()), |
392 | genAffine(builder, loc, binOp.getRHS())); |
393 | } |
394 | case AffineExprKind::Mul: { |
395 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
396 | return MULI(genAffine(builder, loc, binOp.getLHS()), |
397 | genAffine(builder, loc, binOp.getRHS())); |
398 | } |
399 | case AffineExprKind::Constant: { |
400 | int64_t c = cast<AffineConstantExpr>(Val&: a).getValue(); |
401 | return C_IDX(c); |
402 | } |
403 | default: |
404 | llvm_unreachable("unexpected affine subscript" ); |
405 | } |
406 | } |
407 | |
408 | std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl( |
409 | OpBuilder &builder, Location loc, SparseIterator &iter, |
410 | MutableArrayRef<Value> reduc, bool isParallel) { |
411 | |
412 | // TODO: support dynamic slices. |
413 | // Uses the first dimension here to build the loop bound (which is also the |
414 | // biggest range). |
415 | |
416 | Value step = C_IDX(1); |
417 | auto [lo, hi] = iter.genForCond(b&: builder, l: loc); |
418 | Operation *loop = nullptr; |
419 | Value iv; |
420 | if (isParallel) { |
421 | scf::ParallelOp parOp = |
422 | builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc); |
423 | builder.setInsertionPointToStart(parOp.getBody()); |
424 | assert(parOp.getNumReductions() == reduc.size()); |
425 | iv = parOp.getInductionVars()[0]; |
426 | |
427 | // In-place update on the reduction variable vector. |
428 | // Note that the init vals is not the actual reduction variables but instead |
429 | // used as a "special handle" to (temporarily) represent them. The |
430 | // expression on init vals will be moved into scf.reduce and replaced with |
431 | // the block arguments when exiting the loop (see exitForLoop). This is |
432 | // needed as we can not build the actual reduction block and get the actual |
433 | // reduction variable before users fill parallel loop body. |
434 | for (int i = 0, e = reduc.size(); i < e; i++) |
435 | reduc[i] = parOp.getInitVals()[i]; |
436 | loop = parOp; |
437 | } else { |
438 | scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc); |
439 | builder.setInsertionPointToStart(forOp.getBody()); |
440 | iv = forOp.getInductionVar(); |
441 | |
442 | // In-place update on the reduction variable vector. |
443 | assert(forOp.getNumRegionIterArgs() == reduc.size()); |
444 | for (int i = 0, e = reduc.size(); i < e; i++) |
445 | reduc[i] = forOp.getRegionIterArg(i); |
446 | loop = forOp; |
447 | } |
448 | assert(loop && iv); |
449 | |
450 | Value crd = iv; |
451 | if (!iter.randomAccessible()) { |
452 | iter.linkNewScope(pos: iv); |
453 | crd = iter.deref(b&: builder, l: loc); |
454 | } else { |
455 | iter.locate(b&: builder, l: loc, crd: iv); |
456 | } |
457 | |
458 | return {loop, crd}; |
459 | } |
460 | |
461 | std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls( |
462 | OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters, |
463 | MutableArrayRef<Value> reduc, bool needsUniv) { |
464 | // NOTE: the slice driven tensor-related reduction variable must |
465 | // appear before normal tensors. |
466 | |
467 | // The set of induction variables for the while loop. |
468 | SmallVector<Value> ivs; |
469 | |
470 | // Construct the while-loop with a parameter for each coordinate. |
471 | for (SparseIterator *it : spIters) { |
472 | ValueRange itVals = it->getCursor(); |
473 | ivs.append(in_start: itVals.begin(), in_end: itVals.end()); |
474 | } |
475 | |
476 | // The position where user-supplied reduction variable starts. |
477 | ivs.append(in_start: reduc.begin(), in_end: reduc.end()); |
478 | // Update universal index. |
479 | if (needsUniv) |
480 | ivs.push_back(Elt: loopSeqStack.back().first); |
481 | |
482 | // Ensures all operands are valid. |
483 | assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); |
484 | TypeRange types = ValueRange(ivs).getTypes(); |
485 | auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs); |
486 | |
487 | SmallVector<Location> locs(types.size(), loc); |
488 | Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); |
489 | Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); |
490 | |
491 | // Generates loop conditions. |
492 | builder.setInsertionPointToStart(before); |
493 | ValueRange bArgs = before->getArguments(); |
494 | Value whileCond = nullptr; // bool values for loop condition. |
495 | |
496 | for (SparseIterator *it : spIters) { |
497 | auto [cond, remArgs] = it->genWhileCond(b&: builder, l: loc, vs: bArgs); |
498 | whileCond = !whileCond ? cond : ANDI(whileCond, cond); |
499 | bArgs = remArgs; |
500 | } |
501 | // The remaining block arguments are user-provided reduction values and an |
502 | // optional universal index. Make sure their sizes match. |
503 | assert(bArgs.size() == reduc.size() + needsUniv ? 1 : 0); |
504 | builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); |
505 | |
506 | // Generates loop body. |
507 | builder.setInsertionPointToStart(after); |
508 | ValueRange aArgs = after->getArguments(); |
509 | // Since some LoopCondKind might need extra checks to filter out invalid |
510 | // iterations, we maintains another array to hold the iteration arguments to |
511 | // yield if the checks fails. |
512 | SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end()); |
513 | |
514 | for (SparseIterator *it : spIters) { |
515 | aArgs = it->linkNewScope(pos: aArgs); |
516 | // Dereference the iterator to cache the coordinate. |
517 | it->deref(b&: builder, l: loc); |
518 | } |
519 | |
520 | // In-place update on reduction variable. |
521 | assert(aArgs.size() == reduc.size() + needsUniv ? 1 : 0); |
522 | for (unsigned i = 0, e = reduc.size(); i < e; i++) |
523 | reduc[i] = aArgs[i]; |
524 | |
525 | Value min; |
526 | // Finds the minimum coordinate |
527 | if (!needsUniv) { |
528 | for (SparseIterator *it : spIters) { |
529 | if (min) { |
530 | Value cmp = CMPI(ult, it->getCrd(), min); |
531 | min = SELECT(cmp, it->getCrd(), min); |
532 | } else { |
533 | min = it->getCrd(); |
534 | } |
535 | } |
536 | } else { |
537 | // Otherwise, universal index is the minimal pos. |
538 | min = whileOp.getAfterArguments().back(); |
539 | } |
540 | |
541 | return {whileOp, min}; |
542 | } |
543 | |
544 | bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) { |
545 | // If we need to co-iterate over two sparse tensors, we need a while loop |
546 | if (spIters.size() > 1) |
547 | return false; |
548 | |
549 | if (spIters.size() == 1) |
550 | return spIters.front()->iteratableByFor(); |
551 | |
552 | return true; |
553 | } |
554 | |
555 | Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( |
556 | OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls, |
557 | MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) { |
558 | |
559 | // TODO: support multiple return on parallel for? |
560 | tryParallel = tryParallel && reduc.size() <= 1; |
561 | |
562 | SmallVector<SparseIterator *> raIters; |
563 | SmallVector<SparseIterator *> spIters; |
564 | categorizeIterators(tidLvls, raIters, spIters); |
565 | |
566 | // Only when there is at least one sparse conditions, do we really need the |
567 | // universal index. |
568 | // TODO: Maybe we should instead requires merger to pass in a valid value at |
569 | // the first place instead of adjusting it in LoopEmitter? |
570 | needsUniv = !spIters.empty() && needsUniv; |
571 | // The TensorLevel used for loop conditions. |
572 | // If there is any sparse level, we need to use the sparse condition. |
573 | // If all levels are dense, we can pick arbitrary one (dense slice-driven loop |
574 | // can be generated using a simple ForOp as well). |
575 | Operation *l = nullptr; |
576 | Value iv = nullptr; |
577 | SmallVector<TensorLevel> tls; |
578 | |
579 | // Generates loops differently depending on whether we need a slice-driven |
580 | // loop or a simple level traversal loop. |
581 | if (shouldIteratedByForLoop(spIters) && !needsUniv) { |
582 | assert(spIters.size() <= 1); |
583 | SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front(); |
584 | std::tie(args&: l, args&: iv) = |
585 | emitForLoopOverTensorAtLvl(builder, loc, iter&: it, reduc, isParallel: tryParallel); |
586 | tls.push_back(Elt: makeTensorLevel(t: it.tid, l: it.lvl)); |
587 | } else { |
588 | for (auto *it : spIters) { |
589 | tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl)); |
590 | } |
591 | |
592 | if (needsUniv) |
593 | for (auto *it : raIters) |
594 | tls.push_back(Elt: makeTensorLevel(t: it->tid, l: it->lvl)); |
595 | |
596 | std::tie(args&: l, args&: iv) = |
597 | emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv); |
598 | } |
599 | |
600 | // Enter dense tensor levels. |
601 | for (SparseIterator *it : raIters) |
602 | it->locate(b&: builder, l: loc, crd: iv); |
603 | |
604 | // NOTE: we can also prepare for next dim here in advance |
605 | // Pushes the loop into stack. |
606 | loopStack.emplace_back(tls, l, builder.getInsertionBlock(), iv, loopTag); |
607 | return l; |
608 | } |
609 | |
610 | void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc, |
611 | TensorLevel tidLvl, |
612 | AffineExpr lvlExpr) { |
613 | auto [tid, lvl] = unpackTensorLevel(tidLvl); |
614 | |
615 | const SparseIterator *parent = |
616 | lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get(); |
617 | auto &it = getCurIterator(tid, lvl); |
618 | it.genInit(b&: builder, l: loc, p: parent); |
619 | |
620 | assert(it.kind == IterKind::kTrivial && it.randomAccessible()); |
621 | Value lvlCrd = genAffine(builder, loc, a: lvlExpr); |
622 | it.locate(b&: builder, l: loc, crd: lvlCrd); |
623 | } |
624 | |
625 | void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, |
626 | TensorId tid, Level lvl) { |
627 | // if this is the first level, there is no parent iterator for the current |
628 | // iterator. |
629 | // If the current iterator is a subsection-based iterator, the parent iterator |
630 | // is memorized by the iterator. |
631 | bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty(); |
632 | |
633 | const SparseIterator *parent = |
634 | hasParent ? nullptr : iters[tid][lvl - 1].back().get(); |
635 | auto &it = getCurIterator(tid, lvl); |
636 | it.genInit(b&: builder, l: loc, p: parent); |
637 | |
638 | // Locates the randon accessible iterator to 0. |
639 | if (it.randomAccessible()) |
640 | it.locate(b&: builder, l: loc, C_IDX(0)); |
641 | } |
642 | |
643 | void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, |
644 | MutableArrayRef<Value> reduc) { |
645 | const LoopInfo &loopInfo = loopStack.back(); |
646 | if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) { |
647 | if (!reduc.empty()) { |
648 | assert(reduc.size() == forOp.getNumResults()); |
649 | rewriter.create<scf::YieldOp>(loc, reduc); |
650 | } |
651 | // Exit the loop. |
652 | rewriter.setInsertionPointAfter(forOp); |
653 | // In-place update reduction variables. |
654 | for (unsigned i = 0, e = forOp.getResults().size(); i < e; i++) |
655 | reduc[i] = forOp.getResult(i); |
656 | } else { |
657 | auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop); |
658 | if (!reduc.empty()) { |
659 | assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1); |
660 | Operation *redExp = reduc.front().getDefiningOp(); |
661 | // Reduction expression should have no use. |
662 | assert(redExp->getUses().empty()); |
663 | // This must be a binary operation. |
664 | // NOTE: This is users' responsibility to ensure the operation are |
665 | // commutative. |
666 | assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1); |
667 | |
668 | Value redVal = parOp.getInitVals().front(); |
669 | Value curVal; |
670 | if (redExp->getOperand(idx: 0) == redVal) |
671 | curVal = redExp->getOperand(idx: 1); |
672 | else if (redExp->getOperand(idx: 1) == redVal) |
673 | curVal = redExp->getOperand(idx: 0); |
674 | // One of the operands must be the init value (which is also the |
675 | // previous reduction value). |
676 | assert(curVal); |
677 | #ifndef NDEBUG |
678 | // The reduction expression should be the only user of the reduction val |
679 | // inside the parallel for. |
680 | unsigned numUsers = 0; |
681 | for (Operation *op : redVal.getUsers()) { |
682 | if (op->getParentOp() == parOp) |
683 | numUsers++; |
684 | } |
685 | assert(numUsers == 1); |
686 | #endif // NDEBUG |
687 | |
688 | rewriter.setInsertionPointAfter(redExp); |
689 | auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal); |
690 | // Attach to the reduction op. |
691 | Block *redBlock = &redOp.getReductions().front().front(); |
692 | rewriter.setInsertionPointToEnd(redBlock); |
693 | Operation *newRed = rewriter.clone(op&: *redExp); |
694 | // Replaces arguments of the reduction expression by using the block |
695 | // arguments from scf.reduce. |
696 | rewriter.modifyOpInPlace( |
697 | root: newRed, callable: [&]() { newRed->setOperands(redBlock->getArguments()); }); |
698 | // Erases the out-dated reduction expression. |
699 | rewriter.eraseOp(op: redExp); |
700 | rewriter.setInsertionPointToEnd(redBlock); |
701 | rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0)); |
702 | } |
703 | rewriter.setInsertionPointAfter(parOp); |
704 | // In-place update reduction variables. |
705 | for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++) |
706 | reduc[i] = parOp.getResult(i); |
707 | } |
708 | } |
709 | |
710 | void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, |
711 | MutableArrayRef<Value> reduc) { |
712 | const LoopInfo &loopInfo = loopStack.back(); |
713 | auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop); |
714 | Value iv = loopInfo.iv; |
715 | Value one = C_IDX(1); |
716 | |
717 | // Finalize the induction. Note that the induction could be performed |
718 | // in the individual if-branches to avoid re-evaluating the conditions. |
719 | // However, that would result in a rather elaborate forest of yield |
720 | // instructions during code generation. Moreover, performing the induction |
721 | // after the if-statements more closely resembles code generated by TACO. |
722 | SmallVector<Value> operands; |
723 | ValueRange whileRes = whileOp.getResults(); |
724 | |
725 | for (auto [tid, lvl] : unpackTensorLevelRange(c: loopInfo.tidLvls)) { |
726 | SparseIterator &it = getCurIterator(tid, lvl); |
727 | if (!it.randomAccessible()) { |
728 | // Forward the sparse iterator. |
729 | Value cmp = CMPI(eq, it.getCrd(), iv); |
730 | it.forwardIf(b&: builder, l: loc, cond: cmp); |
731 | operands.append(in_start: it.getCursor().begin(), in_end: it.getCursor().end()); |
732 | // const Value newPos = whileOp->getResult(o++); |
733 | // Following loops continue iteration from the break point of the |
734 | // current while loop. |
735 | whileRes = it.linkNewScope(pos: whileRes); |
736 | } else { |
737 | // Make sure randomly accessible (dense) iterator is set to the right |
738 | // position according to the universal index. |
739 | Value uniIdx = whileOp.getResults().back(); |
740 | it.locate(b&: builder, l: loc, crd: uniIdx); |
741 | } |
742 | } |
743 | |
744 | // Reduction value from users. |
745 | for (auto &i : reduc) { |
746 | operands.push_back(Elt: i); |
747 | // Update user reduction variables. |
748 | i = whileRes.front(); |
749 | whileRes = whileRes.drop_front(); |
750 | } |
751 | |
752 | // An (optional) universal index. |
753 | if (operands.size() < whileOp.getNumResults()) { |
754 | assert(operands.size() + 1 == whileOp.getNumResults()); |
755 | // The last one is the universial index. |
756 | operands.push_back(ADDI(iv, one)); |
757 | // update the loop starting point of current loop sequence |
758 | loopSeqStack.back().first = whileOp->getResults().back(); |
759 | } |
760 | |
761 | if (!operands.empty()) |
762 | YIELD(operands); |
763 | |
764 | builder.setInsertionPointAfter(whileOp); |
765 | } |
766 | |
767 | void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc, |
768 | MutableArrayRef<Value> reduc) { |
769 | // Clean up the values, it would help use to discover potential bug at a |
770 | // earlier stage (instead of silently using a wrong value). |
771 | const LoopInfo &loopInfo = loopStack.back(); |
772 | |
773 | // Sets the insertion point to the right position. |
774 | rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock); |
775 | if (!loopInfo.userCodeBlock->empty() && |
776 | llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) { |
777 | // scf::While/For inserts an implicit yield op when there is no loop |
778 | // iter args. In this case, we need to insert the code before the yield. |
779 | assert(loopInfo.userCodeBlock->back().getNumResults() == 0); |
780 | rewriter.setInsertionPoint(&loopInfo.userCodeBlock->back()); |
781 | } |
782 | |
783 | if (llvm::isa<scf::WhileOp>(loopInfo.loop)) { |
784 | exitWhileLoop(builder&: rewriter, loc, reduc); |
785 | } else { |
786 | exitForLoop(rewriter, loc, reduc); |
787 | } |
788 | |
789 | assert(loopStack.size() == loopSeqStack.size()); |
790 | loopStack.pop_back(); |
791 | } |
792 | |
793 | #undef CMPI |
794 | #undef C_IDX |
795 | #undef YIELD |
796 | #undef ADDI |
797 | #undef ANDI |
798 | #undef SUBI |
799 | #undef MULI |
800 | #undef SELECT |
801 | |