1//===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts loops generated by the sparsifier into a form that
10// can exploit SIMD instructions of the target architecture. Note that this pass
11// ensures the sparsifier can generate efficient SIMD (including ArmSVE
12// support) with proper separation of concerns as far as sparsification and
13// vectorization is concerned. However, this pass is not the final abstraction
14// level we want, and not the general vectorizer we want either. It forms a good
15// stepping stone for incremental future improvements though.
16//
17//===----------------------------------------------------------------------===//
18
19#include "Utils/CodegenUtils.h"
20#include "Utils/LoopEmitter.h"
21
22#include "mlir/Dialect/Affine/IR/AffineOps.h"
23#include "mlir/Dialect/Arith/IR/Arith.h"
24#include "mlir/Dialect/Math/IR/Math.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28#include "mlir/Dialect/Vector/IR/VectorOps.h"
29#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
30#include "mlir/IR/Matchers.h"
31
32using namespace mlir;
33using namespace mlir::sparse_tensor;
34
35namespace {
36
37/// Target SIMD properties:
38/// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
39/// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
40/// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
41struct VL {
42 unsigned vectorLength;
43 bool enableVLAVectorization;
44 bool enableSIMDIndex32;
45};
46
47/// Helper test for invariant value (defined outside given block).
48static bool isInvariantValue(Value val, Block *block) {
49 return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
50}
51
52/// Helper test for invariant argument (defined outside given block).
53static bool isInvariantArg(BlockArgument arg, Block *block) {
54 return arg.getOwner() != block;
55}
56
57/// Constructs vector type for element type.
58static VectorType vectorType(VL vl, Type etp) {
59 return VectorType::get(shape: vl.vectorLength, elementType: etp, scalableDims: vl.enableVLAVectorization);
60}
61
62/// Constructs vector type from a memref value.
63static VectorType vectorType(VL vl, Value mem) {
64 return vectorType(vl, etp: getMemRefType(t&: mem).getElementType());
65}
66
67/// Constructs vector iteration mask.
68static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
69 Value iv, Value lo, Value hi, Value step) {
70 VectorType mtp = vectorType(vl, etp: rewriter.getI1Type());
71 // Special case if the vector length evenly divides the trip count (for
72 // example, "for i = 0, 128, 16"). A constant all-true mask is generated
73 // so that all subsequent masked memory operations are immediately folded
74 // into unconditional memory operations.
75 IntegerAttr loInt, hiInt, stepInt;
76 if (matchPattern(value: lo, pattern: m_Constant(bind_value: &loInt)) &&
77 matchPattern(value: hi, pattern: m_Constant(bind_value: &hiInt)) &&
78 matchPattern(value: step, pattern: m_Constant(bind_value: &stepInt))) {
79 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
80 Value trueVal = constantI1(builder&: rewriter, loc, b: true);
81 return rewriter.create<vector::BroadcastOp>(location: loc, args&: mtp, args&: trueVal);
82 }
83 }
84 // Otherwise, generate a vector mask that avoids overrunning the upperbound
85 // during vector execution. Here we rely on subsequent loop optimizations to
86 // avoid executing the mask in all iterations, for example, by splitting the
87 // loop into an unconditional vector loop and a scalar cleanup loop.
88 auto min = AffineMap::get(
89 /*dimCount=*/2, /*symbolCount=*/1,
90 results: {rewriter.getAffineSymbolExpr(position: 0),
91 rewriter.getAffineDimExpr(position: 0) - rewriter.getAffineDimExpr(position: 1)},
92 context: rewriter.getContext());
93 Value end = rewriter.createOrFold<affine::AffineMinOp>(
94 location: loc, args&: min, args: ValueRange{hi, iv, step});
95 return rewriter.create<vector::CreateMaskOp>(location: loc, args&: mtp, args&: end);
96}
97
98/// Generates a vectorized invariant. Here we rely on subsequent loop
99/// optimizations to hoist the invariant broadcast out of the vector loop.
100static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
101 Value val) {
102 VectorType vtp = vectorType(vl, etp: val.getType());
103 return rewriter.create<vector::BroadcastOp>(location: val.getLoc(), args&: vtp, args&: val);
104}
105
106/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
107/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
108/// that the sparsifier can only generate indirect loads in
109/// the last index, i.e. back().
110static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
111 Value mem, ArrayRef<Value> idxs, Value vmask) {
112 VectorType vtp = vectorType(vl, mem);
113 Value pass = constantZero(builder&: rewriter, loc, tp: vtp);
114 if (llvm::isa<VectorType>(Val: idxs.back().getType())) {
115 SmallVector<Value> scalarArgs(idxs);
116 Value indexVec = idxs.back();
117 scalarArgs.back() = constantIndex(builder&: rewriter, loc, i: 0);
118 return rewriter.create<vector::GatherOp>(location: loc, args&: vtp, args&: mem, args&: scalarArgs,
119 args&: indexVec, args&: vmask, args&: pass);
120 }
121 return rewriter.create<vector::MaskedLoadOp>(location: loc, args&: vtp, args&: mem, args&: idxs, args&: vmask,
122 args&: pass);
123}
124
125/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
126/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
127/// that the sparsifier can only generate indirect stores in
128/// the last index, i.e. back().
129static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
130 ArrayRef<Value> idxs, Value vmask, Value rhs) {
131 if (llvm::isa<VectorType>(Val: idxs.back().getType())) {
132 SmallVector<Value> scalarArgs(idxs);
133 Value indexVec = idxs.back();
134 scalarArgs.back() = constantIndex(builder&: rewriter, loc, i: 0);
135 rewriter.create<vector::ScatterOp>(location: loc, args&: mem, args&: scalarArgs, args&: indexVec, args&: vmask,
136 args&: rhs);
137 return;
138 }
139 rewriter.create<vector::MaskedStoreOp>(location: loc, args&: mem, args&: idxs, args&: vmask, args&: rhs);
140}
141
142/// Detects a vectorizable reduction operations and returns the
143/// combining kind of reduction on success in `kind`.
144static bool isVectorizableReduction(Value red, Value iter,
145 vector::CombiningKind &kind) {
146 if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
147 kind = vector::CombiningKind::ADD;
148 return addf->getOperand(idx: 0) == iter || addf->getOperand(idx: 1) == iter;
149 }
150 if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
151 kind = vector::CombiningKind::ADD;
152 return addi->getOperand(idx: 0) == iter || addi->getOperand(idx: 1) == iter;
153 }
154 if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
155 kind = vector::CombiningKind::ADD;
156 return subf->getOperand(idx: 0) == iter;
157 }
158 if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
159 kind = vector::CombiningKind::ADD;
160 return subi->getOperand(idx: 0) == iter;
161 }
162 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
163 kind = vector::CombiningKind::MUL;
164 return mulf->getOperand(idx: 0) == iter || mulf->getOperand(idx: 1) == iter;
165 }
166 if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
167 kind = vector::CombiningKind::MUL;
168 return muli->getOperand(idx: 0) == iter || muli->getOperand(idx: 1) == iter;
169 }
170 if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
171 kind = vector::CombiningKind::AND;
172 return andi->getOperand(idx: 0) == iter || andi->getOperand(idx: 1) == iter;
173 }
174 if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
175 kind = vector::CombiningKind::OR;
176 return ori->getOperand(idx: 0) == iter || ori->getOperand(idx: 1) == iter;
177 }
178 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
179 kind = vector::CombiningKind::XOR;
180 return xori->getOperand(idx: 0) == iter || xori->getOperand(idx: 1) == iter;
181 }
182 return false;
183}
184
185/// Generates an initial value for a vector reduction, following the scheme
186/// given in Chapter 5 of "The Software Vectorization Handbook", where the
187/// initial scalar value is correctly embedded in the vector reduction value,
188/// and a straightforward horizontal reduction will complete the operation.
189/// Value 'r' denotes the initial value of the reduction outside the loop.
190static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
191 Value red, Value iter, Value r,
192 VectorType vtp) {
193 vector::CombiningKind kind;
194 if (!isVectorizableReduction(red, iter, kind))
195 llvm_unreachable("unknown reduction");
196 switch (kind) {
197 case vector::CombiningKind::ADD:
198 case vector::CombiningKind::XOR:
199 // Initialize reduction vector to: | 0 | .. | 0 | r |
200 return rewriter.create<vector::InsertOp>(location: loc, args&: r,
201 args: constantZero(builder&: rewriter, loc, tp: vtp),
202 args: constantIndex(builder&: rewriter, loc, i: 0));
203 case vector::CombiningKind::MUL:
204 // Initialize reduction vector to: | 1 | .. | 1 | r |
205 return rewriter.create<vector::InsertOp>(location: loc, args&: r,
206 args: constantOne(builder&: rewriter, loc, tp: vtp),
207 args: constantIndex(builder&: rewriter, loc, i: 0));
208 case vector::CombiningKind::AND:
209 case vector::CombiningKind::OR:
210 // Initialize reduction vector to: | r | .. | r | r |
211 return rewriter.create<vector::BroadcastOp>(location: loc, args&: vtp, args&: r);
212 default:
213 break;
214 }
215 llvm_unreachable("unknown reduction kind");
216}
217
218/// This method is called twice to analyze and rewrite the given subscripts.
219/// The first call (!codegen) does the analysis. Then, on success, the second
220/// call (codegen) yields the proper vector form in the output parameter
221/// vector 'idxs'. This mechanism ensures that analysis and rewriting code
222/// stay in sync. Note that the analyis part is simple because the sparsifier
223/// only generates relatively simple subscript expressions.
224///
225/// See https://llvm.org/docs/GetElementPtr.html for some background on
226/// the complications described below.
227///
228/// We need to generate a position/coordinate load from the sparse storage
229/// scheme. Narrower data types need to be zero extended before casting
230/// the value into the `index` type used for looping and indexing.
231///
232/// For the scalar case, subscripts simply zero extend narrower indices
233/// into 64-bit values before casting to an index type without a performance
234/// penalty. Indices that already are 64-bit, in theory, cannot express the
235/// full range since the LLVM backend defines addressing in terms of an
236/// unsigned pointer/signed index pair.
237static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
238 VL vl, ValueRange subs, bool codegen,
239 Value vmask, SmallVectorImpl<Value> &idxs) {
240 unsigned d = 0;
241 unsigned dim = subs.size();
242 Block *block = &forOp.getRegion().front();
243 for (auto sub : subs) {
244 bool innermost = ++d == dim;
245 // Invariant subscripts in outer dimensions simply pass through.
246 // Note that we rely on LICM to hoist loads where all subscripts
247 // are invariant in the innermost loop.
248 // Example:
249 // a[inv][i] for inv
250 if (isInvariantValue(val: sub, block)) {
251 if (innermost)
252 return false;
253 if (codegen)
254 idxs.push_back(Elt: sub);
255 continue; // success so far
256 }
257 // Invariant block arguments (including outer loop indices) in outer
258 // dimensions simply pass through. Direct loop indices in the
259 // innermost loop simply pass through as well.
260 // Example:
261 // a[i][j] for both i and j
262 if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: sub)) {
263 if (isInvariantArg(arg, block) == innermost)
264 return false;
265 if (codegen)
266 idxs.push_back(Elt: sub);
267 continue; // success so far
268 }
269 // Look under the hood of casting.
270 auto cast = sub;
271 while (true) {
272 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273 cast = icast->getOperand(idx: 0);
274 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275 cast = ecast->getOperand(idx: 0);
276 else
277 break;
278 }
279 // Since the index vector is used in a subsequent gather/scatter
280 // operations, which effectively defines an unsigned pointer + signed
281 // index, we must zero extend the vector to an index width. For 8-bit
282 // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
283 // zero extending the elements into 64-bit loses some performance since
284 // the 32-bit indexed gather/scatter is more efficient than the 64-bit
285 // index variant (if the negative 32-bit index space is unused, the
286 // enableSIMDIndex32 flag can preserve this performance). For 64-bit
287 // values, there is no good way to state that the indices are unsigned,
288 // which creates the potential of incorrect address calculations in the
289 // unlikely case we need such extremely large offsets.
290 // Example:
291 // a[ ind[i] ]
292 if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
293 if (!innermost)
294 return false;
295 if (codegen) {
296 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
297 Location loc = forOp.getLoc();
298 Value vload =
299 genVectorLoad(rewriter, loc, vl, mem: load.getMemRef(), idxs: idxs2, vmask);
300 Type etp = llvm::cast<VectorType>(Val: vload.getType()).getElementType();
301 if (!llvm::isa<IndexType>(Val: etp)) {
302 if (etp.getIntOrFloatBitWidth() < 32)
303 vload = rewriter.create<arith::ExtUIOp>(
304 location: loc, args: vectorType(vl, etp: rewriter.getI32Type()), args&: vload);
305 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
306 vload = rewriter.create<arith::ExtUIOp>(
307 location: loc, args: vectorType(vl, etp: rewriter.getI64Type()), args&: vload);
308 }
309 idxs.push_back(Elt: vload);
310 }
311 continue; // success so far
312 }
313 // Address calculation 'i = add inv, idx' (after LICM).
314 // Example:
315 // a[base + i]
316 if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
317 Value inv = load.getOperand(i: 0);
318 Value idx = load.getOperand(i: 1);
319 // Swap non-invariant.
320 if (!isInvariantValue(val: inv, block)) {
321 inv = idx;
322 idx = load.getOperand(i: 0);
323 }
324 // Inspect.
325 if (isInvariantValue(val: inv, block)) {
326 if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: idx)) {
327 if (isInvariantArg(arg, block) || !innermost)
328 return false;
329 if (codegen)
330 idxs.push_back(
331 Elt: rewriter.create<arith::AddIOp>(location: forOp.getLoc(), args&: inv, args&: idx));
332 continue; // success so far
333 }
334 }
335 }
336 return false;
337 }
338 return true;
339}
340
341#define UNAOP(xxx) \
342 if (isa<xxx>(def)) { \
343 if (codegen) \
344 vexp = rewriter.create<xxx>(loc, vx); \
345 return true; \
346 }
347
348#define TYPEDUNAOP(xxx) \
349 if (auto x = dyn_cast<xxx>(def)) { \
350 if (codegen) { \
351 VectorType vtp = vectorType(vl, x.getType()); \
352 vexp = rewriter.create<xxx>(loc, vtp, vx); \
353 } \
354 return true; \
355 }
356
357#define BINOP(xxx) \
358 if (isa<xxx>(def)) { \
359 if (codegen) \
360 vexp = rewriter.create<xxx>(loc, vx, vy); \
361 return true; \
362 }
363
364/// This method is called twice to analyze and rewrite the given expression.
365/// The first call (!codegen) does the analysis. Then, on success, the second
366/// call (codegen) yields the proper vector form in the output parameter 'vexp'.
367/// This mechanism ensures that analysis and rewriting code stay in sync. Note
368/// that the analyis part is simple because the sparsifier only generates
369/// relatively simple expressions inside the for-loops.
370static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
371 Value exp, bool codegen, Value vmask, Value &vexp) {
372 Location loc = forOp.getLoc();
373 // Reject unsupported types.
374 if (!VectorType::isValidElementType(t: exp.getType()))
375 return false;
376 // A block argument is invariant/reduction/index.
377 if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: exp)) {
378 if (arg == forOp.getInductionVar()) {
379 // We encountered a single, innermost index inside the computation,
380 // such as a[i] = i, which must convert to [i, i+1, ...].
381 if (codegen) {
382 VectorType vtp = vectorType(vl, etp: arg.getType());
383 Value veci = rewriter.create<vector::BroadcastOp>(location: loc, args&: vtp, args&: arg);
384 Value incr = rewriter.create<vector::StepOp>(location: loc, args&: vtp);
385 vexp = rewriter.create<arith::AddIOp>(location: loc, args&: veci, args&: incr);
386 }
387 return true;
388 }
389 // An invariant or reduction. In both cases, we treat this as an
390 // invariant value, and rely on later replacing and folding to
391 // construct a proper reduction chain for the latter case.
392 if (codegen)
393 vexp = genVectorInvariantValue(rewriter, vl, val: exp);
394 return true;
395 }
396 // Something defined outside the loop-body is invariant.
397 Operation *def = exp.getDefiningOp();
398 Block *block = &forOp.getRegion().front();
399 if (def->getBlock() != block) {
400 if (codegen)
401 vexp = genVectorInvariantValue(rewriter, vl, val: exp);
402 return true;
403 }
404 // Proper load operations. These are either values involved in the
405 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
406 // or coordinate values inside the computation that are now fetched from
407 // the sparse storage coordinates arrays, such as a[i] = i becomes
408 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
409 // and 'hi = lo + vl - 1'.
410 if (auto load = dyn_cast<memref::LoadOp>(Val: def)) {
411 auto subs = load.getIndices();
412 SmallVector<Value> idxs;
413 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
414 if (codegen)
415 vexp = genVectorLoad(rewriter, loc, vl, mem: load.getMemRef(), idxs, vmask);
416 return true;
417 }
418 return false;
419 }
420 // Inside loop-body unary and binary operations. Note that it would be
421 // nicer if we could somehow test and build the operations in a more
422 // concise manner than just listing them all (although this way we know
423 // for certain that they can vectorize).
424 //
425 // TODO: avoid visiting CSEs multiple times
426 //
427 if (def->getNumOperands() == 1) {
428 Value vx;
429 if (vectorizeExpr(rewriter, forOp, vl, exp: def->getOperand(idx: 0), codegen, vmask,
430 vexp&: vx)) {
431 UNAOP(math::AbsFOp)
432 UNAOP(math::AbsIOp)
433 UNAOP(math::CeilOp)
434 UNAOP(math::FloorOp)
435 UNAOP(math::SqrtOp)
436 UNAOP(math::ExpM1Op)
437 UNAOP(math::Log1pOp)
438 UNAOP(math::SinOp)
439 UNAOP(math::TanhOp)
440 UNAOP(arith::NegFOp)
441 TYPEDUNAOP(arith::TruncFOp)
442 TYPEDUNAOP(arith::ExtFOp)
443 TYPEDUNAOP(arith::FPToSIOp)
444 TYPEDUNAOP(arith::FPToUIOp)
445 TYPEDUNAOP(arith::SIToFPOp)
446 TYPEDUNAOP(arith::UIToFPOp)
447 TYPEDUNAOP(arith::ExtSIOp)
448 TYPEDUNAOP(arith::ExtUIOp)
449 TYPEDUNAOP(arith::IndexCastOp)
450 TYPEDUNAOP(arith::TruncIOp)
451 TYPEDUNAOP(arith::BitcastOp)
452 // TODO: complex?
453 }
454 } else if (def->getNumOperands() == 2) {
455 Value vx, vy;
456 if (vectorizeExpr(rewriter, forOp, vl, exp: def->getOperand(idx: 0), codegen, vmask,
457 vexp&: vx) &&
458 vectorizeExpr(rewriter, forOp, vl, exp: def->getOperand(idx: 1), codegen, vmask,
459 vexp&: vy)) {
460 // We only accept shift-by-invariant (where the same shift factor applies
461 // to all packed elements). In the vector dialect, this is still
462 // represented with an expanded vector at the right-hand-side, however,
463 // so that we do not have to special case the code generation.
464 if (isa<arith::ShLIOp>(Val: def) || isa<arith::ShRUIOp>(Val: def) ||
465 isa<arith::ShRSIOp>(Val: def)) {
466 Value shiftFactor = def->getOperand(idx: 1);
467 if (!isInvariantValue(val: shiftFactor, block))
468 return false;
469 }
470 // Generate code.
471 BINOP(arith::MulFOp)
472 BINOP(arith::MulIOp)
473 BINOP(arith::DivFOp)
474 BINOP(arith::DivSIOp)
475 BINOP(arith::DivUIOp)
476 BINOP(arith::AddFOp)
477 BINOP(arith::AddIOp)
478 BINOP(arith::SubFOp)
479 BINOP(arith::SubIOp)
480 BINOP(arith::AndIOp)
481 BINOP(arith::OrIOp)
482 BINOP(arith::XOrIOp)
483 BINOP(arith::ShLIOp)
484 BINOP(arith::ShRUIOp)
485 BINOP(arith::ShRSIOp)
486 // TODO: complex?
487 }
488 }
489 return false;
490}
491
492#undef UNAOP
493#undef TYPEDUNAOP
494#undef BINOP
495
496/// This method is called twice to analyze and rewrite the given for-loop.
497/// The first call (!codegen) does the analysis. Then, on success, the second
498/// call (codegen) rewriters the IR into vector form. This mechanism ensures
499/// that analysis and rewriting code stay in sync.
500static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
501 bool codegen) {
502 Block &block = forOp.getRegion().front();
503 // For loops with single yield statement (as below) could be generated
504 // when custom reduce is used with unary operation.
505 // for (...)
506 // yield c_0
507 if (block.getOperations().size() <= 1)
508 return false;
509
510 Location loc = forOp.getLoc();
511 scf::YieldOp yield = cast<scf::YieldOp>(Val: block.getTerminator());
512 auto &last = *++block.rbegin();
513 scf::ForOp forOpNew;
514
515 // Perform initial set up during codegen (we know that the first analysis
516 // pass was successful). For reductions, we need to construct a completely
517 // new for-loop, since the incoming and outgoing reduction type
518 // changes into SIMD form. For stores, we can simply adjust the stride
519 // and insert in the existing for-loop. In both cases, we set up a vector
520 // mask for all operations which takes care of confining vectors to
521 // the original iteration space (later cleanup loops or other
522 // optimizations can take care of those).
523 Value vmask;
524 if (codegen) {
525 Value step = constantIndex(builder&: rewriter, loc, i: vl.vectorLength);
526 if (vl.enableVLAVectorization) {
527 Value vscale =
528 rewriter.create<vector::VectorScaleOp>(location: loc, args: rewriter.getIndexType());
529 step = rewriter.create<arith::MulIOp>(location: loc, args&: vscale, args&: step);
530 }
531 if (!yield.getResults().empty()) {
532 Value init = forOp.getInitArgs()[0];
533 VectorType vtp = vectorType(vl, etp: init.getType());
534 Value vinit = genVectorReducInit(rewriter, loc, red: yield->getOperand(idx: 0),
535 iter: forOp.getRegionIterArg(index: 0), r: init, vtp);
536 forOpNew = rewriter.create<scf::ForOp>(
537 location: loc, args: forOp.getLowerBound(), args: forOp.getUpperBound(), args&: step, args&: vinit);
538 forOpNew->setAttr(
539 name: LoopEmitter::getLoopEmitterLoopAttrName(),
540 value: forOp->getAttr(name: LoopEmitter::getLoopEmitterLoopAttrName()));
541 rewriter.setInsertionPointToStart(forOpNew.getBody());
542 } else {
543 rewriter.modifyOpInPlace(root: forOp, callable: [&]() { forOp.setStep(step); });
544 rewriter.setInsertionPoint(yield);
545 }
546 vmask = genVectorMask(rewriter, loc, vl, iv: forOp.getInductionVar(),
547 lo: forOp.getLowerBound(), hi: forOp.getUpperBound(), step);
548 }
549
550 // Sparse for-loops either are terminated by a non-empty yield operation
551 // (reduction loop) or otherwise by a store operation (pararallel loop).
552 if (!yield.getResults().empty()) {
553 // Analyze/vectorize reduction.
554 if (yield->getNumOperands() != 1)
555 return false;
556 Value red = yield->getOperand(idx: 0);
557 Value iter = forOp.getRegionIterArg(index: 0);
558 vector::CombiningKind kind;
559 Value vrhs;
560 if (isVectorizableReduction(red, iter, kind) &&
561 vectorizeExpr(rewriter, forOp, vl, exp: red, codegen, vmask, vexp&: vrhs)) {
562 if (codegen) {
563 Value partial = forOpNew.getResult(i: 0);
564 Value vpass = genVectorInvariantValue(rewriter, vl, val: iter);
565 Value vred = rewriter.create<arith::SelectOp>(location: loc, args&: vmask, args&: vrhs, args&: vpass);
566 rewriter.create<scf::YieldOp>(location: loc, args&: vred);
567 rewriter.setInsertionPointAfter(forOpNew);
568 Value vres = rewriter.create<vector::ReductionOp>(location: loc, args&: kind, args&: partial);
569 // Now do some relinking (last one is not completely type safe
570 // but all bad ones are removed right away). This also folds away
571 // nop broadcast operations.
572 rewriter.replaceAllUsesWith(from: forOp.getResult(i: 0), to: vres);
573 rewriter.replaceAllUsesWith(from: forOp.getInductionVar(),
574 to: forOpNew.getInductionVar());
575 rewriter.replaceAllUsesWith(from: forOp.getRegionIterArg(index: 0),
576 to: forOpNew.getRegionIterArg(index: 0));
577 rewriter.eraseOp(op: forOp);
578 }
579 return true;
580 }
581 } else if (auto store = dyn_cast<memref::StoreOp>(Val&: last)) {
582 // Analyze/vectorize store operation.
583 auto subs = store.getIndices();
584 SmallVector<Value> idxs;
585 Value rhs = store.getValue();
586 Value vrhs;
587 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
588 vectorizeExpr(rewriter, forOp, vl, exp: rhs, codegen, vmask, vexp&: vrhs)) {
589 if (codegen) {
590 genVectorStore(rewriter, loc, mem: store.getMemRef(), idxs, vmask, rhs: vrhs);
591 rewriter.eraseOp(op: store);
592 }
593 return true;
594 }
595 }
596
597 assert(!codegen && "cannot call codegen when analysis failed");
598 return false;
599}
600
601/// Basic for-loop vectorizer.
602struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
603public:
604 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
605
606 ForOpRewriter(MLIRContext *context, unsigned vectorLength,
607 bool enableVLAVectorization, bool enableSIMDIndex32)
608 : OpRewritePattern(context), vl{.vectorLength: vectorLength, .enableVLAVectorization: enableVLAVectorization,
609 .enableSIMDIndex32: enableSIMDIndex32} {}
610
611 LogicalResult matchAndRewrite(scf::ForOp op,
612 PatternRewriter &rewriter) const override {
613 // Check for single block, unit-stride for-loop that is generated by
614 // sparsifier, which means no data dependence analysis is required,
615 // and its loop-body is very restricted in form.
616 if (!op.getRegion().hasOneBlock() || !isOneInteger(v: op.getStep()) ||
617 !op->hasAttr(name: LoopEmitter::getLoopEmitterLoopAttrName()))
618 return failure();
619 // Analyze (!codegen) and rewrite (codegen) loop-body.
620 if (vectorizeStmt(rewriter, forOp: op, vl, /*codegen=*/false) &&
621 vectorizeStmt(rewriter, forOp: op, vl, /*codegen=*/true))
622 return success();
623 return failure();
624 }
625
626private:
627 const VL vl;
628};
629
630static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
631 Value inp) {
632 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
633 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
634 if (forOp->hasAttr(name: LoopEmitter::getLoopEmitterLoopAttrName())) {
635 rewriter.replaceOp(op, newValues: redOp.getVector());
636 return success();
637 }
638 }
639 }
640 return failure();
641}
642
643/// Reduction chain cleanup.
644/// v = for { }
645/// s = vsum(v) v = for { }
646/// u = broadcast(s) -> for (v) { }
647/// for (u) { }
648struct ReducChainBroadcastRewriter
649 : public OpRewritePattern<vector::BroadcastOp> {
650public:
651 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
652
653 LogicalResult matchAndRewrite(vector::BroadcastOp op,
654 PatternRewriter &rewriter) const override {
655 return cleanReducChain(rewriter, op, inp: op.getSource());
656 }
657};
658
659/// Reduction chain cleanup.
660/// v = for { }
661/// s = vsum(v) v = for { }
662/// u = insert(s) -> for (v) { }
663/// for (u) { }
664struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
665public:
666 using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
667
668 LogicalResult matchAndRewrite(vector::InsertOp op,
669 PatternRewriter &rewriter) const override {
670 return cleanReducChain(rewriter, op, inp: op.getValueToStore());
671 }
672};
673} // namespace
674
675//===----------------------------------------------------------------------===//
676// Public method for populating vectorization rules.
677//===----------------------------------------------------------------------===//
678
679/// Populates the given patterns list with vectorization rules.
680void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
681 unsigned vectorLength,
682 bool enableVLAVectorization,
683 bool enableSIMDIndex32) {
684 assert(vectorLength > 0);
685 vector::populateVectorStepLoweringPatterns(patterns);
686 patterns.add<ForOpRewriter>(arg: patterns.getContext(), args&: vectorLength,
687 args&: enableVLAVectorization, args&: enableSIMDIndex32);
688 patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
689 arg: patterns.getContext());
690}
691

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp