1 | //===- Sparsification.cpp - Implementation of sparsification --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements converting sparse tensor types to actual sparse code. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "Utils/CodegenEnv.h" |
14 | #include "Utils/CodegenUtils.h" |
15 | #include "Utils/LoopEmitter.h" |
16 | |
17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
20 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
22 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
23 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
24 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
25 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
26 | #include "mlir/Dialect/SCF/IR/SCF.h" |
27 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
28 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
29 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
30 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
31 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
32 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
33 | #include "mlir/IR/AffineExprVisitor.h" |
34 | #include "mlir/IR/Matchers.h" |
35 | #include "mlir/IR/TensorEncoding.h" |
36 | #include "llvm/ADT/SmallBitVector.h" |
37 | |
38 | #include <optional> |
39 | |
40 | using namespace mlir; |
41 | using namespace mlir::sparse_tensor; |
42 | |
43 | //===----------------------------------------------------------------------===// |
44 | // Sparsifier analysis methods. |
45 | //===----------------------------------------------------------------------===// |
46 | |
47 | /// Returns true iff affine expression is invariant. Sets the |
48 | /// parameter `isCurrentLoop` when expression just became invariant. |
49 | static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) { |
50 | switch (a.getKind()) { |
51 | case AffineExprKind::DimId: { |
52 | const LoopId i = cast<AffineDimExpr>(Val&: a).getPosition(); |
53 | if (i + 1 == curr) { |
54 | isCurrentLoop = true; |
55 | return true; // becomes invariant at current loop |
56 | } |
57 | return i < curr; // invariant when already generated |
58 | } |
59 | case AffineExprKind::Add: |
60 | case AffineExprKind::Mul: { |
61 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
62 | return isInvariantAffine(a: binOp.getLHS(), curr, isCurrentLoop) && |
63 | isInvariantAffine(a: binOp.getRHS(), curr, isCurrentLoop); |
64 | } |
65 | default: { |
66 | assert(isa<AffineConstantExpr>(a)); |
67 | return true; |
68 | } |
69 | } |
70 | } |
71 | |
72 | /// Helper method to inspect affine expressions. Rejects cases where the |
73 | /// same index is used more than once. Also rejects compound affine |
74 | /// expressions in sparse dimensions. |
75 | static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, |
76 | LevelType lt, bool setLvlFormat = true) { |
77 | switch (a.getKind()) { |
78 | case AffineExprKind::DimId: { |
79 | const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition()); |
80 | if (!isUndefLT(lt: merger.getLvlType(t: tid, i: idx))) |
81 | return false; // used more than once |
82 | if (setLvlFormat) |
83 | merger.setLevelAndType(t: tid, i: idx, lvl, lt); |
84 | return true; |
85 | } |
86 | case AffineExprKind::Add: |
87 | case AffineExprKind::Mul: |
88 | case AffineExprKind::Constant: { |
89 | assert(lt.hasDenseSemantic()); |
90 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: a)) { |
91 | // We do not set dim level format for affine expression like d0 + d1 on |
92 | // either loop index at d0 or d1. We continue the recursion merely to |
93 | // check whether current affine is admissible or not. |
94 | return findAffine(merger, tid, lvl, a: binOp.getLHS(), lt, setLvlFormat: false) && |
95 | findAffine(merger, tid, lvl, a: binOp.getRHS(), lt, setLvlFormat: false); |
96 | } |
97 | // Falls through when it is a constant Affine |
98 | return true; |
99 | } |
100 | default: |
101 | return false; |
102 | } |
103 | } |
104 | |
105 | /// Helper method to inspect affine expressions for index variable reduction |
106 | /// based codegen. It finds the dependent index set for all tensor levels in the |
107 | /// current expression we are generating. |
108 | /// |
109 | /// For example, when handling A[i+j][j+k], we build the two way mapping in |
110 | /// merger between (tensor, level) pairs and their dependent index variable set: |
111 | /// A_0 <=> [i, j] and A_1 <=> [j, k] |
112 | /// |
113 | /// It rejects cases (returns false) |
114 | /// 1st, when the same index is used more than once, e.g., A[i+j][i] |
115 | /// 2nd, when multiplication is used in the non-trivial index expression. |
116 | /// 3rd, when a constant operand is used in the non-trivial index expression. |
117 | /// |
118 | /// TODO: constant should be easy to handle. |
119 | static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, |
120 | AffineExpr a, LevelType lt, bool isSubExp = false, |
121 | int64_t coefficient = 1) { |
122 | switch (a.getKind()) { |
123 | case AffineExprKind::DimId: { |
124 | // Only allow positive coefficients on AffineDimExpr. |
125 | if (coefficient <= 0) |
126 | return false; |
127 | |
128 | const LoopId idx = merger.makeLoopId(i: cast<AffineDimExpr>(Val&: a).getPosition()); |
129 | if (!isUndefLT(lt: merger.getLvlType(t: tensor, i: idx))) |
130 | return false; // used more than once, e.g., A[i][i] |
131 | |
132 | // TODO: Generalizes the following two cases. A[i] (with trivial index |
133 | // expression) can be treated as a special affine index expression. We do |
134 | // not necessarily need to differentiate them. |
135 | if (!isSubExp) { |
136 | assert(coefficient == 1); |
137 | merger.setLevelAndType(t: tensor, i: idx, lvl, lt); |
138 | } |
139 | |
140 | if (isSubExp) { |
141 | // The current loops appears in more than one affine expressions on the |
142 | // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is |
143 | // used twice. |
144 | if (merger.hasDependentLvl(i: idx, t: tensor)) { |
145 | // TODO: This can be supported by coiterate slices if the loop idx is |
146 | // appeared on affine index for different tensor, or take slice on |
147 | // multiple dimensions when it is on the same tensor. |
148 | // E.g., |
149 | // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] |
150 | // d0_1 = getNextSliceOffset t0 along lvl0 |
151 | // d0_2 = getNextSliceOffset t1 along lvl0 |
152 | // if d0_1 == d0_2 then d0 = d0_1 = d0_1 |
153 | // else increase min(d0_1, d0_2). |
154 | return false; |
155 | } |
156 | merger.setLoopDependentTensorLevel(i: idx, t: tensor, lvl, lt, coefficient); |
157 | } |
158 | return true; |
159 | } |
160 | case AffineExprKind::Constant: |
161 | case AffineExprKind::Mul: { |
162 | // TODO: Support index expression like `2 * d0`, we now only support more |
163 | // complicated cases like `2 * d0 + d1`. |
164 | if (!isSubExp) |
165 | return false; |
166 | |
167 | // TODO: Support Constant AffineExp for slice-based codegen |
168 | if (isa<AffineConstantExpr>(Val: a)) |
169 | llvm_unreachable("Not yet implemented" ); |
170 | |
171 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
172 | auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); |
173 | if (isa<AffineConstantExpr>(Val: rhs)) |
174 | std::swap(a&: lhs, b&: rhs); |
175 | // Must be in form of `constant * d`. |
176 | assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs)); |
177 | int64_t coefficient = cast<AffineConstantExpr>(Val&: lhs).getValue(); |
178 | return findDepIdxSet(merger, tensor, lvl, a: rhs, lt, isSubExp, coefficient); |
179 | } |
180 | case AffineExprKind::Add: { |
181 | auto binOp = cast<AffineBinaryOpExpr>(Val&: a); |
182 | return findDepIdxSet(merger, tensor, lvl, a: binOp.getLHS(), lt, isSubExp: true) && |
183 | findDepIdxSet(merger, tensor, lvl, a: binOp.getRHS(), lt, isSubExp: true); |
184 | } |
185 | default: |
186 | return false; |
187 | } |
188 | } |
189 | |
190 | /// Gets the total number of compound affine expressions in the |
191 | /// `getMatchingIndexingMap` for the given tensor. For the following inputs: |
192 | /// |
193 | /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed) |
194 | /// |
195 | /// Returns 1 (because the first level is compressed and its corresponding |
196 | /// indexing-expression is `d0 + d1`) |
197 | static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, |
198 | Value tensor) { |
199 | // The `tensor` is not guaranteed to have `RankedTensorType`, therefore |
200 | // we can't use `getRankedTensorType`/`getSparseTensorType` here. |
201 | // However, we don't need to handle `StorageSpecifierType`, so we |
202 | // can use `SparseTensorType` once we guard against non-tensors. |
203 | const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); |
204 | if (!rtp) |
205 | return 0; |
206 | const SparseTensorType stt(rtp); |
207 | |
208 | const Level lvlRank = stt.getLvlRank(); |
209 | const auto exprs = map.getResults(); |
210 | assert(static_cast<Dimension>(exprs.size()) == lvlRank && |
211 | "AffineMap does not have dimension-rank many results" ); |
212 | unsigned num = 0; |
213 | for (Level l = 0; l < lvlRank; l++) { |
214 | if (!isa<AffineDimExpr>(Val: exprs[l]) && !stt.getLvlType(l).hasDenseSemantic()) |
215 | num++; |
216 | } |
217 | return num; |
218 | } |
219 | |
220 | /// Gets the total number of sparse levels with compound affine |
221 | /// expressions, summed over all operands of the `GenericOp`. |
222 | static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { |
223 | unsigned num = 0; |
224 | for (OpOperand &t : op->getOpOperands()) |
225 | num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t), |
226 | t.get()); |
227 | return num; |
228 | } |
229 | |
230 | // Returns true iff output has nontrivial affine indices. |
231 | static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { |
232 | OpOperand *out = op.getDpsInitOperand(0); |
233 | if (getSparseTensorType(val: out->get()).isAllDense()) |
234 | return false; |
235 | return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out), |
236 | out->get()); |
237 | } |
238 | |
239 | /// Helper method to inspect sparse encodings in the tensor types. |
240 | /// Fills the per-dimension sparsity information for all tensors. |
241 | /// Returns true if the sparse annotations and affine subscript |
242 | /// expressions of all tensors are admissible. Returns false if |
243 | /// no annotations are found or inadmissible constructs occur. |
244 | /// We currently support two different ways to handle non-trivial index |
245 | /// expression on sparse tensors, and they accept different affine expressions. |
246 | /// When using dependent index reducton-based approach, it currently only |
247 | /// supports affine addition index expression. |
248 | static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { |
249 | bool annotated = false; |
250 | for (OpOperand &t : env.op()->getOpOperands()) { |
251 | const TensorId tid = env.makeTensorId(t.getOperandNumber()); |
252 | const auto map = env.op().getMatchingIndexingMap(&t); |
253 | const auto enc = getSparseTensorEncoding(t.get().getType()); |
254 | if (enc) |
255 | annotated = true; |
256 | const Level lvlRank = map.getNumResults(); |
257 | assert(!enc || lvlRank == enc.getLvlRank()); |
258 | assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank); |
259 | // We only need to do index reduction if there is at least one |
260 | // non-trivial index expression on sparse levels. If all non-trivial |
261 | // index expression is on dense levels, we can efficiently rely on |
262 | // the random access to locate the element. |
263 | bool needIdxReduc = |
264 | enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0; |
265 | // If then current tensor being inspected requires affine index, it need |
266 | // to be sliced. |
267 | for (Level l = 0; l < lvlRank; l++) { |
268 | const AffineExpr a = map.getResult(l); |
269 | const LevelType lt = enc.getLvlType(l); |
270 | if (idxReducBased && needIdxReduc) { |
271 | if (!findDepIdxSet(env.merger(), tid, l, a, lt)) |
272 | return false; // inadmissible affine expression |
273 | } else { |
274 | if (!findAffine(env.merger(), tid, l, a, lt)) |
275 | return false; // inadmissible affine expression |
276 | } |
277 | } |
278 | } |
279 | return annotated; |
280 | } |
281 | |
282 | //===----------------------------------------------------------------------===// |
283 | // Sparsifier synthesis methods (statements and expressions). |
284 | //===----------------------------------------------------------------------===// |
285 | |
286 | /// Local bufferization of all dense and sparse data structures. |
287 | static void genBuffers(CodegenEnv &env, OpBuilder &builder) { |
288 | linalg::GenericOp op = env.op(); |
289 | Location loc = op.getLoc(); |
290 | assert(op.getNumOperands() == op.getNumDpsInputs() + 1); |
291 | |
292 | SmallVector<Range, 4> loopRange = |
293 | llvm::cast<linalg::LinalgOp>(op.getOperation()) |
294 | .createLoopRanges(builder, loc); |
295 | |
296 | env.emitter().initializeLoopEmit( |
297 | builder, loc, |
298 | /// Generates buffer for the output tensor. |
299 | /// Note that all sparse kernels assume that when all elements are written |
300 | /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized |
301 | /// to all zeroes and only nonzeroes values are computed and written out. |
302 | /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used |
303 | /// for the updates and no assumption on the original contents of the |
304 | /// output buffer is necessary. |
305 | [&op](OpBuilder &builder, Location loc, Value memref, |
306 | Value tensor) -> Value { |
307 | // Must not be a sparse tensor. |
308 | assert(!getSparseTensorEncoding(tensor.getType())); |
309 | // Two output tensor references should point to the same object. |
310 | OpOperand *lhs = op.getDpsInitOperand(0); |
311 | assert(lhs->get() == tensor); |
312 | // An output tensor can simply materialize from the buffer of the tensor |
313 | // that appears in the outs() clause. For updates, this has the |
314 | // advantage that only the nonzero value are involved in the |
315 | // computation, keeping the operation O(nnz). In all other cases, we are |
316 | // forced to zero out the buffer to enforce the assumption above, which |
317 | // may negatively impact running complexity (viz. O(n^2 + nnz) vs. |
318 | // O(nnz) for matrices). |
319 | // TODO: use better analysis to avoid zeroing out the buffer? |
320 | bool isInit = op.isInitTensor(lhs); |
321 | Value init = memref; |
322 | if (!isInit) { |
323 | Value zero = constantZero(builder, loc, |
324 | tp: getElementTypeOrSelf(type: tensor.getType())); |
325 | builder.create<linalg::FillOp>(loc, ValueRange{zero}, |
326 | ValueRange{init}); |
327 | } |
328 | return init; |
329 | }, |
330 | [&loopRange](OpBuilder &b, Location loc, Level l) { |
331 | assert(l < loopRange.size()); |
332 | return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size); |
333 | }); |
334 | } |
335 | |
336 | /// Generates index for load/store on sparse tensor. |
337 | static Value genIndex(CodegenEnv &env, OpOperand *t) { |
338 | const auto map = env.op().getMatchingIndexingMap(t); |
339 | const auto stt = getSparseTensorType(val: t->get()); |
340 | const Level lvlRank = stt.getLvlRank(); |
341 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
342 | const AffineExpr a = map.getResult(lvlRank - 1); |
343 | assert(a.getKind() == AffineExprKind::DimId); |
344 | const LoopId idx = env.makeLoopId(i: cast<AffineDimExpr>(Val: a).getPosition()); |
345 | return env.getLoopVar(i: idx); |
346 | } |
347 | |
348 | /// Generates subscript for load/store on a dense or sparse tensor. |
349 | static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
350 | SmallVectorImpl<Value> &args) { |
351 | const Location loc = env.op().getLoc(); |
352 | const TensorId tid = env.makeTensorId(t: t->getOperandNumber()); |
353 | const auto map = env.op().getMatchingIndexingMap(t); |
354 | const auto stt = getSparseTensorType(val: t->get()); |
355 | if (stt.hasEncoding()) { |
356 | // For sparse tensors we only push the last-level's position onto `args`. |
357 | const auto pos = env.emitter().getValPosits(tid); |
358 | assert(!pos.empty()); |
359 | args.append(RHS: pos); |
360 | // Simply returns the tensor to extract value using iterators. |
361 | if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator) |
362 | return t->get(); |
363 | } else { |
364 | // For dense tensors we push all level's coordinates onto `args`. |
365 | const Level lvlRank = stt.getLvlRank(); |
366 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
367 | for (Level l = 0; l < lvlRank; l++) { |
368 | const auto lvlExpr = map.getResult(l); |
369 | const auto lvlCrd = env.emitter().genAffine(builder, loc, a: lvlExpr); |
370 | args.push_back(Elt: lvlCrd); |
371 | } |
372 | } |
373 | return env.emitter().getValBuffer()[tid]; |
374 | } |
375 | |
376 | /// Generates insertion code to implement dynamic tensor load. |
377 | static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, |
378 | OpOperand *t) { |
379 | linalg::GenericOp op = env.op(); |
380 | Location loc = op.getLoc(); |
381 | // Direct lexicographic coordinate order, tensor loads as zero. |
382 | if (!env.isExpand()) { |
383 | Type tp = getElementTypeOrSelf(type: t->get().getType()); |
384 | return constantZero(builder, loc, tp); |
385 | } |
386 | // Load from expanded access pattern. |
387 | Value index = genIndex(env, t); |
388 | return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index); |
389 | } |
390 | |
391 | /// Generates insertion code to implement dynamic tensor load for reduction. |
392 | static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, |
393 | OpOperand *t) { |
394 | linalg::GenericOp op = env.op(); |
395 | Location loc = op.getLoc(); |
396 | Value identity = env.getCustomRedId(); |
397 | // Direct lexicographic coordinate order, tensor loads as identity. |
398 | if (!env.isExpand()) |
399 | return identity; |
400 | // Load from expanded access pattern if filled, identity otherwise. |
401 | Value values = env.getExpandValues(); |
402 | Value filled = env.getExpandFilled(); |
403 | Value index = genIndex(env, t); |
404 | Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); |
405 | Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index); |
406 | return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity); |
407 | } |
408 | |
409 | static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, |
410 | Value sparseOut, ValueRange ivs, Value v) { |
411 | scf::IfOp condInsert = |
412 | builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true); |
413 | // True branch. |
414 | builder.setInsertionPointToStart(condInsert.thenBlock()); |
415 | Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs); |
416 | builder.create<scf::YieldOp>(loc, res); |
417 | // False branch. |
418 | builder.setInsertionPointToStart(condInsert.elseBlock()); |
419 | builder.create<scf::YieldOp>(loc, sparseOut); |
420 | // Value assignment. |
421 | builder.setInsertionPointAfter(condInsert); |
422 | return condInsert.getResult(0); |
423 | } |
424 | |
425 | /// Generates insertion code to implement dynamic tensor store. |
426 | static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
427 | Value rhs) { |
428 | linalg::GenericOp op = env.op(); |
429 | Location loc = op.getLoc(); |
430 | // Direct insertion in lexicographic coordinate order. |
431 | if (!env.isExpand()) { |
432 | const LoopId numLoops = op.getRank(t); |
433 | // Retrieves the first `numLoop` induction variables. |
434 | SmallVector<Value> ivs = llvm::to_vector(Range: llvm::drop_end( |
435 | RangeOrContainer: env.emitter().getLoopIVsRange(), N: env.getCurrentDepth() - numLoops)); |
436 | Value chain = env.getInsertionChain(); |
437 | if (env.isValidLexInsert()) { |
438 | // Generates runtime check for a valid lex during reduction, |
439 | // to avoid inserting the identity value for empty reductions. |
440 | // if (validLexInsert) then |
441 | // insert(rhs) into chain |
442 | // return updated chain |
443 | // else |
444 | // return unmodified chain |
445 | Value out = genConditionalInsert(loc, builder, cond: env.getValidLexInsert(), |
446 | sparseOut: chain, ivs, v: rhs); |
447 | env.updateInsertionChain(chain: out); |
448 | } else { |
449 | Value sparseOut; |
450 | if (!hasAnySparseType(env.op().getInputs().getTypes())) { |
451 | // This is an all-dense -> sparse kernel, test rhs != 0 before |
452 | // insertion. |
453 | Value nz = genIsNonzero(builder, loc, v: rhs); |
454 | sparseOut = genConditionalInsert(loc, builder, cond: nz, sparseOut: chain, ivs, v: rhs); |
455 | } else { |
456 | sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs); |
457 | } |
458 | // Generates regular insertion chain. |
459 | env.updateInsertionChain(chain: sparseOut); |
460 | } |
461 | return; |
462 | } |
463 | // Generates insertion code along expanded access pattern. |
464 | // if (!expFilled[i]) then |
465 | // expFilled[i] = true |
466 | // expAdded[inserts++] = i |
467 | // endif |
468 | // values[i] = rhs |
469 | Value values = env.getExpandValues(); |
470 | Value filled = env.getExpandFilled(); |
471 | Value added = env.getExpandAdded(); |
472 | Value count = env.getExpandCount(); |
473 | Value index = genIndex(env, t); |
474 | Value fval = constantI1(builder, loc, b: false); |
475 | Value tval = constantI1(builder, loc, b: true); |
476 | // If statement. |
477 | Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); |
478 | Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
479 | isFilled, fval); |
480 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, |
481 | /*else=*/true); |
482 | // True branch. |
483 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
484 | builder.create<memref::StoreOp>(loc, tval, filled, index); |
485 | builder.create<memref::StoreOp>(loc, index, added, count); |
486 | Value one = constantIndex(builder, loc, i: 1); |
487 | Value add = builder.create<arith::AddIOp>(loc, count, one); |
488 | builder.create<scf::YieldOp>(loc, add); |
489 | // False branch. |
490 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
491 | builder.create<scf::YieldOp>(loc, count); |
492 | builder.setInsertionPointAfter(ifOp); |
493 | // Value assignment. |
494 | env.updateExpandCount(count: ifOp.getResult(0)); |
495 | builder.create<memref::StoreOp>(loc, rhs, values, index); |
496 | } |
497 | |
498 | /// Generates a load on a dense or sparse tensor. |
499 | static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { |
500 | // Test if the load was hoisted to a higher loop nest. |
501 | Value val = env.exp(e: exp).val; |
502 | if (val) |
503 | return val; |
504 | // Get tensor operand. |
505 | linalg::GenericOp op = env.op(); |
506 | Location loc = op.getLoc(); |
507 | OpOperand *t = &op->getOpOperand(env.exp(e: exp).tensor); |
508 | // Fold binary-valued tensor into explicit value. |
509 | const auto stt = getSparseTensorType(val: t->get()); |
510 | if (auto explVal = stt.getExplicitVal()) |
511 | return genValFromAttr(builder, loc, explVal); |
512 | // Load during insertion. |
513 | if (env.isSparseOutput(o: t)) { |
514 | if (env.isCustomReduc()) |
515 | return genInsertionLoadReduce(env, builder, t); |
516 | return genInsertionLoad(env, builder, t); |
517 | } |
518 | |
519 | // Actual load. |
520 | SmallVector<Value> args; |
521 | Value ptr = genSubscript(env, builder, t, args); |
522 | if (llvm::isa<TensorType>(Val: ptr.getType())) { |
523 | assert(env.options().sparseEmitStrategy == |
524 | SparseEmitStrategy::kSparseIterator); |
525 | return builder.create<ExtractValOp>(loc, ptr, llvm::getSingleElement(args)); |
526 | } |
527 | return builder.create<memref::LoadOp>(loc, ptr, args); |
528 | } |
529 | |
530 | /// Generates a store on a dense or sparse tensor. |
531 | static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
532 | Value rhs) { |
533 | // Only unary and binary are allowed to return an uninitialized rhs |
534 | // to indicate missing output. Or otherwise a custom reduction that |
535 | // received no value to accumulate. |
536 | if (!rhs) { |
537 | assert(env.exp(exp).kind == TensorExp::Kind::kUnary || |
538 | env.exp(exp).kind == TensorExp::Kind::kBinary || |
539 | env.exp(exp).kind == TensorExp::Kind::kReduce); |
540 | return; |
541 | } |
542 | // Test if this is a scalarized reduction. |
543 | if (env.isReduc()) { |
544 | env.updateReduc(val: rhs); |
545 | return; |
546 | } |
547 | // Regular store. |
548 | linalg::GenericOp op = env.op(); |
549 | Location loc = op.getLoc(); |
550 | OpOperand *t = op.getDpsInitOperand(0); |
551 | if (!env.isSparseOutput(o: t)) { |
552 | SmallVector<Value> args; |
553 | Value ptr = genSubscript(env, builder, t, args); |
554 | builder.create<memref::StoreOp>(loc, rhs, ptr, args); |
555 | return; |
556 | } |
557 | // Store during sparse insertion. |
558 | if (env.exp(e: exp).kind != TensorExp::Kind::kSelect) { |
559 | genInsertionStore(env, builder, t, rhs); |
560 | return; |
561 | } |
562 | // Select operation insertion. |
563 | Value chain = env.getInsertionChain(); |
564 | scf::IfOp ifOp = |
565 | builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true); |
566 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
567 | // Existing value was preserved to be used here. |
568 | assert(env.exp(exp).val); |
569 | Value v0 = env.exp(e: exp).val; |
570 | genInsertionStore(env, builder, t, rhs: v0); |
571 | env.merger().clearExprValue(e: exp); |
572 | // Yield modified insertion chain along true branch. |
573 | Value mchain = env.getInsertionChain(); |
574 | builder.create<scf::YieldOp>(op.getLoc(), mchain); |
575 | // Yield original insertion chain along false branch. |
576 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
577 | builder.create<scf::YieldOp>(loc, chain); |
578 | // Done with if statement. |
579 | env.updateInsertionChain(chain: ifOp->getResult(0)); |
580 | builder.setInsertionPointAfter(ifOp); |
581 | } |
582 | |
583 | /// Generates an invariant value. |
584 | inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { |
585 | return env.exp(e: exp).val; |
586 | } |
587 | |
588 | /// Semi-ring branches are simply inlined by the sparsifier. Prior |
589 | /// analysis has verified that all computations are "local" to the inlined |
590 | /// branch or otherwise invariantly defined outside the loop nest, with the |
591 | /// exception of index computations, which need to be relinked to actual |
592 | /// inlined cloned code. |
593 | static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, |
594 | Value e) { |
595 | if (auto arg = dyn_cast<BlockArgument>(Val&: e)) { |
596 | // Direct arguments of the original linalg op must be converted |
597 | // into dense tensor loads. Note that we should not encounter |
598 | // anything else. This needs to be verified by semi-ring ops. |
599 | linalg::GenericOp op = env.op(); |
600 | if (arg.getOwner()->getParentOp() == op) { |
601 | const TensorId tid = env.makeTensorId(t: arg.getArgNumber()); |
602 | OpOperand *t = &op->getOpOperand(tid); |
603 | assert(!getSparseTensorType(t->get()).hasEncoding()); // dense! |
604 | SmallVector<Value> args; |
605 | Value ptr = genSubscript(env, builder&: rewriter, t, args); |
606 | return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); |
607 | } |
608 | } else if (Operation *def = e.getDefiningOp()) { |
609 | // Handle index computation. |
610 | if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) |
611 | return env.getLoopVar(i: env.makeLoopId(i: indexOp.getDim())); |
612 | // When still defined in new body, recurse into operands. |
613 | if (def->getBlock() == block) { |
614 | rewriter.setInsertionPoint(def); |
615 | for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { |
616 | rewriter.modifyOpInPlace(root: def, callable: [&]() { |
617 | def->setOperand( |
618 | idx: i, value: relinkBranch(env, rewriter, block, e: def->getOperand(idx: i))); |
619 | }); |
620 | } |
621 | } |
622 | } |
623 | return e; |
624 | } |
625 | |
626 | /// Recursively generates tensor expression. |
627 | static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { |
628 | if (e == ::mlir::sparse_tensor::detail::kInvalidId) |
629 | return Value(); |
630 | |
631 | linalg::GenericOp op = env.op(); |
632 | Location loc = op.getLoc(); |
633 | const TensorExp &exp = env.exp(e); |
634 | const auto kind = exp.kind; |
635 | if (kind == TensorExp::Kind::kTensor) |
636 | return genTensorLoad(env, builder&: rewriter, exp: e); |
637 | if (kind == TensorExp::Kind::kInvariant) |
638 | return genInvariantValue(env, exp: e); |
639 | if (kind == TensorExp::Kind::kLoopVar) |
640 | return env.getLoopVar(i: exp.loop); |
641 | |
642 | if (kind == TensorExp::Kind::kReduce) |
643 | env.startCustomReduc(exp: e); // enter custom |
644 | |
645 | // If either lhs/rhs is a synthetic zero, we infer the type for the zero value |
646 | // based on the type of the other operand. |
647 | Value v0, v1; |
648 | if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && |
649 | env.exp(e: exp.children.e0).kind == TensorExp::Kind::kSynZero) { |
650 | v1 = genExp(env, rewriter, e: exp.children.e1); |
651 | v0 = constantZero(builder&: rewriter, loc, tp: v1.getType()); |
652 | } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId && |
653 | env.exp(e: exp.children.e1).kind == TensorExp::Kind::kSynZero) { |
654 | v0 = genExp(env, rewriter, e: exp.children.e0); |
655 | v1 = constantZero(builder&: rewriter, loc, tp: v0.getType()); |
656 | } else { |
657 | v0 = genExp(env, rewriter, e: exp.children.e0); |
658 | v1 = genExp(env, rewriter, e: exp.children.e1); |
659 | } |
660 | |
661 | Value ee; |
662 | if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) { |
663 | // custom reduce did not receive a value |
664 | } else { |
665 | ee = env.merger().buildExp(rewriter, loc, e, v0, v1); |
666 | if (ee && |
667 | (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || |
668 | kind == TensorExp::Kind::kBinaryBranch || |
669 | kind == TensorExp::Kind::kReduce || |
670 | kind == TensorExp::Kind::kSelect)) { |
671 | OpBuilder::InsertionGuard guard(rewriter); |
672 | ee = relinkBranch(env, rewriter, block: ee.getParentBlock(), e: ee); |
673 | } |
674 | } |
675 | |
676 | if (kind == TensorExp::Kind::kReduce) |
677 | env.endCustomReduc(); // exit custom |
678 | |
679 | if (kind == TensorExp::Kind::kSelect) |
680 | env.merger().setExprValue(e, v: v0); // Preserve value for later use. |
681 | |
682 | return ee; |
683 | } |
684 | |
685 | /// Hoists loop invariant tensor loads for which indices have been exhausted. |
686 | static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
687 | LoopId curr, bool isStart) { |
688 | if (exp == ::mlir::sparse_tensor::detail::kInvalidId) |
689 | return; |
690 | if (env.exp(e: exp).kind == TensorExp::Kind::kTensor) { |
691 | // Inspect tensor indices. |
692 | linalg::GenericOp op = env.op(); |
693 | OpOperand &t = op->getOpOperand(env.exp(e: exp).tensor); |
694 | const auto map = op.getMatchingIndexingMap(&t); |
695 | const auto stt = getSparseTensorType(val: t.get()); |
696 | const Level lvlRank = stt.getLvlRank(); |
697 | assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
698 | bool isCurrentLoop = curr == 0; // for scalar tensors |
699 | for (Level l = 0; l < lvlRank; l++) { |
700 | const AffineExpr a = map.getResult(l); |
701 | if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop)) |
702 | return; // still in play |
703 | } |
704 | // All exhausted at current level. |
705 | if (!isCurrentLoop) |
706 | return; |
707 | // Generate code for a scalarized reduction or invariant. Note that |
708 | // because custom reduction lhs may occur several times in the IR, |
709 | // we have a built-in safety for only initializing and wrapping-up |
710 | // the scalarized reduction once. |
711 | OpOperand *lhs = op.getDpsInitOperand(0); |
712 | if (lhs == &t) { |
713 | // Start or end a scalarized reduction. |
714 | if (isStart) { |
715 | if (env.isCustomReduc()) { |
716 | if (!env.isReduc()) |
717 | env.startReduc(exp, val: env.getCustomRedId()); |
718 | } else { |
719 | env.startReduc(exp, val: genTensorLoad(env, builder, exp)); |
720 | } |
721 | if (env.hasSparseOutput()) |
722 | env.startValidLexInsert( |
723 | val: constantI1(builder, env.op().getLoc(), false)); |
724 | } else { |
725 | if (!env.isCustomReduc() || env.isReduc()) |
726 | genTensorStore(env, builder, exp, rhs: env.endReduc()); |
727 | if (env.hasSparseOutput()) |
728 | env.endValidLexInsert(); |
729 | } |
730 | } else { |
731 | // Start or end loop invariant hoisting of a tensor load. |
732 | if (isStart) { |
733 | env.merger().setExprValue(e: exp, v: genTensorLoad(env, builder, exp)); |
734 | } else { |
735 | env.merger().clearExprValue(e: exp); |
736 | } |
737 | } |
738 | } else if (env.exp(e: exp).kind != TensorExp::Kind::kInvariant && |
739 | env.exp(e: exp).kind != TensorExp::Kind::kLoopVar && |
740 | env.exp(e: exp).kind != TensorExp::Kind::kSynZero) { |
741 | // Traverse into the binary operations. Note that we only hoist |
742 | // tensor loads, since subsequent MLIR/LLVM passes know how to |
743 | // deal with all other kinds of derived loop invariants. |
744 | if (env.exp(e: exp).kind == TensorExp::Kind::kReduce) |
745 | env.startCustomReduc(exp); // enter custom |
746 | const ExprId e0 = env.exp(e: exp).children.e0; |
747 | const ExprId e1 = env.exp(e: exp).children.e1; |
748 | genInvariants(env, builder, exp: e0, curr, isStart); |
749 | genInvariants(env, builder, exp: e1, curr, isStart); |
750 | if (env.exp(e: exp).kind == TensorExp::Kind::kReduce) |
751 | env.endCustomReduc(); // exit custom |
752 | } |
753 | } |
754 | |
755 | /// Generates an expanded access pattern in innermost dimension. |
756 | static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
757 | bool isStart) { |
758 | linalg::GenericOp op = env.op(); |
759 | OpOperand *lhs = op.getDpsInitOperand(0); |
760 | if (!env.atExpandLevel(o: lhs, rank: op.getRank(lhs), n: curr)) |
761 | return; // not needed at current level |
762 | assert(!env.isReduc()); |
763 | // Generate start or end of an expanded access pattern. Note that because |
764 | // an expansion does not rely on the ongoing contents of the sparse storage |
765 | // scheme, we can use the original tensor as incoming SSA value (which |
766 | // simplifies codegen a bit). If expansion on the actual contents is ever |
767 | // needed, we will need to use the SSA value in the insertion chain instead. |
768 | Value tensor = lhs->get(); |
769 | Location loc = op.getLoc(); |
770 | if (isStart) { |
771 | auto dynShape = {ShapedType::kDynamic}; |
772 | Type etp = cast<ShapedType>(tensor.getType()).getElementType(); |
773 | Type t1 = MemRefType::get(dynShape, etp); |
774 | Type t2 = MemRefType::get(dynShape, builder.getI1Type()); |
775 | Type t3 = MemRefType::get(dynShape, builder.getIndexType()); |
776 | Type t4 = builder.getIndexType(); |
777 | auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); |
778 | assert(r.getNumResults() == 4); |
779 | env.startExpand(values: r.getResult(0), filled: r.getResult(1), added: r.getResult(2), |
780 | count: r.getResult(3)); |
781 | } else { |
782 | SmallVector<Value> indices; |
783 | for (LoopId i = 0; i < curr; i++) |
784 | indices.push_back(Elt: env.emitter().getLoopIV(n: i)); |
785 | Value values = env.getExpandValues(); |
786 | Value filled = env.getExpandFilled(); |
787 | Value added = env.getExpandAdded(); |
788 | Value count = env.getExpandCount(); |
789 | Value chain = env.getInsertionChain(); |
790 | Value compress = builder.create<CompressOp>(loc, values, filled, added, |
791 | count, chain, indices); |
792 | env.updateInsertionChain(chain: compress); |
793 | env.endExpand(); |
794 | } |
795 | } |
796 | |
797 | /// Returns parallelization strategy. Any implicit loop in the Linalg |
798 | /// operation that is marked "parallel" is a candidate. Whether it is actually |
799 | /// converted to a parallel operation depends on the requested strategy. |
800 | static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { |
801 | // Reject parallelization of sparse output. |
802 | if (env.hasSparseOutput()) |
803 | return false; |
804 | // Parallel loops on tensor expansion can cause data races. |
805 | if (env.isExpand()) |
806 | return false; |
807 | // Inspect strategy. |
808 | switch (env.options().parallelizationStrategy) { |
809 | case SparseParallelizationStrategy::kNone: |
810 | return false; |
811 | case SparseParallelizationStrategy::kDenseOuterLoop: |
812 | return isOuter && !isSparse; |
813 | case SparseParallelizationStrategy::kAnyStorageOuterLoop: |
814 | return isOuter; |
815 | case SparseParallelizationStrategy::kDenseAnyLoop: |
816 | return !isSparse; |
817 | case SparseParallelizationStrategy::kAnyStorageAnyLoop: |
818 | return true; |
819 | } |
820 | llvm_unreachable("unexpected parallelization strategy" ); |
821 | } |
822 | |
823 | /// Whether or not the current loop being generated should be parallized (if |
824 | /// possible) according to the configuration. |
825 | static bool shouldTryParallize(CodegenEnv &env, LoopId curr, |
826 | ArrayRef<TensorLevel> tidLvls) { |
827 | linalg::GenericOp op = env.op(); |
828 | auto iteratorTypes = op.getIteratorTypesArray(); |
829 | bool isSparse = llvm::any_of(Range&: tidLvls, P: [curr, &env](TensorLevel tidLvl) { |
830 | // Queries the LT based on the tensor and loop id, as requested by |
831 | // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv |
832 | // should be consistent with the LT indexed by <TensorId, Level>. |
833 | const auto lt = env.lt(t: env.unpackTensorLevel(tl: tidLvl).first, i: curr); |
834 | return lt.hasSparseSemantic(); |
835 | }); |
836 | return isParallelFor(env, /*isOuter=*/curr == 0, isSparse); |
837 | } |
838 | |
839 | /// Emit a loop to coiterate over the list of tensor levels. The generated loop |
840 | /// can either be a for loop or while loop depending on whether there is at most |
841 | /// one sparse level in the list. |
842 | static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, |
843 | ArrayRef<TensorLevel> tidLvls, |
844 | unsigned numCases, bool tryParallel, |
845 | bool needsUniv) { |
846 | Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { |
847 | // Construct while-loop with a parameter for each index. |
848 | return env.emitter().enterCoIterationOverTensorsAtLvls( |
849 | builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel, |
850 | needsUniv); |
851 | }); |
852 | assert(loop); |
853 | return loop; |
854 | } |
855 | |
856 | /// Generates a for-loop or a while-loop, depending on whether it implements |
857 | /// singleton iteration or co-iteration over the given conjunction. |
858 | static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
859 | unsigned numCases, bool needsUniv, |
860 | ArrayRef<TensorLevel> tidLvls) { |
861 | bool tryParallel = shouldTryParallize(env, curr, tidLvls); |
862 | return genCoIteration(env, builder, tidLvls, numCases, tryParallel, |
863 | needsUniv); |
864 | } |
865 | |
866 | /// Generates the induction structure for a while-loop. |
867 | static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, |
868 | bool needsUniv) { |
869 | Location loc = env.op().getLoc(); |
870 | // Finalize each else branch of all if statements. |
871 | if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { |
872 | while (auto ifOp = dyn_cast_or_null<scf::IfOp>( |
873 | builder.getInsertionBlock()->getParentOp())) { |
874 | // Break on IfOp for slicing filtering. |
875 | if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) == |
876 | StringAttr::get(ifOp->getContext(), "slice" )) |
877 | break; |
878 | |
879 | unsigned y = 0; |
880 | SmallVector<Value> yields; |
881 | if (env.isReduc()) { |
882 | yields.push_back(Elt: env.getReduc()); |
883 | env.updateReduc(val: ifOp.getResult(y++)); |
884 | if (env.isValidLexInsert()) { |
885 | yields.push_back(Elt: env.getValidLexInsert()); |
886 | env.updateValidLexInsert(val: ifOp.getResult(y++)); |
887 | } |
888 | } |
889 | if (env.isExpand()) { |
890 | yields.push_back(Elt: env.getExpandCount()); |
891 | env.updateExpandCount(count: ifOp->getResult(y++)); |
892 | } |
893 | if (env.getInsertionChain()) { |
894 | yields.push_back(Elt: env.getInsertionChain()); |
895 | env.updateInsertionChain(chain: ifOp->getResult(y++)); |
896 | } |
897 | assert(y == yields.size()); |
898 | builder.create<scf::YieldOp>(loc, yields); |
899 | builder.setInsertionPointAfter(ifOp); |
900 | } |
901 | } |
902 | // No need to set the insertion point here as LoopEmitter keeps track of the |
903 | // basic block where scf::Yield should be inserted. |
904 | } |
905 | |
906 | /// Generates a case region in the coiterate operation. |
907 | static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, |
908 | unsigned caseIdx, LatPointId allCase, |
909 | LatPointId curCase, |
910 | MutableArrayRef<Value> reduc) { |
911 | assert(allCase == curCase || env.merger().latGT(allCase, curCase)); |
912 | const BitVector &allCaseBits = env.merger().lat(p: allCase).simple; |
913 | const BitVector &curCaseBits = env.merger().lat(p: curCase).simple; |
914 | |
915 | /// Computes the subset of iterators that are valid in the current case being |
916 | /// generated. |
917 | I64BitSet caseBit(0); |
918 | for (auto [idx, set] : llvm::enumerate(First: allCaseBits.set_bits())) |
919 | if (curCaseBits.test(Idx: set)) |
920 | caseBit.set(idx); |
921 | |
922 | env.emitter().enterCurrentCoIterationCase(builder, loc: env.op().getLoc(), caseBit, |
923 | caseIdx, reduc); |
924 | } |
925 | |
926 | /// Generates a single if-statement within a while-loop. |
927 | static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
928 | LatPointId p) { |
929 | Location loc = env.op().getLoc(); |
930 | SmallVector<Type> types; |
931 | Value cond; |
932 | env.merger().foreachTensorLoopId( |
933 | p, /*simple=*/true, |
934 | callback: [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, |
935 | bool isIdxRed) { |
936 | if (isIdxRed) { |
937 | // Since there is no 1:1 mapping from loop to level (multiple loops |
938 | // are required to resolve one level with non-trivial index |
939 | // expression), we need to reconstruct the tensor level types if this |
940 | // loop requires index reduction condition. |
941 | assert(lvl.has_value() && isUndefLT(lt)); |
942 | auto stt = getSparseTensorType(env.op().getInputs()[tid]); |
943 | lt = stt.getLvlType(*lvl); |
944 | } |
945 | assert(curr == env.merger().loop(b)); |
946 | Value clause; |
947 | if (lt.hasSparseSemantic()) { |
948 | assert(lvl.has_value()); |
949 | const Value crd = env.emitter().getCoord(tid, lvl: *lvl); |
950 | const Value lvar = env.getLoopVar(i: curr); |
951 | clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
952 | crd, lvar); |
953 | } else { |
954 | assert(lt.hasDenseSemantic() || isUndefLT(lt)); |
955 | clause = constantI1(builder, loc, b: true); |
956 | } |
957 | cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; |
958 | }); |
959 | if (env.isReduc()) { |
960 | types.push_back(Elt: env.getReduc().getType()); |
961 | if (env.isValidLexInsert()) |
962 | types.push_back(Elt: env.getValidLexInsert().getType()); |
963 | } |
964 | if (env.isExpand()) |
965 | types.push_back(builder.getIndexType()); |
966 | if (env.getInsertionChain()) |
967 | types.push_back(Elt: env.getInsertionChain().getType()); |
968 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
969 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
970 | return ifOp; |
971 | } |
972 | |
973 | /// Generates end of true branch of if-statement within a while-loop. |
974 | static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, |
975 | Value redInput, Value cntInput, Value insInput, |
976 | Value validIns) { |
977 | SmallVector<Value> operands; |
978 | if (env.isReduc()) { |
979 | operands.push_back(Elt: env.getReduc()); |
980 | env.updateReduc(val: redInput); |
981 | if (env.isValidLexInsert()) { |
982 | // Any overlapping indices during a reduction creates a valid lex insert. |
983 | operands.push_back(Elt: constantI1(builder, env.op().getLoc(), true)); |
984 | env.updateValidLexInsert(val: validIns); |
985 | } |
986 | } |
987 | if (env.isExpand()) { |
988 | operands.push_back(Elt: env.getExpandCount()); |
989 | env.updateExpandCount(count: cntInput); |
990 | } |
991 | if (env.getInsertionChain()) { |
992 | operands.push_back(Elt: env.getInsertionChain()); |
993 | env.updateInsertionChain(chain: insInput); |
994 | } |
995 | if (!operands.empty()) |
996 | builder.create<scf::YieldOp>(env.op().getLoc(), operands); |
997 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
998 | } |
999 | |
1000 | //===----------------------------------------------------------------------===// |
1001 | // Sparsifier synthesis methods (loop sequence). |
1002 | //===----------------------------------------------------------------------===// |
1003 | |
1004 | static bool getAllTidLvlsInLatPoints( |
1005 | CodegenEnv &env, LatPointId li, LoopId curr, |
1006 | llvm::function_ref<void(TensorLevel, AffineExpr)> callback) { |
1007 | const BitVector &simple = env.lat(l: li).simple; |
1008 | const TensorId outTid = env.merger().getOutTensorID(); |
1009 | const std::optional<Level> outLvl = env.merger().getLvl(t: outTid, i: curr); |
1010 | |
1011 | unsigned numloopCond = 0; |
1012 | bool hasNonUnique = false; |
1013 | env.merger().foreachTensorLoopId( |
1014 | p: li, callback: [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, |
1015 | LevelType lt, bool isIdxReduc) { |
1016 | if (simple[b]) { |
1017 | if (isIdxReduc) { |
1018 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
1019 | numloopCond++; |
1020 | return; |
1021 | } |
1022 | if (isUndefLT(lt)) { |
1023 | // An undefined lt in the lattices, we probably mean to |
1024 | // generate a dense loop according to the synthetic tensor (for |
1025 | // invariants and sparse output tensor). |
1026 | if (env.merger().getSynTensorID() == tid) { |
1027 | // Coiterating with an invariant |
1028 | // e.g., out = prod(in[i][j] op invariant); |
1029 | // or a broadcast |
1030 | // e.g., out[i][j] = in[i] (j is undef for input) |
1031 | // |
1032 | // The level of the synthetic tensor is the current loop depth; |
1033 | // the rank of the synthetic tensor equals to number of loops. |
1034 | assert(curr == env.getCurrentDepth()); |
1035 | lvl = curr; |
1036 | } else if (!lvl) { |
1037 | // Skips invalid lvl (e.g., when this is a zero ranked tensor). |
1038 | return; |
1039 | } |
1040 | } |
1041 | hasNonUnique = !isUniqueLT(lt) || hasNonUnique; |
1042 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
1043 | numloopCond++; |
1044 | } else if (lt.hasDenseSemantic() || isIdxReduc) { |
1045 | callback(env.makeTensorLevel(t: tid, l: *lvl), nullptr); |
1046 | } else { |
1047 | assert(isUndefLT(lt)); |
1048 | linalg::GenericOp op = env.op(); |
1049 | if (tid >= op.getNumDpsInputs()) |
1050 | // We only handle affine expression on input tensors (for now). |
1051 | return; |
1052 | OpOperand *operand = &op->getOpOperand(tid); |
1053 | const auto stt = getSparseTensorType(val: operand->get()); |
1054 | // Non-annotated dense tensors requires no special handling. |
1055 | if (!stt.hasEncoding()) |
1056 | return; |
1057 | |
1058 | ArrayRef<AffineExpr> affines = |
1059 | op.getMatchingIndexingMap(operand).getResults(); |
1060 | const Level lvlRank = stt.getLvlRank(); |
1061 | assert(affines.size() == static_cast<size_t>(lvlRank)); |
1062 | for (Level l = 0; l < lvlRank; l++) { |
1063 | AffineExpr exp = affines[l]; |
1064 | // Skip simple affine expression and non-dense levels (which |
1065 | // have their own filter loop). |
1066 | LevelType lt = stt.getLvlType(l); |
1067 | if (isa<AffineDimExpr>(Val: exp) || !lt.hasDenseSemantic()) |
1068 | continue; |
1069 | |
1070 | // Constant affine expression are handled in genLoop. |
1071 | if (!isa<AffineConstantExpr>(Val: exp)) { |
1072 | bool isCurrentLoop = false; |
1073 | assert(curr == env.getCurrentDepth()); |
1074 | if (isInvariantAffine(a: exp, curr: curr + 1, /*out*/ isCurrentLoop) && |
1075 | isCurrentLoop) { |
1076 | // If the compound affine is invariant and we are right at the |
1077 | // level. We need to generate the address according to the |
1078 | // affine expression. This is also the best place we can do it |
1079 | // to avoid putting it inside inner loops. |
1080 | callback(env.makeTensorLevel(t: tid, l), exp); |
1081 | } |
1082 | } |
1083 | } |
1084 | } |
1085 | }); |
1086 | |
1087 | if (isDenseLT(lt: env.lt(t: outTid, i: curr))) { |
1088 | auto stt = getSparseTensorType(env.op().getOutputs().front()); |
1089 | // Note that we generate dense indices of the output tensor unconditionally, |
1090 | // since they may not appear in the lattice, but may be needed for |
1091 | // linearized env. |
1092 | // TODO: we should avoid introducing corner cases for all-dense sparse |
1093 | // tensors. |
1094 | if (stt.hasEncoding() && stt.isAllDense()) |
1095 | callback(env.makeTensorLevel(t: outTid, l: *outLvl), nullptr); |
1096 | } |
1097 | |
1098 | if (numloopCond == 0) { |
1099 | // Corner cases where the loop bound is defined by a *unused* operand, in |
1100 | // this case, we just generate a dense "fake" loop by iterating over the |
1101 | // synthetic tensor. |
1102 | callback(env.makeTensorLevel(t: env.merger().getSynTensorID(), l: curr), nullptr); |
1103 | numloopCond++; |
1104 | } |
1105 | // If we just need to one loop conditions and the conditions is not imposed on |
1106 | // non-unique level, the loop can be generated by a for loop. |
1107 | // Or, if we are generating sparse-iterator-based loops, we always generate |
1108 | // `sparse_tensor.iterate` regardless whether the level is unique or not. |
1109 | return numloopCond == 1 && |
1110 | (!hasNonUnique || env.options().sparseEmitStrategy == |
1111 | SparseEmitStrategy::kSparseIterator); |
1112 | } |
1113 | |
1114 | /// Starts a loop sequence at given level. Returns true if |
1115 | /// the universal loop index must be maintained at this level. |
1116 | static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
1117 | LoopId curr, LatSetId lts) { |
1118 | assert(!env.getLoopVar(curr)); |
1119 | // Emit invariants at this loop sequence level. |
1120 | genInvariants(env, builder, exp, curr, /*isStart=*/true); |
1121 | // Emit access pattern expansion for sparse tensor output. |
1122 | genExpand(env, builder, curr, /*isStart=*/true); |
1123 | // Emit further initialization at this loop sequence level. |
1124 | const LatPointId l0 = env.set(lts)[0]; |
1125 | |
1126 | SmallVector<TensorLevel> tidLvls; |
1127 | getAllTidLvlsInLatPoints(env, li: l0, curr, callback: [&](TensorLevel tl, AffineExpr) { |
1128 | // TODO: remove this! The same tensor level might be added for multiple |
1129 | // times due to the special handling for all-dense "sparse" output tensor |
1130 | // (see L1038). |
1131 | if (llvm::is_contained(Range&: tidLvls, Element: tl)) |
1132 | return; |
1133 | tidLvls.emplace_back(Args&: tl); |
1134 | }); |
1135 | |
1136 | env.emitter().enterNewLoopSeq(builder, loc: env.op().getLoc(), tidLvls); |
1137 | |
1138 | // Maintain the universal index only if it is actually |
1139 | // consumed by a subsequent lattice point. |
1140 | for (const LatPointId li : env.set(lts).drop_front()) |
1141 | if (!env.merger().hasAnySparse(bits: env.lat(l: li).simple)) |
1142 | return true; |
1143 | |
1144 | return false; |
1145 | } |
1146 | |
1147 | // Generates dense affine address for encoding. |
1148 | static void genConstantDenseAddressFromLevel(CodegenEnv &env, |
1149 | OpBuilder &builder, TensorId tid, |
1150 | Level startLvl) { |
1151 | // TODO: Handle affine expression on output tensor. |
1152 | linalg::GenericOp op = env.op(); |
1153 | assert(tid < op.getNumDpsInputs()); |
1154 | OpOperand *input = op.getDpsInputOperands()[tid]; |
1155 | const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); |
1156 | const auto enc = getSparseTensorEncoding(input->get().getType()); |
1157 | if (enc) { |
1158 | const Location loc = op.getLoc(); |
1159 | const TensorId tid = env.makeTensorId(t: input->getOperandNumber()); |
1160 | const Level lvlRank = enc.getLvlRank(); |
1161 | assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); |
1162 | for (Level l = startLvl; l < lvlRank; l++) { |
1163 | AffineExpr lvlExpr = lvlExprs[l]; |
1164 | if (enc.getLvlType(l).hasDenseSemantic() && |
1165 | isa<AffineConstantExpr>(Val: lvlExpr)) |
1166 | env.emitter().locateLvlAtAffineAddress( |
1167 | builder, loc, tidLvl: env.makeTensorLevel(t: tid, l), lvlExpr); |
1168 | else |
1169 | return; // break on first non-dense non-constant level |
1170 | } |
1171 | } |
1172 | } |
1173 | |
1174 | // We can generate address for constant affine expression before any loops |
1175 | // starting from the first level as they do not depend on anything. |
1176 | // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two |
1177 | // levels can be determined before loops. |
1178 | static void genInitConstantDenseAddress(CodegenEnv &env, |
1179 | RewriterBase &rewriter) { |
1180 | for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) |
1181 | genConstantDenseAddressFromLevel(env, builder&: rewriter, tid, startLvl: 0); |
1182 | } |
1183 | |
1184 | /// Returns true if the lattice bit can be iterated by a for loop. |
1185 | static bool translateBitsToTidLvlPairs( |
1186 | CodegenEnv &env, LatPointId li, LoopId curr, |
1187 | SmallVectorImpl<TensorLevel> &tidLvls, |
1188 | SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { |
1189 | return getAllTidLvlsInLatPoints(env, li, curr, |
1190 | callback: [&](TensorLevel tl, AffineExpr exp) { |
1191 | if (exp) |
1192 | affineTidLvls.emplace_back(Args&: tl, Args&: exp); |
1193 | else |
1194 | tidLvls.emplace_back(Args&: tl); |
1195 | }); |
1196 | } |
1197 | |
1198 | /// Starts a single loop in current sequence. |
1199 | static std::pair<Operation *, bool> startLoop(CodegenEnv &env, |
1200 | OpBuilder &builder, LoopId curr, |
1201 | LatPointId li, unsigned numCases, |
1202 | bool needsUniv) { |
1203 | // TODO: numCases only used when generating iterator-based loops. Cleanup |
1204 | // after fully migration. |
1205 | // The set of tensors + lvls to generate loops on |
1206 | SmallVector<TensorLevel> tidLvls; |
1207 | |
1208 | // The set of dense tensors with non-trivial affine expression that just |
1209 | // becomes invariant and the address are generated at the current level. |
1210 | SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls; |
1211 | bool isSingleCond = |
1212 | translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls); |
1213 | |
1214 | // Emit the for/while-loop control. |
1215 | Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls); |
1216 | Location loc = env.op().getLoc(); |
1217 | for (auto [tidLvl, exp] : affineTidLvls) { |
1218 | env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, lvlExpr: exp); |
1219 | } |
1220 | |
1221 | // Until now, we have entered every <tid, lvl> pair in {cond, extra, |
1222 | // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent |
1223 | // on constant affines expression may now be determined. |
1224 | auto allTidLvls = |
1225 | llvm::concat<TensorLevel>(Ranges&: tidLvls, Ranges: llvm::make_first_range(c&: affineTidLvls)); |
1226 | for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) { |
1227 | if (tid != env.merger().getOutTensorID() && |
1228 | tid != env.merger().getSynTensorID()) |
1229 | genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); |
1230 | } |
1231 | |
1232 | return std::make_pair(x&: loop, y&: isSingleCond); |
1233 | } |
1234 | |
1235 | /// Ends a single loop in current sequence. Returns new values for needsUniv. |
1236 | static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, |
1237 | LatPointId li, bool needsUniv, bool isSingleCond) { |
1238 | // Either a for-loop or a while-loop that iterates over a slice. |
1239 | if (isSingleCond) { |
1240 | // Any iteration creates a valid lex insert. |
1241 | if (env.isReduc() && env.isValidLexInsert()) |
1242 | env.updateValidLexInsert(val: constantI1(rewriter, env.op().getLoc(), true)); |
1243 | } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { |
1244 | // End a while-loop. |
1245 | finalizeWhileOp(env, builder&: rewriter, needsUniv); |
1246 | } else { |
1247 | needsUniv = false; |
1248 | } |
1249 | env.genLoopBoundary(callback: [&](MutableArrayRef<Value> reduc) { |
1250 | env.emitter().exitCurrentLoop(rewriter, loc: env.op().getLoc(), reduc); |
1251 | return std::nullopt; |
1252 | }); |
1253 | return needsUniv; |
1254 | } |
1255 | |
1256 | /// Ends a loop sequence at given level. |
1257 | static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, |
1258 | unsigned at) { |
1259 | assert(!env.getLoopVar(at)); |
1260 | env.emitter().exitCurrentLoopSeq(builder, loc: env.op().getLoc()); |
1261 | // Unmark bookkeeping of invariants and loop index. |
1262 | genInvariants(env, builder, exp, curr: at, /*isStart=*/false); |
1263 | // Finalize access pattern expansion for sparse tensor output. |
1264 | genExpand(env, builder, curr: at, /*isStart=*/false); |
1265 | } |
1266 | |
1267 | /// Recursively generates code while computing iteration lattices in order |
1268 | /// to manage the complexity of implementing co-iteration over unions |
1269 | /// and intersections of sparse iterations spaces. |
1270 | static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, |
1271 | LoopId curr) { |
1272 | assert(curr == env.getCurrentDepth()); |
1273 | |
1274 | // At each leaf, assign remaining tensor (sub)expression to output tensor. |
1275 | if (curr == env.getLoopNum()) { |
1276 | Value rhs = genExp(env, rewriter, e: exp); |
1277 | genTensorStore(env, builder&: rewriter, exp, rhs); |
1278 | return; |
1279 | } |
1280 | |
1281 | // Construct iteration lattices for current loop index. |
1282 | const LatSetId lts = |
1283 | env.merger().optimizeSet(s: env.merger().buildLattices(e: exp, i: curr)); |
1284 | |
1285 | // Start a loop sequence. |
1286 | bool needsUniv = startLoopSeq(env, builder&: rewriter, exp, curr, lts); |
1287 | |
1288 | // When using sparse-iterator-based loops, we only need one loops, as |
1289 | // opposed to a loop sequence, to cover all the iterator spaces. |
1290 | const unsigned lsize = env.set(lts).size(); |
1291 | if (env.generatingSparseIterator()) { |
1292 | // Get the largest lattice point and start a loop. |
1293 | const LatPointId li = env.set(lts)[0]; |
1294 | auto [loop, isSingleCond] = |
1295 | startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv); |
1296 | assert(isSingleCond == llvm::isa<IterateOp>(loop)); |
1297 | // We cannot change this to `for (const LatPointId li : env.set(lts))` |
1298 | // because the loop body causes data-movement which invalidates |
1299 | // the iterator. |
1300 | for (unsigned j = 0; j < lsize; j++) { |
1301 | const LatPointId lj = env.set(lts)[j]; |
1302 | const ExprId ej = env.lat(l: lj).exp; |
1303 | // Recurse into body of each branch. |
1304 | if (!isSingleCond) { |
1305 | env.genLoopBoundary(callback: [&, curr, j, li, lj](MutableArrayRef<Value> reduc) { |
1306 | genCoIterationCase(env, builder&: rewriter, /*caseIdx*/ j, allCase: li, curCase: lj, reduc); |
1307 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
1308 | // TODO: handle yield values. |
1309 | assert(reduc.empty() && "Not Implemented" ); |
1310 | rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc()); |
1311 | return std::nullopt; |
1312 | }); |
1313 | // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); |
1314 | } else { |
1315 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
1316 | } |
1317 | } |
1318 | // End a loop. |
1319 | needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond); |
1320 | } else { |
1321 | // Emit a loop for every lattice point L0 >= Li in this loop sequence. |
1322 | for (unsigned i = 0; i < lsize; i++) { |
1323 | const LatPointId li = env.set(lts)[i]; |
1324 | // Start a loop. |
1325 | auto [loop, isSingleCond] = |
1326 | startLoop(env, builder&: rewriter, curr, li, numCases: lsize, needsUniv); |
1327 | |
1328 | // Visit all lattices points with Li >= Lj to generate the |
1329 | // loop-body, possibly with if statements for coiteration. |
1330 | Value redInput = env.getReduc(); |
1331 | Value cntInput = env.getExpandCount(); |
1332 | Value insInput = env.getInsertionChain(); |
1333 | Value validIns = env.getValidLexInsert(); |
1334 | // We cannot change this to `for (const LatPointId lj : env.set(lts))` |
1335 | // because the loop body causes data-movement which invalidates the |
1336 | // iterator. |
1337 | for (unsigned j = 0; j < lsize; j++) { |
1338 | const LatPointId lj = env.set(lts)[j]; |
1339 | const ExprId ej = env.lat(l: lj).exp; |
1340 | if (li == lj || env.merger().latGT(p0: li, p1: lj)) { |
1341 | // Recurse into body of each branch. |
1342 | if (!isSingleCond) { |
1343 | scf::IfOp ifOp = genIf(env, rewriter, curr, lj); |
1344 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
1345 | endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); |
1346 | } else { |
1347 | genStmt(env, rewriter, exp: ej, curr: curr + 1); |
1348 | } |
1349 | } |
1350 | } |
1351 | |
1352 | // End a loop. |
1353 | needsUniv = endLoop(env, rewriter, loop, li: curr, needsUniv, isSingleCond); |
1354 | } |
1355 | } |
1356 | |
1357 | // End a loop sequence. |
1358 | endLoopSeq(env, builder&: rewriter, exp, at: curr); |
1359 | assert(curr == env.getCurrentDepth()); |
1360 | } |
1361 | |
1362 | /// Converts the result computed by the sparse kernel into the required form. |
1363 | static void genResult(CodegenEnv &env, RewriterBase &rewriter) { |
1364 | linalg::GenericOp op = env.op(); |
1365 | OpOperand *lhs = op.getDpsInitOperand(0); |
1366 | Value tensor = lhs->get(); |
1367 | Type resType = tensor.getType(); |
1368 | if (getSparseTensorEncoding(type: resType)) { |
1369 | // The sparse tensor rematerializes from the original sparse tensor's |
1370 | // underlying sparse storage format. For an insertion chain, the |
1371 | // tensor materializes from the chain with 'hasInserts' enabled. |
1372 | bool hasInserts = false; |
1373 | if (Value chain = env.getInsertionChain()) { |
1374 | hasInserts = true; |
1375 | tensor = chain; |
1376 | } |
1377 | rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts); |
1378 | } else { |
1379 | // To rematerialize an non-annotated tensor, simply load it |
1380 | // from the bufferized value. |
1381 | Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()]; |
1382 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); |
1383 | } |
1384 | } |
1385 | |
1386 | //===----------------------------------------------------------------------===// |
1387 | // Sparsifier rewriting methods. |
1388 | //===----------------------------------------------------------------------===// |
1389 | |
1390 | namespace { |
1391 | |
1392 | /// Sparse rewriting rule for generic Lingalg operation. |
1393 | struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { |
1394 | public: |
1395 | GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) |
1396 | : OpRewritePattern<linalg::GenericOp>(context), options(o) {} |
1397 | |
1398 | LogicalResult matchAndRewrite(linalg::GenericOp op, |
1399 | PatternRewriter &rewriter) const override { |
1400 | // Only accept single output operations with pure tensor semantics. |
1401 | if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics()) |
1402 | return failure(); |
1403 | |
1404 | // Only accept trivial affine indices. |
1405 | if (hasNonTrivialAffineOnSparseOut(op)) |
1406 | return failure(); |
1407 | |
1408 | // Only accept scheduled loops. |
1409 | if (!op->hasAttr("sorted" )) { |
1410 | return rewriter.notifyMatchFailure( |
1411 | op, "Loops not yet scheduled, try run --sparse-reinterpret-map " |
1412 | "before sparsification." ); |
1413 | } |
1414 | |
1415 | // Must have been demapped as well if the generic op is sorted. |
1416 | assert(!hasAnyNonIdentityOperandsOrResults(op)); |
1417 | |
1418 | // Sets up a code generation environment. |
1419 | const unsigned numTensors = op->getNumOperands(); |
1420 | const unsigned numLoops = op.getNumLoops(); |
1421 | bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0; |
1422 | // If we have indexing map like (d0) -> (0, d0), there might be more |
1423 | // levels then loops because of the constant index, that means we can not |
1424 | // use numLoops as the upper bound for ranks of all tensors. |
1425 | // TODO: Constant indices are currently not support on sparse tensor, but |
1426 | // are allowed in non-annotated dense tensor. Support it, it would be |
1427 | // required for sparse tensor slice rank reducing too. |
1428 | Level maxLvlRank = 0; |
1429 | for (auto operand : op.getOperands()) { |
1430 | if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) { |
1431 | maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); |
1432 | } |
1433 | } |
1434 | |
1435 | // Detects sparse annotations and translates the per-level sparsity |
1436 | // information for all tensors to loop indices in the kernel. |
1437 | CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); |
1438 | if (!findSparseAnnotations(env, idxReducBased: needIdxRed)) |
1439 | return failure(); |
1440 | |
1441 | // Only standard reduction operations (add, sub, or, xor) that can be |
1442 | // sparsified by merely reducing the stored values are admissible. More |
1443 | // elaborate reduction operations (such as mul, and, min, max) would need |
1444 | // to know whether implicit zeros occur as well. They can still be |
1445 | // implemented with a custom reduction operation, accepted here as well. |
1446 | if (op.getNumReductionLoops() > 0) { |
1447 | Operation *yield = op.getRegion().front().getTerminator(); |
1448 | assert(isa<linalg::YieldOp>(yield)); |
1449 | Operation *redop = yield->getOperand(idx: 0).getDefiningOp(); |
1450 | if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) && |
1451 | !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) && |
1452 | !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) && |
1453 | !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) && |
1454 | !isa<ReduceOp>(redop)) { |
1455 | return failure(); |
1456 | } |
1457 | } |
1458 | |
1459 | // Constructs the tensor expressions tree from `op`, returns failure if the |
1460 | // tree can not be built or the tensor expression is inadmissible. |
1461 | if (failed(Result: env.initTensorExp())) |
1462 | return failure(); |
1463 | |
1464 | // Recursively generates code if admissible. |
1465 | env.startEmit(emitStrategy: options.sparseEmitStrategy); |
1466 | genBuffers(env, builder&: rewriter); |
1467 | // TODO: Constant affine expression should be handled differently when using |
1468 | // slice-based codegen, it does not matter now because we already reject the |
1469 | // constant expression at an earlier stage. |
1470 | genInitConstantDenseAddress(env, rewriter); |
1471 | genStmt(env, rewriter, exp: env.getExprId(), curr: 0); |
1472 | genResult(env, rewriter); |
1473 | return success(); |
1474 | } |
1475 | |
1476 | private: |
1477 | /// Options to control sparse code generation. |
1478 | SparsificationOptions options; |
1479 | }; |
1480 | |
1481 | } // namespace |
1482 | |
1483 | /// Populates the given patterns list with rewriting rules required for |
1484 | /// the sparsification of linear algebra operations. |
1485 | void mlir::populateSparsificationPatterns( |
1486 | RewritePatternSet &patterns, const SparsificationOptions &options) { |
1487 | patterns.add<GenericOpSparsifier>(arg: patterns.getContext(), args: options); |
1488 | } |
1489 | |