1 | //===- Sparsification.cpp - Implementation of sparsification --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements converting sparse tensor types to actual sparse code. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Utils/CodegenEnv.h" |
14 | #include "Utils/CodegenUtils.h" |
15 | #include "Utils/LoopEmitter.h" |
16 | |
17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
20 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
22 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
23 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
24 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
25 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
26 | #include "mlir/Dialect/SCF/IR/SCF.h" |
27 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
28 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
29 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
30 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
31 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
32 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
33 | #include "mlir/IR/AffineExprVisitor.h" |
34 | #include "mlir/IR/Matchers.h" |
35 | #include "mlir/IR/TensorEncoding.h" |
36 | #include "llvm/ADT/SmallBitVector.h" |
37 | |
38 | #include <optional> |
39 | |
40 | using namespace mlir; |
41 | using namespace mlir::sparse_tensor; |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // Sparsifier analysis methods. |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | /// Returns true iff affine expression is invariant. Sets the |
48 | /// parameter `isCurrentLoop` when expression just became invariant. |
49 | static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) { |
50 | switch (a.getKind()) { |
51 | case AffineExprKind::DimId: { |
52 | const LoopId i = cast<AffineDimExpr>(Val&: a).getPosition(); |
53 | if (i + 1 == curr) { |
54 | isCurrentLoop = true; |
55 | return true; // becomes invariant at current loop |
56 | } |
57 | return i < curr; // invariant when already generated |
58 | } |
59 | case AffineExprKind::Add: |
60 | case AffineExprKind::Mul: { |
61 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
62 | return isInvariantAffine(a: binOp.getLHS(), curr, isCurrentLoop) && |
63 | isInvariantAffine(a: binOp.getRHS(), curr, isCurrentLoop); |
64 | } |
65 | default: { |
66 | assert(isa<AffineConstantExpr>(a)); |
67 | return true; |
68 | } |
69 | } |
70 | } |
71 | |
72 | /// Helper method to inspect affine expressions. Rejects cases where the |
73 | /// same index is used more than once. Also rejects compound affine |
74 | /// expressions in sparse dimensions. |
75 | static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, |
76 | LevelType lt, bool setLvlFormat = true) { |
77 | switch (a.getKind()) { |
78 | case AffineExprKind::DimId: { |
79 | const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition()); |
80 | if (!isUndefLT(lt: merger.getLvlType(t: tid, i: idx))) |
81 | return false; // used more than once |
82 | if (setLvlFormat) |
83 | merger.setLevelAndType(t: tid, i: idx, lvl, lt); |
84 | return true; |
85 | } |
86 | case AffineExprKind::Add: |
87 | case AffineExprKind::Mul: |
88 | case AffineExprKind::Constant: { |
89 | assert(lt.hasDenseSemantic()); |
90 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: a)) { |
91 | // We do not set dim level format for affine expression like d0 + d1 on |
92 | // either loop index at d0 or d1. We continue the recursion merely to |
93 | // check whether current affine is admissible or not. |
94 | return findAffine(merger, tid, lvl, a: binOp.getLHS(), lt, setLvlFormat: false) && |
95 | findAffine(merger, tid, lvl, a: binOp.getRHS(), lt, setLvlFormat: false); |
96 | } |
97 | // Falls through when it is a constant Affine |
98 | return true; |
99 | } |
100 | default: |
101 | return false; |
102 | } |
103 | } |
104 | |
105 | /// Helper method to inspect affine expressions for index variable reduction |
106 | /// based codegen. It finds the dependent index set for all tensor levels in the |
107 | /// current expression we are generating. |
108 | /// |
109 | /// For example, when handling A[i+j][j+k], we build the two way mapping in |
110 | /// merger between (tensor, level) pairs and their dependent index variable set: |
111 | /// A_0 <=> [i, j] and A_1 <=> [j, k] |
112 | /// |
113 | /// It rejects cases (returns false) |
114 | /// 1st, when the same index is used more than once, e.g., A[i+j][i] |
115 | /// 2nd, when multiplication is used in the non-trivial index expression. |
116 | /// 3rd, when a constant operand is used in the non-trivial index expression. |
117 | /// |
118 | /// TODO: constant should be easy to handle. |
119 | static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, |
120 | AffineExpr a, LevelType lt, bool isSubExp = false, |
121 | int64_t coefficient = 1) { |
122 | switch (a.getKind()) { |
123 | case AffineExprKind::DimId: { |
124 | // Only allow positive coefficients on AffineDimExpr. |
125 | if (coefficient <= 0) |
126 | return false; |
127 | |
128 | const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition()); |
129 | if (!isUndefLT(lt: merger.getLvlType(t: tensor, i: idx))) |
130 | return false; // used more than once, e.g., A[i][i] |
131 | |
132 | // TODO: Generalizes the following two cases. A[i] (with trivial index |
133 | // expression) can be treated as a special affine index expression. We do |
134 | // not necessarily need to differentiate them. |
135 | if (!isSubExp) { |
136 | assert(coefficient == 1); |
137 | merger.setLevelAndType(t: tensor, i: idx, lvl, lt); |
138 | } |
139 | |
140 | if (isSubExp) { |
141 | // The current loops appears in more than one affine expressions on the |
142 | // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is |
143 | // used twice. |
144 | if (merger.hasDependentLvl(i: idx, t: tensor)) { |
145 | // TODO: This can be supported by coiterate slices if the loop idx is |
146 | // appeared on affine index for different tensor, or take slice on |
147 | // multiple dimensions when it is on the same tensor. |
148 | // E.g., |
149 | // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] |
150 | // d0_1 = getNextSliceOffset t0 along lvl0 |
151 | // d0_2 = getNextSliceOffset t1 along lvl0 |
152 | // if d0_1 == d0_2 then d0 = d0_1 = d0_1 |
153 | // else increase min(d0_1, d0_2). |
154 | return false; |
155 | } |
156 | merger.setLoopDependentTensorLevel(i: idx, t: tensor, lvl, lt, coefficient); |
157 | } |
158 | return true; |
159 | } |
160 | case AffineExprKind::Constant: |
161 | case AffineExprKind::Mul: { |
162 | // TODO: Support index expression like `2 * d0`, we now only support more |
163 | // complicated cases like `2 * d0 + d1`. |
164 | if (!isSubExp) |
165 | return false; |
166 | |
167 | // TODO: Support Constant AffineExp for slice-based codegen |
168 | if (isa<AffineConstantExpr>(Val: a)) |
169 | llvm_unreachable("Not yet implemented" ); |
170 | |
171 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
172 | auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); |
173 | if (isa<AffineConstantExpr>(Val: rhs)) |
174 | std::swap(a&: lhs, b&: rhs); |
175 | // Must be in form of `constant * d`. |
176 | assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs)); |
177 | int64_t coefficient = cast<AffineConstantExpr>(Val&: lhs).getValue(); |
178 | return findDepIdxSet(merger, tensor, lvl, a: rhs, lt, isSubExp, coefficient); |
179 | } |
180 | case AffineExprKind::Add: { |
181 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
182 | return findDepIdxSet(merger, tensor, lvl, a: binOp.getLHS(), lt, isSubExp: true) && |
183 | findDepIdxSet(merger, tensor, lvl, a: binOp.getRHS(), lt, isSubExp: true); |
184 | } |
185 | default: |
186 | return false; |
187 | } |
188 | } |
189 | |
190 | /// Gets the total number of compound affine expressions in the |
191 | /// `getMatchingIndexingMap` for the given tensor. For the following inputs: |
192 | /// |
193 | /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed) |
194 | /// |
195 | /// Returns 1 (because the first level is compressed and its corresponding |
196 | /// indexing-expression is `d0 + d1`) |
197 | static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, |
198 | Value tensor) { |
199 | // The `tensor` is not guaranteed to have `RankedTensorType`, therefore |
200 | // we can't use `getRankedTensorType`/`getSparseTensorType` here. |
201 | // However, we don't need to handle `StorageSpecifierType`, so we |
202 | // can use `SparseTensorType` once we guard against non-tensors. |
203 | const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); |
204 | if (!rtp) |
205 | return 0; |
206 | const SparseTensorType stt(rtp); |
207 | |
208 | const Level lvlRank = stt.getLvlRank(); |
209 | const auto exprs = map.getResults(); |
210 | assert(static_cast<Dimension>(exprs.size()) == lvlRank && |
211 | "AffineMap does not have dimension-rank many results" ); |
212 | unsigned num = 0; |
213 | for (Level l = 0; l < lvlRank; l++) { |
214 | if (!isa<AffineDimExpr>(Val: exprs[l]) && !stt.getLvlType(l).hasDenseSemantic()) |
215 | num++; |
216 | } |
217 | return num; |
218 | } |
219 | |
220 | /// Gets the total number of sparse levels with compound affine |
221 | /// expressions, summed over all operands of the `GenericOp`. |
222 | static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { |
223 | unsigned num = 0; |
224 | for (OpOperand &t : op->getOpOperands()) |
225 | num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t), |
226 | t.get()); |
227 | return num; |
228 | } |
229 | |
230 | // Returns true iff output has nontrivial affine indices. |
231 | static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { |
232 | OpOperand *out = op.getDpsInitOperand(0); |
233 | if (getSparseTensorType(val: out->get()).isAllDense()) |
234 | return false; |
235 | return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out), |
236 | out->get()); |
237 | } |
238 | |
239 | /// Helper method to inspect sparse encodings in the tensor types. |
240 | /// Fills the per-dimension sparsity information for all tensors. |
241 | /// Returns true if the sparse annotations and affine subscript |
242 | /// expressions of all tensors are admissible. Returns false if |
243 | /// no annotations are found or inadmissible constructs occur. |
244 | /// We currently support two different ways to handle non-trivial index |
245 | /// expression on sparse tensors, and they accept different affine expressions. |
246 | /// When using dependent index reducton-based approach, it currently only |
247 | /// supports affine addition index expression. |
248 | static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { |
249 | bool annotated = false; |
250 | for (OpOperand &t : env.op()->getOpOperands()) { |
251 | const TensorId tid = env.makeTensorId(t.getOperandNumber()); |
252 | const auto map = env.op().getMatchingIndexingMap(&t); |
253 | const auto enc = getSparseTensorEncoding(t.get().getType()); |
254 | if (enc) |
255 | annotated = true; |
256 | const Level lvlRank = map.getNumResults(); |
257 | assert(!enc || lvlRank == enc.getLvlRank()); |
258 | assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank); |
259 | // We only need to do index reduction if there is at least one |
260 | // non-trivial index expression on sparse levels. If all non-trivial |
261 | // index expression is on dense levels, we can efficiently rely on |
262 | // the random access to locate the element. |
263 | bool needIdxReduc = |
264 | enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0; |
265 | // If then current tensor being inspected requires affine index, it need |
266 | // to be sliced. |
267 | for (Level l = 0; l < lvlRank; l++) { |
268 | const AffineExpr a = map.getResult(l); |
269 | const LevelType lt = enc.getLvlType(l); |
270 | if (idxReducBased && needIdxReduc) { |
271 | if (!findDepIdxSet(env.merger(), tid, l, a, lt)) |
272 | return false; // inadmissible affine expression |
273 | } else { |
274 | if (!findAffine(env.merger(), tid, l, a, lt)) |
275 | return false; // inadmissible affine expression |
276 | } |
277 | } |
278 | } |
279 | return annotated; |
280 | } |
281 | |
282 | //===----------------------------------------------------------------------===// |
283 | // Sparsifier synthesis methods (statements and expressions). |
284 | //===----------------------------------------------------------------------===// |
285 | |
286 | /// Local bufferization of all dense and sparse data structures. |
287 | static void genBuffers(CodegenEnv &env, OpBuilder &builder) { |
288 | linalg::GenericOp op = env.op(); |
289 | Location loc = op.getLoc(); |
290 | assert(op.getNumOperands() == op.getNumDpsInputs() + 1); |
291 | |
292 | SmallVector<Range, 4> loopRange = |
293 | llvm::cast<linalg::LinalgOp>(op.getOperation()) |
294 | .createLoopRanges(builder, loc); |
295 | |
296 | env.emitter().initializeLoopEmit( |
297 | builder, loc, |
298 | /// Generates buffer for the output tensor. |
299 | /// Note that all sparse kernels assume that when all elements are written |
300 | /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized |
301 | /// to all zeroes and only nonzeroes values are computed and written out. |
302 | /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used |
303 | /// for the updates and no assumption on the original contents of the |
304 | /// output buffer is necessary. |
305 | [&op](OpBuilder &builder, Location loc, Value memref, |
306 | Value tensor) -> Value { |
307 | // Must not be a sparse tensor. |
308 | assert(!getSparseTensorEncoding(tensor.getType())); |
309 | // Two output tensor references should point to the same object. |
310 | OpOperand *lhs = op.getDpsInitOperand(0); |
311 | assert(lhs->get() == tensor); |
312 | // An output tensor can simply materialize from the buffer of the tensor |
313 | // that appears in the outs() clause. For updates, this has the |
314 | // advantage that only the nonzero value are involved in the |
315 | // computation, keeping the operation O(nnz). In all other cases, we are |
316 | // forced to zero out the buffer to enforce the assumption above, which |
317 | // may negatively impact running complexity (viz. O(n^2 + nnz) vs. |
318 | // O(nnz) for matrices). |
319 | // TODO: use better analysis to avoid zeroing out the buffer? |
320 | bool isInit = op.isInitTensor(lhs); |
321 | Value init = memref; |
322 | if (!isInit) { |
323 | Value zero = constantZero(builder, loc, |
324 | tp: getElementTypeOrSelf(type: tensor.getType())); |
325 | builder.create<linalg::FillOp>(loc, ValueRange{zero}, |
326 | ValueRange{init}); |
327 | } |
328 | return init; |
329 | }, |
330 | [&loopRange](OpBuilder &b, Location loc, Level l) { |
331 | assert(l < loopRange.size()); |
332 | return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size); |
333 | }); |
334 | } |
335 | |
336 | /// Generates index for load/store on sparse tensor. |
337 | static Value genIndex(CodegenEnv &env, OpOperand *t) { |
338 | const auto map = env.op().getMatchingIndexingMap(t); |
339 | const auto stt = getSparseTensorType(val: t->get()); |
340 | const Level lvlRank = stt.getLvlRank(); |
341 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
342 | const AffineExpr a = map.getResult(lvlRank - 1); |
343 | assert(a.getKind() == AffineExprKind::DimId); |
344 | const LoopId idx = env.makeLoopId(i: cast<AffineDimExpr>(Val: a).getPosition()); |
345 | return env.getLoopVar(i: idx); |
346 | } |
347 | |
348 | /// Generates subscript for load/store on a dense or sparse tensor. |
349 | static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
350 | SmallVectorImpl<Value> &args) { |
351 | const Location loc = env.op().getLoc(); |
352 | const TensorId tid = env.makeTensorId(t: t->getOperandNumber()); |
353 | const auto map = env.op().getMatchingIndexingMap(t); |
354 | const auto stt = getSparseTensorType(val: t->get()); |
355 | if (stt.hasEncoding()) { |
356 | // For sparse tensors we only push the last-level's position onto `args`. |
357 | const auto pos = env.emitter().getValPosits(tid); |
358 | assert(!pos.empty()); |
359 | args.append(RHS: pos); |
360 | } else { |
361 | // For dense tensors we push all level's coordinates onto `args`. |
362 | const Level lvlRank = stt.getLvlRank(); |
363 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
364 | for (Level l = 0; l < lvlRank; l++) { |
365 | const auto lvlExpr = map.getResult(l); |
366 | const auto lvlCrd = env.emitter().genAffine(builder, loc, a: lvlExpr); |
367 | args.push_back(Elt: lvlCrd); |
368 | } |
369 | } |
370 | return env.emitter().getValBuffer()[tid]; |
371 | } |
372 | |
373 | /// Generates insertion code to implement dynamic tensor load. |
374 | static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, |
375 | OpOperand *t) { |
376 | linalg::GenericOp op = env.op(); |
377 | Location loc = op.getLoc(); |
378 | // Direct lexicographic coordinate order, tensor loads as zero. |
379 | if (!env.isExpand()) { |
380 | Type tp = getElementTypeOrSelf(type: t->get().getType()); |
381 | return constantZero(builder, loc, tp); |
382 | } |
383 | // Load from expanded access pattern. |
384 | Value index = genIndex(env, t); |
385 | return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index); |
386 | } |
387 | |
388 | /// Generates insertion code to implement dynamic tensor load for reduction. |
389 | static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, |
390 | OpOperand *t) { |
391 | linalg::GenericOp op = env.op(); |
392 | Location loc = op.getLoc(); |
393 | Value identity = env.getCustomRedId(); |
394 | // Direct lexicographic coordinate order, tensor loads as identity. |
395 | if (!env.isExpand()) |
396 | return identity; |
397 | // Load from expanded access pattern if filled, identity otherwise. |
398 | Value values = env.getExpandValues(); |
399 | Value filled = env.getExpandFilled(); |
400 | Value index = genIndex(env, t); |
401 | Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); |
402 | Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index); |
403 | return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity); |
404 | } |
405 | |
406 | /// Generates insertion code to implement dynamic tensor store. |
407 | static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
408 | Value rhs) { |
409 | linalg::GenericOp op = env.op(); |
410 | Location loc = op.getLoc(); |
411 | // Direct insertion in lexicographic coordinate order. |
412 | if (!env.isExpand()) { |
413 | const LoopId numLoops = op.getRank(t); |
414 | // Retrieves the first `numLoop` induction variables. |
415 | SmallVector<Value> ivs = llvm::to_vector(Range: llvm::drop_end( |
416 | RangeOrContainer: env.emitter().getLoopIVsRange(), N: env.getCurrentDepth() - numLoops)); |
417 | Value chain = env.getInsertionChain(); |
418 | if (env.isValidLexInsert()) { |
419 | // Generates runtime check for a valid lex during reduction, |
420 | // to avoid inserting the identity value for empty reductions. |
421 | // if (validLexInsert) then |
422 | // insert(rhs) into chain |
423 | // return updated chain |
424 | // else |
425 | // return unmodified chain |
426 | scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>( |
427 | loc, chain.getType(), env.getValidLexInsert(), |
428 | /*else=*/true); |
429 | // True branch. |
430 | builder.setInsertionPointToStart(ifValidLexInsert.thenBlock()); |
431 | Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs); |
432 | builder.create<scf::YieldOp>(loc, res); |
433 | // False branch. |
434 | builder.setInsertionPointToStart(ifValidLexInsert.elseBlock()); |
435 | builder.create<scf::YieldOp>(loc, chain); |
436 | // Value assignment. |
437 | builder.setInsertionPointAfter(ifValidLexInsert); |
438 | env.updateInsertionChain(chain: ifValidLexInsert.getResult(0)); |
439 | } else { |
440 | // Generates regular insertion chain. |
441 | env.updateInsertionChain( |
442 | builder.create<tensor::InsertOp>(loc, rhs, chain, ivs)); |
443 | } |
444 | return; |
445 | } |
446 | // Generates insertion code along expanded access pattern. |
447 | // if (!expFilled[i]) then |
448 | // expFilled[i] = true |
449 | // expAdded[inserts++] = i |
450 | // endif |
451 | // values[i] = rhs |
452 | Value values = env.getExpandValues(); |
453 | Value filled = env.getExpandFilled(); |
454 | Value added = env.getExpandAdded(); |
455 | Value count = env.getExpandCount(); |
456 | Value index = genIndex(env, t); |
457 | Value fval = constantI1(builder, loc, b: false); |
458 | Value tval = constantI1(builder, loc, b: true); |
459 | // If statement. |
460 | Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); |
461 | Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
462 | isFilled, fval); |
463 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, |
464 | /*else=*/true); |
465 | // True branch. |
466 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
467 | builder.create<memref::StoreOp>(loc, tval, filled, index); |
468 | builder.create<memref::StoreOp>(loc, index, added, count); |
469 | Value one = constantIndex(builder, loc, i: 1); |
470 | Value add = builder.create<arith::AddIOp>(loc, count, one); |
471 | builder.create<scf::YieldOp>(loc, add); |
472 | // False branch. |
473 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
474 | builder.create<scf::YieldOp>(loc, count); |
475 | builder.setInsertionPointAfter(ifOp); |
476 | // Value assignment. |
477 | env.updateExpandCount(count: ifOp.getResult(0)); |
478 | builder.create<memref::StoreOp>(loc, rhs, values, index); |
479 | } |
480 | |
481 | /// Generates a load on a dense or sparse tensor. |
482 | static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { |
483 | // Test if the load was hoisted to a higher loop nest. |
484 | Value val = env.exp(e: exp).val; |
485 | if (val) |
486 | return val; |
487 | // Load during insertion. |
488 | linalg::GenericOp op = env.op(); |
489 | OpOperand *t = &op->getOpOperand(env.exp(e: exp).tensor); |
490 | if (env.isSparseOutput(o: t)) { |
491 | if (env.isCustomReduc()) |
492 | return genInsertionLoadReduce(env, builder, t); |
493 | return genInsertionLoad(env, builder, t); |
494 | } |
495 | // Actual load. |
496 | SmallVector<Value> args; |
497 | Value ptr = genSubscript(env, builder, t, args); |
498 | return builder.create<memref::LoadOp>(op.getLoc(), ptr, args); |
499 | } |
500 | |
501 | /// Generates a store on a dense or sparse tensor. |
502 | static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
503 | Value rhs) { |
504 | // Only unary and binary are allowed to return an uninitialized rhs |
505 | // to indicate missing output. Or otherwise a custom reduction that |
506 | // received no value to accumulate. |
507 | if (!rhs) { |
508 | assert(env.exp(exp).kind == TensorExp::Kind::kUnary || |
509 | env.exp(exp).kind == TensorExp::Kind::kBinary || |
510 | env.exp(exp).kind == TensorExp::Kind::kReduce); |
511 | return; |
512 | } |
513 | // Test if this is a scalarized reduction. |
514 | if (env.isReduc()) { |
515 | env.updateReduc(val: rhs); |
516 | return; |
517 | } |
518 | // Regular store. |
519 | linalg::GenericOp op = env.op(); |
520 | Location loc = op.getLoc(); |
521 | OpOperand *t = op.getDpsInitOperand(0); |
522 | if (!env.isSparseOutput(o: t)) { |
523 | SmallVector<Value> args; |
524 | Value ptr = genSubscript(env, builder, t, args); |
525 | builder.create<memref::StoreOp>(loc, rhs, ptr, args); |
526 | return; |
527 | } |
528 | // Store during sparse insertion. |
529 | if (env.exp(e: exp).kind != TensorExp::Kind::kSelect) { |
530 | genInsertionStore(env, builder, t, rhs); |
531 | return; |
532 | } |
533 | // Select operation insertion. |
534 | Value chain = env.getInsertionChain(); |
535 | scf::IfOp ifOp = |
536 | builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true); |
537 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
538 | // Existing value was preserved to be used here. |
539 | assert(env.exp(exp).val); |
540 | Value v0 = env.exp(e: exp).val; |
541 | genInsertionStore(env, builder, t, rhs: v0); |
542 | env.merger().clearExprValue(e: exp); |
543 | // Yield modified insertion chain along true branch. |
544 | Value mchain = env.getInsertionChain(); |
545 | builder.create<scf::YieldOp>(op.getLoc(), mchain); |
546 | // Yield original insertion chain along false branch. |
547 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
548 | builder.create<scf::YieldOp>(loc, chain); |
549 | // Done with if statement. |
550 | env.updateInsertionChain(chain: ifOp->getResult(0)); |
551 | builder.setInsertionPointAfter(ifOp); |
552 | } |
553 | |
554 | /// Generates an invariant value. |
555 | inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { |
556 | return env.exp(e: exp).val; |
557 | } |
558 | |
559 | /// Semi-ring branches are simply inlined by the sparsifier. Prior |
560 | /// analysis has verified that all computations are "local" to the inlined |
561 | /// branch or otherwise invariantly defined outside the loop nest, with the |
562 | /// exception of index computations, which need to be relinked to actual |
563 | /// inlined cloned code. |
564 | static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, |
565 | Value e) { |
566 | if (auto arg = dyn_cast<BlockArgument>(Val&: e)) { |
567 | // Direct arguments of the original linalg op must be converted |
568 | // into dense tensor loads. Note that we should not encounter |
569 | // anything else. This needs to be verified by semi-ring ops. |
570 | linalg::GenericOp op = env.op(); |
571 | if (arg.getOwner()->getParentOp() == op) { |
572 | const TensorId tid = env.makeTensorId(t: arg.getArgNumber()); |
573 | OpOperand *t = &op->getOpOperand(tid); |
574 | assert(!getSparseTensorType(t->get()).hasEncoding()); // dense! |
575 | SmallVector<Value> args; |
576 | Value ptr = genSubscript(env, builder&: rewriter, t, args); |
577 | return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); |
578 | } |
579 | } else if (Operation *def = e.getDefiningOp()) { |
580 | // Handle index computation. |
581 | if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) |
582 | return env.getLoopVar(i: env.makeLoopId(i: indexOp.getDim())); |
583 | // When still defined in new body, recurse into operands. |
584 | if (def->getBlock() == block) { |
585 | rewriter.setInsertionPoint(def); |
586 | for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { |
587 | rewriter.modifyOpInPlace(root: def, callable: [&]() { |
588 | def->setOperand( |
589 | idx: i, value: relinkBranch(env, rewriter, block, e: def->getOperand(idx: i))); |
590 | }); |
591 | } |
592 | } |
593 | } |
594 | return e; |
595 | } |
596 | |
597 | /// Recursively generates tensor expression. |
598 | static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { |
599 | if (e == ::mlir::sparse_tensor::detail::kInvalidId) |
600 | return Value(); |
601 | |
602 | linalg::GenericOp op = env.op(); |
603 | Location loc = op.getLoc(); |
604 | const TensorExp &exp = env.exp(e); |
605 | const auto kind = exp.kind; |
606 | if (kind == TensorExp::Kind::kTensor) |
607 | return genTensorLoad(env, builder&: rewriter, exp: e); |
608 | if (kind == TensorExp::Kind::kInvariant) |
609 | return genInvariantValue(env, exp: e); |
610 | if (kind == TensorExp::Kind::kLoopVar) |
611 | return env.getLoopVar(i: exp.loop); |
612 | |
613 | if (kind == TensorExp::Kind::kReduce) |
614 | env.startCustomReduc(exp: e); // enter custom |
615 | |
616 | // If either lhs/rhs is a synthetic zero, we infer the type for the zero value |
617 | // based on the type of the other operand. |
618 | Value v0, v1; |
619 | if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && |
620 | env.exp(e: exp.children.e0).kind == TensorExp::Kind::kSynZero) { |
621 | v1 = genExp(env, rewriter, e: exp.children.e1); |
622 | v0 = constantZero(builder&: rewriter, loc, tp: v1.getType()); |
623 | } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId && |
624 | env.exp(e: exp.children.e1).kind == TensorExp::Kind::kSynZero) { |
625 | v0 = genExp(env, rewriter, e: exp.children.e0); |
626 | v1 = constantZero(builder&: rewriter, loc, tp: v0.getType()); |
627 | } else { |
628 | v0 = genExp(env, rewriter, e: exp.children.e0); |
629 | v1 = genExp(env, rewriter, e: exp.children.e1); |
630 | } |
631 | |
632 | Value ee; |
633 | if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) { |
634 | // custom reduce did not receive a value |
635 | } else { |
636 | ee = env.merger().buildExp(rewriter, loc, e, v0, v1); |
637 | if (ee && |
638 | (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || |
639 | kind == TensorExp::Kind::kBinaryBranch || |
640 | kind == TensorExp::Kind::kReduce || |
641 | kind == TensorExp::Kind::kSelect)) { |
642 | OpBuilder::InsertionGuard guard(rewriter); |
643 | ee = relinkBranch(env, rewriter, block: ee.getParentBlock(), e: ee); |
644 | } |
645 | } |
646 | |
647 | if (kind == TensorExp::Kind::kReduce) |
648 | env.endCustomReduc(); // exit custom |
649 | |
650 | if (kind == TensorExp::Kind::kSelect) |
651 | env.merger().setExprValue(e, v: v0); // Preserve value for later use. |
652 | |
653 | return ee; |
654 | } |
655 | |
656 | /// Hoists loop invariant tensor loads for which indices have been exhausted. |
657 | static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
658 | LoopId curr, bool isStart) { |
659 | if (exp == ::mlir::sparse_tensor::detail::kInvalidId) |
660 | return; |
661 | if (env.exp(e: exp).kind == TensorExp::Kind::kTensor) { |
662 | // Inspect tensor indices. |
663 | linalg::GenericOp op = env.op(); |
664 | OpOperand &t = op->getOpOperand(env.exp(e: exp).tensor); |
665 | const auto map = op.getMatchingIndexingMap(&t); |
666 | const auto stt = getSparseTensorType(val: t.get()); |
667 | const Level lvlRank = stt.getLvlRank(); |
668 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
669 | bool isCurrentLoop = curr == 0; // for scalar tensors |
670 | for (Level l = 0; l < lvlRank; l++) { |
671 | const AffineExpr a = map.getResult(l); |
672 | if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop)) |
673 | return; // still in play |
674 | } |
675 | // All exhausted at current level. |
676 | if (!isCurrentLoop) |
677 | return; |
678 | // Generate code for a scalarized reduction or invariant. Note that |
679 | // because custom reduction lhs may occur several times in the IR, |
680 | // we have a built-in safety for only initializing and wrapping-up |
681 | // the scalarized reduction once. |
682 | OpOperand *lhs = op.getDpsInitOperand(0); |
683 | if (lhs == &t) { |
684 | // Start or end a scalarized reduction. |
685 | if (isStart) { |
686 | if (env.isCustomReduc()) { |
687 | if (!env.isReduc()) |
688 | env.startReduc(exp, val: env.getCustomRedId()); |
689 | } else { |
690 | env.startReduc(exp, val: genTensorLoad(env, builder, exp)); |
691 | } |
692 | if (env.hasSparseOutput()) |
693 | env.startValidLexInsert( |
694 | val: constantI1(builder, env.op().getLoc(), false)); |
695 | } else { |
696 | if (!env.isCustomReduc() || env.isReduc()) |
697 | genTensorStore(env, builder, exp, rhs: env.endReduc()); |
698 | if (env.hasSparseOutput()) |
699 | env.endValidLexInsert(); |
700 | } |
701 | } else { |
702 | // Start or end loop invariant hoisting of a tensor load. |
703 | if (isStart) { |
704 | env.merger().setExprValue(e: exp, v: genTensorLoad(env, builder, exp)); |
705 | } else { |
706 | env.merger().clearExprValue(e: exp); |
707 | } |
708 | } |
709 | } else if (env.exp(e: exp).kind != TensorExp::Kind::kInvariant && |
710 | env.exp(e: exp).kind != TensorExp::Kind::kLoopVar && |
711 | env.exp(e: exp).kind != TensorExp::Kind::kSynZero) { |
712 | // Traverse into the binary operations. Note that we only hoist |
713 | // tensor loads, since subsequent MLIR/LLVM passes know how to |
714 | // deal with all other kinds of derived loop invariants. |
715 | if (env.exp(e: exp).kind == TensorExp::Kind::kReduce) |
716 | env.startCustomReduc(exp); // enter custom |
717 | const ExprId e0 = env.exp(e: exp).children.e0; |
718 | const ExprId e1 = env.exp(e: exp).children.e1; |
719 | genInvariants(env, builder, exp: e0, curr, isStart); |
720 | genInvariants(env, builder, exp: e1, curr, isStart); |
721 | if (env.exp(e: exp).kind == TensorExp::Kind::kReduce) |
722 | env.endCustomReduc(); // exit custom |
723 | } |
724 | } |
725 | |
726 | /// Generates an expanded access pattern in innermost dimension. |
727 | static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
728 | bool isStart) { |
729 | linalg::GenericOp op = env.op(); |
730 | OpOperand *lhs = op.getDpsInitOperand(0); |
731 | if (!env.atExpandLevel(o: lhs, rank: op.getRank(lhs), n: curr)) |
732 | return; // not needed at current level |
733 | assert(!env.isReduc()); |
734 | // Generate start or end of an expanded access pattern. Note that because |
735 | // an expansion does not rely on the ongoing contents of the sparse storage |
736 | // scheme, we can use the original tensor as incoming SSA value (which |
737 | // simplifies codegen a bit). If expansion on the actual contents is ever |
738 | // needed, we will need to use the SSA value in the insertion chain instead. |
739 | Value tensor = lhs->get(); |
740 | Location loc = op.getLoc(); |
741 | if (isStart) { |
742 | auto dynShape = {ShapedType::kDynamic}; |
743 | Type etp = cast<ShapedType>(tensor.getType()).getElementType(); |
744 | Type t1 = MemRefType::get(dynShape, etp); |
745 | Type t2 = MemRefType::get(dynShape, builder.getI1Type()); |
746 | Type t3 = MemRefType::get(dynShape, builder.getIndexType()); |
747 | Type t4 = builder.getIndexType(); |
748 | auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); |
749 | assert(r.getNumResults() == 4); |
750 | env.startExpand(values: r.getResult(0), filled: r.getResult(1), added: r.getResult(2), |
751 | count: r.getResult(3)); |
752 | } else { |
753 | SmallVector<Value> indices; |
754 | for (LoopId i = 0; i < curr; i++) |
755 | indices.push_back(Elt: env.emitter().getLoopIV(n: i)); |
756 | Value values = env.getExpandValues(); |
757 | Value filled = env.getExpandFilled(); |
758 | Value added = env.getExpandAdded(); |
759 | Value count = env.getExpandCount(); |
760 | Value chain = env.getInsertionChain(); |
761 | Value compress = builder.create<CompressOp>(loc, values, filled, added, |
762 | count, chain, indices); |
763 | env.updateInsertionChain(chain: compress); |
764 | env.endExpand(); |
765 | } |
766 | } |
767 | |
768 | /// Returns parallelization strategy. Any implicit loop in the Linalg |
769 | /// operation that is marked "parallel" is a candidate. Whether it is actually |
770 | /// converted to a parallel operation depends on the requested strategy. |
771 | static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { |
772 | // Reject parallelization of sparse output. |
773 | if (env.hasSparseOutput()) |
774 | return false; |
775 | // Parallel loops on tensor expansion can cause data races. |
776 | if (env.isExpand()) |
777 | return false; |
778 | // Inspect strategy. |
779 | switch (env.options().parallelizationStrategy) { |
780 | case SparseParallelizationStrategy::kNone: |
781 | return false; |
782 | case SparseParallelizationStrategy::kDenseOuterLoop: |
783 | return isOuter && !isSparse; |
784 | case SparseParallelizationStrategy::kAnyStorageOuterLoop: |
785 | return isOuter; |
786 | case SparseParallelizationStrategy::kDenseAnyLoop: |
787 | return !isSparse; |
788 | case SparseParallelizationStrategy::kAnyStorageAnyLoop: |
789 | return true; |
790 | } |
791 | llvm_unreachable("unexpected parallelization strategy" ); |
792 | } |
793 | |
794 | /// Whether or not the current loop being generated should be parallized (if |
795 | /// possible) according to the configuration. |
796 | static bool shouldTryParallize(CodegenEnv &env, LoopId curr, |
797 | ArrayRef<TensorLevel> tidLvls) { |
798 | linalg::GenericOp op = env.op(); |
799 | auto iteratorTypes = op.getIteratorTypesArray(); |
800 | bool isSparse = llvm::any_of(Range&: tidLvls, P: [curr, &env](TensorLevel tidLvl) { |
801 | // Queries the LT based on the tensor and loop id, as requested by |
802 | // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv |
803 | // should be consistent with the LT indexed by <TensorId, Level>. |
804 | const auto lt = env.lt(t: env.unpackTensorLevel(tl: tidLvl).first, i: curr); |
805 | return lt.hasSparseSemantic(); |
806 | }); |
807 | return isParallelFor(env, /*isOuter=*/curr == 0, isSparse); |
808 | } |
809 | |
810 | /// Emit a loop to coiterate over the list of tensor levels. The generated loop |
811 | /// can either be a for loop or while loop depending on whether there is at most |
812 | /// one sparse level in the list. |
813 | static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, |
814 | ArrayRef<TensorLevel> tidLvls, |
815 | bool tryParallel, bool needsUniv) { |
816 | Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { |
817 | // Construct while-loop with a parameter for each index. |
818 | return env.emitter().enterCoIterationOverTensorsAtLvls( |
819 | builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv); |
820 | }); |
821 | assert(loop); |
822 | return loop; |
823 | } |
824 | |
825 | /// Generates a for-loop or a while-loop, depending on whether it implements |
826 | /// singleton iteration or co-iteration over the given conjunction. |
827 | static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
828 | bool needsUniv, ArrayRef<TensorLevel> tidLvls) { |
829 | bool tryParallel = shouldTryParallize(env, curr, tidLvls); |
830 | return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv); |
831 | } |
832 | |
833 | /// Generates the induction structure for a while-loop. |
834 | static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, |
835 | bool needsUniv) { |
836 | Location loc = env.op().getLoc(); |
837 | // Finalize each else branch of all if statements. |
838 | if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { |
839 | while (auto ifOp = dyn_cast_or_null<scf::IfOp>( |
840 | builder.getInsertionBlock()->getParentOp())) { |
841 | // Break on IfOp for slicing filtering. |
842 | if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) == |
843 | StringAttr::get(ifOp->getContext(), "slice" )) |
844 | break; |
845 | |
846 | unsigned y = 0; |
847 | SmallVector<Value> yields; |
848 | if (env.isReduc()) { |
849 | yields.push_back(Elt: env.getReduc()); |
850 | env.updateReduc(val: ifOp.getResult(y++)); |
851 | if (env.isValidLexInsert()) { |
852 | yields.push_back(Elt: env.getValidLexInsert()); |
853 | env.updateValidLexInsert(val: ifOp.getResult(y++)); |
854 | } |
855 | } |
856 | if (env.isExpand()) { |
857 | yields.push_back(Elt: env.getExpandCount()); |
858 | env.updateExpandCount(count: ifOp->getResult(y++)); |
859 | } |
860 | if (env.getInsertionChain()) { |
861 | yields.push_back(Elt: env.getInsertionChain()); |
862 | env.updateInsertionChain(chain: ifOp->getResult(y++)); |
863 | } |
864 | assert(y == yields.size()); |
865 | builder.create<scf::YieldOp>(loc, yields); |
866 | builder.setInsertionPointAfter(ifOp); |
867 | } |
868 | } |
869 | // No need to set the insertion point here as LoopEmitter keeps track of the |
870 | // basic block where scf::Yield should be inserted. |
871 | } |
872 | |
873 | /// Generates a single if-statement within a while-loop. |
874 | static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
875 | LatPointId p) { |
876 | Location loc = env.op().getLoc(); |
877 | SmallVector<Type> types; |
878 | Value cond; |
879 | env.merger().foreachTensorLoopId( |
880 | p, /*simple=*/true, |
881 | callback: [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, |
882 | bool isIdxRed) { |
883 | if (isIdxRed) { |
884 | // Since there is no 1:1 mapping from loop to level (multiple loops |
885 | // are required to resolve one level with non-trivial index |
886 | // expression), we need to reconstruct the tensor level types if this |
887 | // loop requires index reduction condition. |
888 | assert(lvl.has_value() && isUndefLT(lt)); |
889 | auto stt = getSparseTensorType(env.op().getInputs()[tid]); |
890 | lt = stt.getLvlType(*lvl); |
891 | } |
892 | assert(curr == env.merger().loop(b)); |
893 | Value clause; |
894 | if (lt.hasSparseSemantic()) { |
895 | assert(lvl.has_value()); |
896 | const Value crd = env.emitter().getCoord(tid, lvl: *lvl); |
897 | const Value lvar = env.getLoopVar(i: curr); |
898 | clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
899 | crd, lvar); |
900 | } else { |
901 | assert(lt.hasDenseSemantic() || isUndefLT(lt)); |
902 | clause = constantI1(builder, loc, b: true); |
903 | } |
904 | cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; |
905 | }); |
906 | if (env.isReduc()) { |
907 | types.push_back(Elt: env.getReduc().getType()); |
908 | if (env.isValidLexInsert()) |
909 | types.push_back(Elt: env.getValidLexInsert().getType()); |
910 | } |
911 | if (env.isExpand()) |
912 | types.push_back(builder.getIndexType()); |
913 | if (env.getInsertionChain()) |
914 | types.push_back(Elt: env.getInsertionChain().getType()); |
915 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
916 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
917 | return ifOp; |
918 | } |
919 | |
920 | /// Generates end of true branch of if-statement within a while-loop. |
921 | static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, |
922 | Value redInput, Value cntInput, Value insInput, |
923 | Value validIns) { |
924 | SmallVector<Value> operands; |
925 | if (env.isReduc()) { |
926 | operands.push_back(Elt: env.getReduc()); |
927 | env.updateReduc(val: redInput); |
928 | if (env.isValidLexInsert()) { |
929 | // Any overlapping indices during a reduction creates a valid lex insert. |
930 | operands.push_back(Elt: constantI1(builder, env.op().getLoc(), true)); |
931 | env.updateValidLexInsert(val: validIns); |
932 | } |
933 | } |
934 | if (env.isExpand()) { |
935 | operands.push_back(Elt: env.getExpandCount()); |
936 | env.updateExpandCount(count: cntInput); |
937 | } |
938 | if (env.getInsertionChain()) { |
939 | operands.push_back(Elt: env.getInsertionChain()); |
940 | env.updateInsertionChain(chain: insInput); |
941 | } |
942 | if (!operands.empty()) |
943 | builder.create<scf::YieldOp>(env.op().getLoc(), operands); |
944 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
945 | } |
946 | |
947 | //===----------------------------------------------------------------------===// |
948 | // Sparsifier synthesis methods (loop sequence). |
949 | //===----------------------------------------------------------------------===// |
950 | |
951 | static bool getAllTidLvlsInLatPoints( |
952 | CodegenEnv &env, LatPointId li, LoopId curr, |
953 | llvm::function_ref<void(TensorLevel, AffineExpr)> callback) { |
954 | const BitVector &simple = env.lat(l: li).simple; |
955 | const TensorId outTid = env.merger().getOutTensorID(); |
956 | const std::optional<Level> outLvl = env.merger().getLvl(t: outTid, i: curr); |
957 | |
958 | unsigned numloopCond = 0; |
959 | bool hasNonUnique = false; |
960 | env.merger().foreachTensorLoopId( |
961 | p: li, callback: [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, |
962 | LevelType lt, bool isIdxReduc) { |
963 | if (simple[b]) { |
964 | if (isIdxReduc) { |
965 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
966 | numloopCond++; |
967 | return; |
968 | } |
969 | if (isUndefLT(lt)) { |
970 | // An undefined lt in the lattices, we probably mean to |
971 | // generate a dense loop according to the synthetic tensor (for |
972 | // invariants and sparse output tensor). |
973 | if (env.merger().getSynTensorID() == tid) { |
974 | // Coiterating with an invariant |
975 | // e.g., out = prod(in[i][j] op invariant); |
976 | // or a broadcast |
977 | // e.g., out[i][j] = in[i] (j is undef for input) |
978 | // |
979 | // The level of the synthetic tensor is the current loop depth; |
980 | // the rank of the synthetic tensor equals to number of loops. |
981 | assert(curr == env.getCurrentDepth()); |
982 | lvl = curr; |
983 | } else if (!lvl) { |
984 | // Skips invalid lvl (e.g., when this is a zero ranked tensor). |
985 | return; |
986 | } |
987 | } |
988 | hasNonUnique = !isUniqueLT(lt) || hasNonUnique; |
989 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
990 | numloopCond++; |
991 | } else if (lt.hasDenseSemantic() || isIdxReduc) { |
992 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
993 | } else { |
994 | assert(isUndefLT(lt)); |
995 | linalg::GenericOp op = env.op(); |
996 | if (tid >= op.getNumDpsInputs()) |
997 | // We only handle affine expression on input tensors (for now). |
998 | return; |
999 | OpOperand *operand = &op->getOpOperand(tid); |
1000 | const auto stt = getSparseTensorType(val: operand->get()); |
1001 | // Non-annotated dense tensors requires no special handling. |
1002 | if (!stt.hasEncoding()) |
1003 | return; |
1004 | |
1005 | ArrayRef<AffineExpr> affines = |
1006 | op.getMatchingIndexingMap(operand).getResults(); |
1007 | const Level lvlRank = stt.getLvlRank(); |
1008 | assert(affines.size() == static_cast<size_t>(lvlRank)); |
1009 | for (Level l = 0; l < lvlRank; l++) { |
1010 | AffineExpr exp = affines[l]; |
1011 | // Skip simple affine expression and non-dense levels (which |
1012 | // have their own filter loop). |
1013 | LevelType lt = stt.getLvlType(l); |
1014 | if (isa<AffineDimExpr>(Val: exp) || !lt.hasDenseSemantic()) |
1015 | continue; |
1016 | |
1017 | // Constant affine expression are handled in genLoop. |
1018 | if (!isa<AffineConstantExpr>(Val: exp)) { |
1019 | bool isCurrentLoop = false; |
1020 | assert(curr == env.getCurrentDepth()); |
1021 | if (isInvariantAffine(a: exp, curr: curr + 1, /*out*/ isCurrentLoop) && |
1022 | isCurrentLoop) { |
1023 | // If the compound affine is invariant and we are right at the |
1024 | // level. We need to generate the address according to the |
1025 | // affine expression. This is also the best place we can do it |
1026 | // to avoid putting it inside inner loops. |
1027 | callback(env.makeTensorLevel(t: tid, l), exp); |
1028 | } |
1029 | } |
1030 | } |
1031 | } |
1032 | }); |
1033 | |
1034 | if (isDenseLT(lt: env.lt(t: outTid, i: curr))) { |
1035 | auto stt = getSparseTensorType(env.op().getOutputs().front()); |
1036 | // Note that we generate dense indices of the output tensor unconditionally, |
1037 | // since they may not appear in the lattice, but may be needed for |
1038 | // linearized env. |
1039 | // TODO: we should avoid introducing corner cases for all-dense sparse |
1040 | // tensors. |
1041 | if (stt.hasEncoding() && stt.isAllDense()) |
1042 | callback(env.makeTensorLevel(t: outTid, l: *outLvl), nullptr); |
1043 | } |
1044 | |
1045 | if (numloopCond == 0) { |
1046 | // Corner cases where the loop bound is defined by a *unused* operand, in |
1047 | // this case, we just generate a dense "fake" loop by iterating over the |
1048 | // synthetic tensor. |
1049 | callback(env.makeTensorLevel(t: env.merger().getSynTensorID(), l: curr), nullptr); |
1050 | numloopCond++; |
1051 | } |
1052 | // If we just need to one loop conditions and the conditions is not imposed on |
1053 | // non-unique level, the loop can be generated by a for loop. |
1054 | return numloopCond == 1 && !hasNonUnique; |
1055 | } |
1056 | |
1057 | /// Starts a loop sequence at given level. Returns true if |
1058 | /// the universal loop index must be maintained at this level. |
1059 | static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
1060 | LoopId curr, LatSetId lts) { |
1061 | assert(!env.getLoopVar(curr)); |
1062 | // Emit invariants at this loop sequence level. |
1063 | genInvariants(env, builder, exp, curr, /*isStart=*/true); |
1064 | // Emit access pattern expansion for sparse tensor output. |
1065 | genExpand(env, builder, curr, /*isStart=*/true); |
1066 | // Emit further initialization at this loop sequence level. |
1067 | const LatPointId l0 = env.set(lts)[0]; |
1068 | |
1069 | SmallVector<TensorLevel> tidLvls; |
1070 | getAllTidLvlsInLatPoints(env, li: l0, curr, callback: [&](TensorLevel tl, AffineExpr) { |
1071 | // TODO: remove this! The same tensor level might be added for multiple |
1072 | // times due to the special handling for all-dense "sparse" output tensor |
1073 | // (see L1038). |
1074 | if (llvm::find(Range&: tidLvls, Val: tl) != tidLvls.end()) |
1075 | return; |
1076 | tidLvls.emplace_back(Args&: tl); |
1077 | }); |
1078 | |
1079 | env.emitter().enterNewLoopSeq(builder, loc: env.op().getLoc(), tidLvls); |
1080 | |
1081 | // Maintain the universal index only if it is actually |
1082 | // consumed by a subsequent lattice point. |
1083 | for (const LatPointId li : env.set(lts).drop_front()) |
1084 | if (!env.merger().hasAnySparse(bits: env.lat(l: li).simple)) |
1085 | return true; |
1086 | |
1087 | return false; |
1088 | } |
1089 | |
1090 | // Generates dense affine address for encoding. |
1091 | static void genConstantDenseAddressFromLevel(CodegenEnv &env, |
1092 | OpBuilder &builder, TensorId tid, |
1093 | Level startLvl) { |
1094 | // TODO: Handle affine expression on output tensor. |
1095 | linalg::GenericOp op = env.op(); |
1096 | assert(tid < op.getNumDpsInputs()); |
1097 | OpOperand *input = op.getDpsInputOperands()[tid]; |
1098 | const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); |
1099 | const auto enc = getSparseTensorEncoding(input->get().getType()); |
1100 | if (enc) { |
1101 | const Location loc = op.getLoc(); |
1102 | const TensorId tid = env.makeTensorId(t: input->getOperandNumber()); |
1103 | const Level lvlRank = enc.getLvlRank(); |
1104 | assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); |
1105 | for (Level l = startLvl; l < lvlRank; l++) { |
1106 | AffineExpr lvlExpr = lvlExprs[l]; |
1107 | if (enc.getLvlType(l).hasDenseSemantic() && |
1108 | isa<AffineConstantExpr>(Val: lvlExpr)) |
1109 | env.emitter().locateLvlAtAffineAddress( |
1110 | builder, loc, tidLvl: env.makeTensorLevel(t: tid, l), lvlExpr); |
1111 | else |
1112 | return; // break on first non-dense non-constant level |
1113 | } |
1114 | } |
1115 | } |
1116 | |
1117 | // We can generate address for constant affine expression before any loops |
1118 | // starting from the first level as they do not depend on anything. |
1119 | // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two |
1120 | // levels can be determined before loops. |
1121 | static void genInitConstantDenseAddress(CodegenEnv &env, |
1122 | RewriterBase &rewriter) { |
1123 | for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) |
1124 | genConstantDenseAddressFromLevel(env, builder&: rewriter, tid, startLvl: 0); |
1125 | } |
1126 | |
1127 | /// Returns true if the lattice bit can be iterated by a for loop. |
1128 | static bool translateBitsToTidLvlPairs( |
1129 | CodegenEnv &env, LatPointId li, LoopId curr, |
1130 | SmallVectorImpl<TensorLevel> &tidLvls, |
1131 | SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { |
1132 | return getAllTidLvlsInLatPoints(env, li, curr, |
1133 | callback: [&](TensorLevel tl, AffineExpr exp) { |
1134 | if (exp) |
1135 | affineTidLvls.emplace_back(Args&: tl, Args&: exp); |
1136 | else |
1137 | tidLvls.emplace_back(Args&: tl); |
1138 | }); |
1139 | } |
1140 | |
1141 | /// Starts a single loop in current sequence. |
1142 | static std::pair<Operation *, bool> startLoop(CodegenEnv &env, |
1143 | OpBuilder &builder, LoopId curr, |
1144 | LatPointId li, bool needsUniv) { |
1145 | // The set of tensors + lvls to generate loops on |
1146 | SmallVector<TensorLevel> tidLvls; |
1147 | |
1148 | // The set of dense tensors with non-trivial affine expression that just |
1149 | // becomes invariant and the address are generated at the current level. |
1150 | SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls; |
1151 | bool isSingleCond = |
1152 | translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls); |
1153 | |
1154 | // Emit the for/while-loop control. |
1155 | Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls); |
1156 | Location loc = env.op().getLoc(); |
1157 | for (auto [tidLvl, exp] : affineTidLvls) { |
1158 | env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, lvlExpr: exp); |
1159 | } |
1160 | |
1161 | // Until now, we have entered every <tid, lvl> pair in {cond, extra, |
1162 | // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent |
1163 | // on constant affines expression may now be determined. |
1164 | auto allTidLvls = |
1165 | llvm::concat<TensorLevel>(Ranges&: tidLvls, Ranges: llvm::make_first_range(c&: affineTidLvls)); |
1166 | for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) { |
1167 | if (tid != env.merger().getOutTensorID() && |
1168 | tid != env.merger().getSynTensorID()) |
1169 | genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); |
1170 | } |
1171 | |
1172 | return std::make_pair(x&: loop, y&: isSingleCond); |
1173 | } |
1174 | |
1175 | /// Ends a single loop in current sequence. Returns new values for needsUniv. |
1176 | static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, |
1177 | LatPointId li, bool needsUniv, bool isSingleCond) { |
1178 | // Either a for-loop or a while-loop that iterates over a slice. |
1179 | if (isSingleCond) { |
1180 | // Any iteration creates a valid lex insert. |
1181 | if (env.isReduc() && env.isValidLexInsert()) |
1182 | env.updateValidLexInsert(val: constantI1(rewriter, env.op().getLoc(), true)); |
1183 | } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { |
1184 | // End a while-loop. |
1185 | finalizeWhileOp(env, builder&: rewriter, needsUniv); |
1186 | } else { |
1187 | needsUniv = false; |
1188 | } |
1189 | env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) { |
1190 | env.emitter().exitCurrentLoop(rewriter, loc: env.op().getLoc(), reduc); |
1191 | return std::nullopt; |
1192 | }); |
1193 | return needsUniv; |
1194 | } |
1195 | |
1196 | /// Ends a loop sequence at given level. |
1197 | static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, |
1198 | unsigned at) { |
1199 | assert(!env.getLoopVar(at)); |
1200 | env.emitter().exitCurrentLoopSeq(builder, loc: env.op().getLoc()); |
1201 | // Unmark bookkeeping of invariants and loop index. |
1202 | genInvariants(env, builder, exp, curr: at, /*isStart=*/false); |
1203 | // Finalize access pattern expansion for sparse tensor output. |
1204 | genExpand(env, builder, curr: at, /*isStart=*/false); |
1205 | } |
1206 | |
1207 | /// Recursively generates code while computing iteration lattices in order |
1208 | /// to manage the complexity of implementing co-iteration over unions |
1209 | /// and intersections of sparse iterations spaces. |
1210 | static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, |
1211 | LoopId curr) { |
1212 | assert(curr == env.getCurrentDepth()); |
1213 | |
1214 | // At each leaf, assign remaining tensor (sub)expression to output tensor. |
1215 | if (curr == env.getLoopNum()) { |
1216 | Value rhs = genExp(env, rewriter, e: exp); |
1217 | genTensorStore(env, builder&: rewriter, exp, rhs); |
1218 | return; |
1219 | } |
1220 | |
1221 | // Construct iteration lattices for current loop index. |
1222 | const LatSetId lts = |
1223 | env.merger().optimizeSet(s: env.merger().buildLattices(e: exp, i: curr)); |
1224 | |
1225 | // Start a loop sequence. |
1226 | bool needsUniv = startLoopSeq(env, builder&: rewriter, exp, curr, lts); |
1227 | |
1228 | // Emit a loop for every lattice point L0 >= Li in this loop sequence. |
1229 | // We cannot change this to `for (const LatPointId li : env.set(lts))` |
1230 | // because the loop body causes data-movement which invalidates |
1231 | // the iterator. |
1232 | const unsigned lsize = env.set(lts).size(); |
1233 | for (unsigned i = 0; i < lsize; i++) { |
1234 | const LatPointId li = env.set(lts)[i]; |
1235 | // Start a loop. |
1236 | auto [loop, isSingleCond] = startLoop(env, builder&: rewriter, curr, li, needsUniv); |
1237 | |
1238 | // Visit all lattices points with Li >= Lj to generate the |
1239 | // loop-body, possibly with if statements for coiteration. |
1240 | Value redInput = env.getReduc(); |
1241 | Value cntInput = env.getExpandCount(); |
1242 | Value insInput = env.getInsertionChain(); |
1243 | Value validIns = env.getValidLexInsert(); |
1244 | // We cannot change this to `for (const LatPointId lj : env.set(lts))` |
1245 | // because the loop body causes data-movement which invalidates the |
1246 | // iterator. |
1247 | for (unsigned j = 0; j < lsize; j++) { |
1248 | const LatPointId lj = env.set(lts)[j]; |
1249 | const ExprId ej = env.lat(l: lj).exp; |
1250 | if (li == lj || env.merger().latGT(p0: li, p1: lj)) { |
1251 | // Recurse into body of each branch. |
1252 | if (!isSingleCond) { |
1253 | scf::IfOp ifOp = genIf(env, rewriter, curr, lj); |
1254 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
1255 | endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); |
1256 | } else { |
1257 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
1258 | } |
1259 | } |
1260 | } |
1261 | |
1262 | // End a loop. |
1263 | needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond); |
1264 | } |
1265 | |
1266 | // End a loop sequence. |
1267 | endLoopSeq(env, builder&: rewriter, exp, at: curr); |
1268 | assert(curr == env.getCurrentDepth()); |
1269 | } |
1270 | |
1271 | /// Converts the result computed by the sparse kernel into the required form. |
1272 | static void genResult(CodegenEnv &env, RewriterBase &rewriter) { |
1273 | linalg::GenericOp op = env.op(); |
1274 | OpOperand *lhs = op.getDpsInitOperand(0); |
1275 | Value tensor = lhs->get(); |
1276 | Type resType = tensor.getType(); |
1277 | if (getSparseTensorEncoding(type: resType)) { |
1278 | // The sparse tensor rematerializes from the original sparse tensor's |
1279 | // underlying sparse storage format. For an insertion chain, the |
1280 | // tensor materializes from the chain with 'hasInserts' enabled. |
1281 | bool hasInserts = false; |
1282 | if (Value chain = env.getInsertionChain()) { |
1283 | hasInserts = true; |
1284 | tensor = chain; |
1285 | } |
1286 | rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts); |
1287 | } else { |
1288 | // To rematerialize an non-annotated tensor, simply load it |
1289 | // from the bufferized value. |
1290 | Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()]; |
1291 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); |
1292 | } |
1293 | } |
1294 | |
1295 | //===----------------------------------------------------------------------===// |
1296 | // Sparsifier rewriting methods. |
1297 | //===----------------------------------------------------------------------===// |
1298 | |
1299 | namespace { |
1300 | |
1301 | /// Sparse rewriting rule for generic Lingalg operation. |
1302 | struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { |
1303 | public: |
1304 | GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) |
1305 | : OpRewritePattern<linalg::GenericOp>(context), options(o) {} |
1306 | |
1307 | LogicalResult matchAndRewrite(linalg::GenericOp op, |
1308 | PatternRewriter &rewriter) const override { |
1309 | // Only accept single output operations with pure tensor semantics. |
1310 | if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics()) |
1311 | return failure(); |
1312 | |
1313 | // Only accept trivial affine indices. |
1314 | if (hasNonTrivialAffineOnSparseOut(op)) |
1315 | return failure(); |
1316 | |
1317 | // Only accept scheduled loops. |
1318 | if (!op->hasAttr("sorted" )) { |
1319 | return rewriter.notifyMatchFailure( |
1320 | op, "Loops not yet scheduled, try run --sparse-reinterpret-map " |
1321 | "before sparsification." ); |
1322 | } |
1323 | |
1324 | // Must have been demapped as well if the generic op is sorted. |
1325 | assert(!hasAnyNonIdentityOperandsOrResults(op)); |
1326 | |
1327 | // Sets up a code generation environment. |
1328 | const unsigned numTensors = op->getNumOperands(); |
1329 | const unsigned numLoops = op.getNumLoops(); |
1330 | bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0; |
1331 | // If we have indexing map like (d0) -> (0, d0), there might be more |
1332 | // levels then loops because of the constant index, that means we can not |
1333 | // use numLoops as the upper bound for ranks of all tensors. |
1334 | // TODO: Constant indices are currently not support on sparse tensor, but |
1335 | // are allowed in non-annotated dense tensor. Support it, it would be |
1336 | // required for sparse tensor slice rank reducing too. |
1337 | Level maxLvlRank = 0; |
1338 | for (auto operand : op.getOperands()) { |
1339 | if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) { |
1340 | maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); |
1341 | } |
1342 | } |
1343 | |
1344 | // Detects sparse annotations and translates the per-level sparsity |
1345 | // information for all tensors to loop indices in the kernel. |
1346 | CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); |
1347 | if (!findSparseAnnotations(env, idxReducBased: needIdxRed)) |
1348 | return failure(); |
1349 | |
1350 | // Only standard reduction operations (add, sub, or, xor) that can be |
1351 | // sparsified by merely reducing the stored values are admissible. More |
1352 | // elaborate reduction operations (such as mul, and, min, max) would need |
1353 | // to know whether implicit zeros occur as well. They can still be |
1354 | // implemented with a custom reduction operation, accepted here as well. |
1355 | if (op.getNumReductionLoops() > 0) { |
1356 | Operation *yield = op.getRegion().front().getTerminator(); |
1357 | assert(isa<linalg::YieldOp>(yield)); |
1358 | Operation *redop = yield->getOperand(idx: 0).getDefiningOp(); |
1359 | if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) && |
1360 | !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) && |
1361 | !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) && |
1362 | !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) && |
1363 | !isa<ReduceOp>(redop)) { |
1364 | return failure(); |
1365 | } |
1366 | } |
1367 | |
1368 | // Constructs the tensor expressions tree from `op`, returns failure if the |
1369 | // tree can not be built or the tensor expression is inadmissible. |
1370 | if (failed(result: env.initTensorExp())) |
1371 | return failure(); |
1372 | |
1373 | // Recursively generates code if admissible. |
1374 | env.startEmit(emitStrategy: options.sparseEmitStrategy); |
1375 | genBuffers(env, builder&: rewriter); |
1376 | // TODO: Constant affine expression should be handled differently when using |
1377 | // slice-based codegen, it does not matter now because we already reject the |
1378 | // constant expression at an earlier stage. |
1379 | genInitConstantDenseAddress(env, rewriter); |
1380 | genStmt(env, rewriter, exp: env.getExprId(), curr: 0); |
1381 | genResult(env, rewriter); |
1382 | return success(); |
1383 | } |
1384 | |
1385 | private: |
1386 | /// Options to control sparse code generation. |
1387 | SparsificationOptions options; |
1388 | }; |
1389 | |
1390 | } // namespace |
1391 | |
1392 | /// Populates the given patterns list with rewriting rules required for |
1393 | /// the sparsification of linear algebra operations. |
1394 | void mlir::populateSparsificationPatterns( |
1395 | RewritePatternSet &patterns, const SparsificationOptions &options) { |
1396 | patterns.add<GenericOpSparsifier>(arg: patterns.getContext(), args: options); |
1397 | } |
1398 | |