1 | //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // A pass that converts loops generated by the sparsifier into a form that |
10 | // can exploit SIMD instructions of the target architecture. Note that this pass |
11 | // ensures the sparsifier can generate efficient SIMD (including ArmSVE |
12 | // support) with proper separation of concerns as far as sparsification and |
13 | // vectorization is concerned. However, this pass is not the final abstraction |
14 | // level we want, and not the general vectorizer we want either. It forms a good |
15 | // stepping stone for incremental future improvements though. |
16 | // |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | #include "Utils/CodegenUtils.h" |
20 | #include "Utils/LoopEmitter.h" |
21 | |
22 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
23 | #include "mlir/Dialect/Arith/IR/Arith.h" |
24 | #include "mlir/Dialect/Complex/IR/Complex.h" |
25 | #include "mlir/Dialect/Math/IR/Math.h" |
26 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
27 | #include "mlir/Dialect/SCF/IR/SCF.h" |
28 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
29 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
30 | #include "mlir/IR/Matchers.h" |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::sparse_tensor; |
34 | |
35 | namespace { |
36 | |
37 | /// Target SIMD properties: |
38 | /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) |
39 | /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) |
40 | /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency |
41 | struct VL { |
42 | unsigned vectorLength; |
43 | bool enableVLAVectorization; |
44 | bool enableSIMDIndex32; |
45 | }; |
46 | |
47 | /// Helper test for invariant value (defined outside given block). |
48 | static bool isInvariantValue(Value val, Block *block) { |
49 | return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block; |
50 | } |
51 | |
52 | /// Helper test for invariant argument (defined outside given block). |
53 | static bool isInvariantArg(BlockArgument arg, Block *block) { |
54 | return arg.getOwner() != block; |
55 | } |
56 | |
57 | /// Constructs vector type for element type. |
58 | static VectorType vectorType(VL vl, Type etp) { |
59 | return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization); |
60 | } |
61 | |
62 | /// Constructs vector type from a memref value. |
63 | static VectorType vectorType(VL vl, Value mem) { |
64 | return vectorType(vl, getMemRefType(mem).getElementType()); |
65 | } |
66 | |
67 | /// Constructs vector iteration mask. |
68 | static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, |
69 | Value iv, Value lo, Value hi, Value step) { |
70 | VectorType mtp = vectorType(vl, rewriter.getI1Type()); |
71 | // Special case if the vector length evenly divides the trip count (for |
72 | // example, "for i = 0, 128, 16"). A constant all-true mask is generated |
73 | // so that all subsequent masked memory operations are immediately folded |
74 | // into unconditional memory operations. |
75 | IntegerAttr loInt, hiInt, stepInt; |
76 | if (matchPattern(lo, m_Constant(&loInt)) && |
77 | matchPattern(hi, m_Constant(&hiInt)) && |
78 | matchPattern(step, m_Constant(&stepInt))) { |
79 | if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { |
80 | Value trueVal = constantI1(builder&: rewriter, loc, b: true); |
81 | return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); |
82 | } |
83 | } |
84 | // Otherwise, generate a vector mask that avoids overrunning the upperbound |
85 | // during vector execution. Here we rely on subsequent loop optimizations to |
86 | // avoid executing the mask in all iterations, for example, by splitting the |
87 | // loop into an unconditional vector loop and a scalar cleanup loop. |
88 | auto min = AffineMap::get( |
89 | /*dimCount=*/2, /*symbolCount=*/1, |
90 | results: {rewriter.getAffineSymbolExpr(position: 0), |
91 | rewriter.getAffineDimExpr(position: 0) - rewriter.getAffineDimExpr(position: 1)}, |
92 | context: rewriter.getContext()); |
93 | Value end = rewriter.createOrFold<affine::AffineMinOp>( |
94 | loc, min, ValueRange{hi, iv, step}); |
95 | return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); |
96 | } |
97 | |
98 | /// Generates a vectorized invariant. Here we rely on subsequent loop |
99 | /// optimizations to hoist the invariant broadcast out of the vector loop. |
100 | static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, |
101 | Value val) { |
102 | VectorType vtp = vectorType(vl, val.getType()); |
103 | return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); |
104 | } |
105 | |
106 | /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], |
107 | /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note |
108 | /// that the sparsifier can only generate indirect loads in |
109 | /// the last index, i.e. back(). |
110 | static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, |
111 | Value mem, ArrayRef<Value> idxs, Value vmask) { |
112 | VectorType vtp = vectorType(vl, mem); |
113 | Value pass = constantZero(rewriter, loc, vtp); |
114 | if (llvm::isa<VectorType>(Val: idxs.back().getType())) { |
115 | SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); |
116 | Value indexVec = idxs.back(); |
117 | scalarArgs.back() = constantIndex(builder&: rewriter, loc, i: 0); |
118 | return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs, |
119 | indexVec, vmask, pass); |
120 | } |
121 | return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask, |
122 | pass); |
123 | } |
124 | |
125 | /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs |
126 | /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note |
127 | /// that the sparsifier can only generate indirect stores in |
128 | /// the last index, i.e. back(). |
129 | static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, |
130 | ArrayRef<Value> idxs, Value vmask, Value rhs) { |
131 | if (llvm::isa<VectorType>(Val: idxs.back().getType())) { |
132 | SmallVector<Value> scalarArgs(idxs.begin(), idxs.end()); |
133 | Value indexVec = idxs.back(); |
134 | scalarArgs.back() = constantIndex(builder&: rewriter, loc, i: 0); |
135 | rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask, |
136 | rhs); |
137 | return; |
138 | } |
139 | rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs); |
140 | } |
141 | |
142 | /// Detects a vectorizable reduction operations and returns the |
143 | /// combining kind of reduction on success in `kind`. |
144 | static bool isVectorizableReduction(Value red, Value iter, |
145 | vector::CombiningKind &kind) { |
146 | if (auto addf = red.getDefiningOp<arith::AddFOp>()) { |
147 | kind = vector::CombiningKind::ADD; |
148 | return addf->getOperand(0) == iter || addf->getOperand(1) == iter; |
149 | } |
150 | if (auto addi = red.getDefiningOp<arith::AddIOp>()) { |
151 | kind = vector::CombiningKind::ADD; |
152 | return addi->getOperand(0) == iter || addi->getOperand(1) == iter; |
153 | } |
154 | if (auto subf = red.getDefiningOp<arith::SubFOp>()) { |
155 | kind = vector::CombiningKind::ADD; |
156 | return subf->getOperand(0) == iter; |
157 | } |
158 | if (auto subi = red.getDefiningOp<arith::SubIOp>()) { |
159 | kind = vector::CombiningKind::ADD; |
160 | return subi->getOperand(0) == iter; |
161 | } |
162 | if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { |
163 | kind = vector::CombiningKind::MUL; |
164 | return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; |
165 | } |
166 | if (auto muli = red.getDefiningOp<arith::MulIOp>()) { |
167 | kind = vector::CombiningKind::MUL; |
168 | return muli->getOperand(0) == iter || muli->getOperand(1) == iter; |
169 | } |
170 | if (auto andi = red.getDefiningOp<arith::AndIOp>()) { |
171 | kind = vector::CombiningKind::AND; |
172 | return andi->getOperand(0) == iter || andi->getOperand(1) == iter; |
173 | } |
174 | if (auto ori = red.getDefiningOp<arith::OrIOp>()) { |
175 | kind = vector::CombiningKind::OR; |
176 | return ori->getOperand(0) == iter || ori->getOperand(1) == iter; |
177 | } |
178 | if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { |
179 | kind = vector::CombiningKind::XOR; |
180 | return xori->getOperand(0) == iter || xori->getOperand(1) == iter; |
181 | } |
182 | return false; |
183 | } |
184 | |
185 | /// Generates an initial value for a vector reduction, following the scheme |
186 | /// given in Chapter 5 of "The Software Vectorization Handbook", where the |
187 | /// initial scalar value is correctly embedded in the vector reduction value, |
188 | /// and a straightforward horizontal reduction will complete the operation. |
189 | /// Value 'r' denotes the initial value of the reduction outside the loop. |
190 | static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, |
191 | Value red, Value iter, Value r, |
192 | VectorType vtp) { |
193 | vector::CombiningKind kind; |
194 | if (!isVectorizableReduction(red, iter, kind)) |
195 | llvm_unreachable("unknown reduction" ); |
196 | switch (kind) { |
197 | case vector::CombiningKind::ADD: |
198 | case vector::CombiningKind::XOR: |
199 | // Initialize reduction vector to: | 0 | .. | 0 | r | |
200 | return rewriter.create<vector::InsertElementOp>( |
201 | loc, r, constantZero(rewriter, loc, vtp), |
202 | constantIndex(rewriter, loc, 0)); |
203 | case vector::CombiningKind::MUL: |
204 | // Initialize reduction vector to: | 1 | .. | 1 | r | |
205 | return rewriter.create<vector::InsertElementOp>( |
206 | loc, r, constantOne(rewriter, loc, vtp), |
207 | constantIndex(rewriter, loc, 0)); |
208 | case vector::CombiningKind::AND: |
209 | case vector::CombiningKind::OR: |
210 | // Initialize reduction vector to: | r | .. | r | r | |
211 | return rewriter.create<vector::BroadcastOp>(loc, vtp, r); |
212 | default: |
213 | break; |
214 | } |
215 | llvm_unreachable("unknown reduction kind" ); |
216 | } |
217 | |
218 | /// This method is called twice to analyze and rewrite the given subscripts. |
219 | /// The first call (!codegen) does the analysis. Then, on success, the second |
220 | /// call (codegen) yields the proper vector form in the output parameter |
221 | /// vector 'idxs'. This mechanism ensures that analysis and rewriting code |
222 | /// stay in sync. Note that the analyis part is simple because the sparsifier |
223 | /// only generates relatively simple subscript expressions. |
224 | /// |
225 | /// See https://llvm.org/docs/GetElementPtr.html for some background on |
226 | /// the complications described below. |
227 | /// |
228 | /// We need to generate a position/coordinate load from the sparse storage |
229 | /// scheme. Narrower data types need to be zero extended before casting |
230 | /// the value into the `index` type used for looping and indexing. |
231 | /// |
232 | /// For the scalar case, subscripts simply zero extend narrower indices |
233 | /// into 64-bit values before casting to an index type without a performance |
234 | /// penalty. Indices that already are 64-bit, in theory, cannot express the |
235 | /// full range since the LLVM backend defines addressing in terms of an |
236 | /// unsigned pointer/signed index pair. |
237 | static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, |
238 | VL vl, ValueRange subs, bool codegen, |
239 | Value vmask, SmallVectorImpl<Value> &idxs) { |
240 | unsigned d = 0; |
241 | unsigned dim = subs.size(); |
242 | Block *block = &forOp.getRegion().front(); |
243 | for (auto sub : subs) { |
244 | bool innermost = ++d == dim; |
245 | // Invariant subscripts in outer dimensions simply pass through. |
246 | // Note that we rely on LICM to hoist loads where all subscripts |
247 | // are invariant in the innermost loop. |
248 | // Example: |
249 | // a[inv][i] for inv |
250 | if (isInvariantValue(val: sub, block)) { |
251 | if (innermost) |
252 | return false; |
253 | if (codegen) |
254 | idxs.push_back(Elt: sub); |
255 | continue; // success so far |
256 | } |
257 | // Invariant block arguments (including outer loop indices) in outer |
258 | // dimensions simply pass through. Direct loop indices in the |
259 | // innermost loop simply pass through as well. |
260 | // Example: |
261 | // a[i][j] for both i and j |
262 | if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: sub)) { |
263 | if (isInvariantArg(arg, block) == innermost) |
264 | return false; |
265 | if (codegen) |
266 | idxs.push_back(Elt: sub); |
267 | continue; // success so far |
268 | } |
269 | // Look under the hood of casting. |
270 | auto cast = sub; |
271 | while (true) { |
272 | if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) |
273 | cast = icast->getOperand(0); |
274 | else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) |
275 | cast = ecast->getOperand(0); |
276 | else |
277 | break; |
278 | } |
279 | // Since the index vector is used in a subsequent gather/scatter |
280 | // operations, which effectively defines an unsigned pointer + signed |
281 | // index, we must zero extend the vector to an index width. For 8-bit |
282 | // and 16-bit values, an 32-bit index width suffices. For 32-bit values, |
283 | // zero extending the elements into 64-bit loses some performance since |
284 | // the 32-bit indexed gather/scatter is more efficient than the 64-bit |
285 | // index variant (if the negative 32-bit index space is unused, the |
286 | // enableSIMDIndex32 flag can preserve this performance). For 64-bit |
287 | // values, there is no good way to state that the indices are unsigned, |
288 | // which creates the potential of incorrect address calculations in the |
289 | // unlikely case we need such extremely large offsets. |
290 | // Example: |
291 | // a[ ind[i] ] |
292 | if (auto load = cast.getDefiningOp<memref::LoadOp>()) { |
293 | if (!innermost) |
294 | return false; |
295 | if (codegen) { |
296 | SmallVector<Value> idxs2(load.getIndices()); // no need to analyze |
297 | Location loc = forOp.getLoc(); |
298 | Value vload = |
299 | genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); |
300 | Type etp = llvm::cast<VectorType>(vload.getType()).getElementType(); |
301 | if (!llvm::isa<IndexType>(Val: etp)) { |
302 | if (etp.getIntOrFloatBitWidth() < 32) |
303 | vload = rewriter.create<arith::ExtUIOp>( |
304 | loc, vectorType(vl, rewriter.getI32Type()), vload); |
305 | else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) |
306 | vload = rewriter.create<arith::ExtUIOp>( |
307 | loc, vectorType(vl, rewriter.getI64Type()), vload); |
308 | } |
309 | idxs.push_back(Elt: vload); |
310 | } |
311 | continue; // success so far |
312 | } |
313 | // Address calculation 'i = add inv, idx' (after LICM). |
314 | // Example: |
315 | // a[base + i] |
316 | if (auto load = cast.getDefiningOp<arith::AddIOp>()) { |
317 | Value inv = load.getOperand(0); |
318 | Value idx = load.getOperand(1); |
319 | // Swap non-invariant. |
320 | if (!isInvariantValue(val: inv, block)) { |
321 | inv = idx; |
322 | idx = load.getOperand(0); |
323 | } |
324 | // Inspect. |
325 | if (isInvariantValue(val: inv, block)) { |
326 | if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) { |
327 | if (isInvariantArg(arg, block) || !innermost) |
328 | return false; |
329 | if (codegen) |
330 | idxs.push_back( |
331 | rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx)); |
332 | continue; // success so far |
333 | } |
334 | } |
335 | } |
336 | return false; |
337 | } |
338 | return true; |
339 | } |
340 | |
341 | #define UNAOP(xxx) \ |
342 | if (isa<xxx>(def)) { \ |
343 | if (codegen) \ |
344 | vexp = rewriter.create<xxx>(loc, vx); \ |
345 | return true; \ |
346 | } |
347 | |
348 | #define TYPEDUNAOP(xxx) \ |
349 | if (auto x = dyn_cast<xxx>(def)) { \ |
350 | if (codegen) { \ |
351 | VectorType vtp = vectorType(vl, x.getType()); \ |
352 | vexp = rewriter.create<xxx>(loc, vtp, vx); \ |
353 | } \ |
354 | return true; \ |
355 | } |
356 | |
357 | #define BINOP(xxx) \ |
358 | if (isa<xxx>(def)) { \ |
359 | if (codegen) \ |
360 | vexp = rewriter.create<xxx>(loc, vx, vy); \ |
361 | return true; \ |
362 | } |
363 | |
364 | /// This method is called twice to analyze and rewrite the given expression. |
365 | /// The first call (!codegen) does the analysis. Then, on success, the second |
366 | /// call (codegen) yields the proper vector form in the output parameter 'vexp'. |
367 | /// This mechanism ensures that analysis and rewriting code stay in sync. Note |
368 | /// that the analyis part is simple because the sparsifier only generates |
369 | /// relatively simple expressions inside the for-loops. |
370 | static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, |
371 | Value exp, bool codegen, Value vmask, Value &vexp) { |
372 | Location loc = forOp.getLoc(); |
373 | // Reject unsupported types. |
374 | if (!VectorType::isValidElementType(exp.getType())) |
375 | return false; |
376 | // A block argument is invariant/reduction/index. |
377 | if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: exp)) { |
378 | if (arg == forOp.getInductionVar()) { |
379 | // We encountered a single, innermost index inside the computation, |
380 | // such as a[i] = i, which must convert to [i, i+1, ...]. |
381 | if (codegen) { |
382 | VectorType vtp = vectorType(vl, arg.getType()); |
383 | Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); |
384 | Value incr; |
385 | if (vl.enableVLAVectorization) { |
386 | Type stepvty = vectorType(vl, rewriter.getI64Type()); |
387 | Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty); |
388 | incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv); |
389 | } else { |
390 | SmallVector<APInt> integers; |
391 | for (unsigned i = 0, l = vl.vectorLength; i < l; i++) |
392 | integers.push_back(Elt: APInt(/*width=*/64, i)); |
393 | auto values = DenseElementsAttr::get(vtp, integers); |
394 | incr = rewriter.create<arith::ConstantOp>(loc, vtp, values); |
395 | } |
396 | vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); |
397 | } |
398 | return true; |
399 | } |
400 | // An invariant or reduction. In both cases, we treat this as an |
401 | // invariant value, and rely on later replacing and folding to |
402 | // construct a proper reduction chain for the latter case. |
403 | if (codegen) |
404 | vexp = genVectorInvariantValue(rewriter, vl, val: exp); |
405 | return true; |
406 | } |
407 | // Something defined outside the loop-body is invariant. |
408 | Operation *def = exp.getDefiningOp(); |
409 | Block *block = &forOp.getRegion().front(); |
410 | if (def->getBlock() != block) { |
411 | if (codegen) |
412 | vexp = genVectorInvariantValue(rewriter, vl, val: exp); |
413 | return true; |
414 | } |
415 | // Proper load operations. These are either values involved in the |
416 | // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], |
417 | // or coordinate values inside the computation that are now fetched from |
418 | // the sparse storage coordinates arrays, such as a[i] = i becomes |
419 | // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index |
420 | // and 'hi = lo + vl - 1'. |
421 | if (auto load = dyn_cast<memref::LoadOp>(def)) { |
422 | auto subs = load.getIndices(); |
423 | SmallVector<Value> idxs; |
424 | if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { |
425 | if (codegen) |
426 | vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); |
427 | return true; |
428 | } |
429 | return false; |
430 | } |
431 | // Inside loop-body unary and binary operations. Note that it would be |
432 | // nicer if we could somehow test and build the operations in a more |
433 | // concise manner than just listing them all (although this way we know |
434 | // for certain that they can vectorize). |
435 | // |
436 | // TODO: avoid visiting CSEs multiple times |
437 | // |
438 | if (def->getNumOperands() == 1) { |
439 | Value vx; |
440 | if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(idx: 0), codegen, vmask, |
441 | vx)) { |
442 | UNAOP(math::AbsFOp) |
443 | UNAOP(math::AbsIOp) |
444 | UNAOP(math::CeilOp) |
445 | UNAOP(math::FloorOp) |
446 | UNAOP(math::SqrtOp) |
447 | UNAOP(math::ExpM1Op) |
448 | UNAOP(math::Log1pOp) |
449 | UNAOP(math::SinOp) |
450 | UNAOP(math::TanhOp) |
451 | UNAOP(arith::NegFOp) |
452 | TYPEDUNAOP(arith::TruncFOp) |
453 | TYPEDUNAOP(arith::ExtFOp) |
454 | TYPEDUNAOP(arith::FPToSIOp) |
455 | TYPEDUNAOP(arith::FPToUIOp) |
456 | TYPEDUNAOP(arith::SIToFPOp) |
457 | TYPEDUNAOP(arith::UIToFPOp) |
458 | TYPEDUNAOP(arith::ExtSIOp) |
459 | TYPEDUNAOP(arith::ExtUIOp) |
460 | TYPEDUNAOP(arith::IndexCastOp) |
461 | TYPEDUNAOP(arith::TruncIOp) |
462 | TYPEDUNAOP(arith::BitcastOp) |
463 | // TODO: complex? |
464 | } |
465 | } else if (def->getNumOperands() == 2) { |
466 | Value vx, vy; |
467 | if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, |
468 | vx) && |
469 | vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, |
470 | vy)) { |
471 | // We only accept shift-by-invariant (where the same shift factor applies |
472 | // to all packed elements). In the vector dialect, this is still |
473 | // represented with an expanded vector at the right-hand-side, however, |
474 | // so that we do not have to special case the code generation. |
475 | if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) || |
476 | isa<arith::ShRSIOp>(def)) { |
477 | Value shiftFactor = def->getOperand(idx: 1); |
478 | if (!isInvariantValue(val: shiftFactor, block)) |
479 | return false; |
480 | } |
481 | // Generate code. |
482 | BINOP(arith::MulFOp) |
483 | BINOP(arith::MulIOp) |
484 | BINOP(arith::DivFOp) |
485 | BINOP(arith::DivSIOp) |
486 | BINOP(arith::DivUIOp) |
487 | BINOP(arith::AddFOp) |
488 | BINOP(arith::AddIOp) |
489 | BINOP(arith::SubFOp) |
490 | BINOP(arith::SubIOp) |
491 | BINOP(arith::AndIOp) |
492 | BINOP(arith::OrIOp) |
493 | BINOP(arith::XOrIOp) |
494 | BINOP(arith::ShLIOp) |
495 | BINOP(arith::ShRUIOp) |
496 | BINOP(arith::ShRSIOp) |
497 | // TODO: complex? |
498 | } |
499 | } |
500 | return false; |
501 | } |
502 | |
503 | #undef UNAOP |
504 | #undef TYPEDUNAOP |
505 | #undef BINOP |
506 | |
507 | /// This method is called twice to analyze and rewrite the given for-loop. |
508 | /// The first call (!codegen) does the analysis. Then, on success, the second |
509 | /// call (codegen) rewriters the IR into vector form. This mechanism ensures |
510 | /// that analysis and rewriting code stay in sync. |
511 | static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, |
512 | bool codegen) { |
513 | Block &block = forOp.getRegion().front(); |
514 | // For loops with single yield statement (as below) could be generated |
515 | // when custom reduce is used with unary operation. |
516 | // for (...) |
517 | // yield c_0 |
518 | if (block.getOperations().size() <= 1) |
519 | return false; |
520 | |
521 | Location loc = forOp.getLoc(); |
522 | scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); |
523 | auto &last = *++block.rbegin(); |
524 | scf::ForOp forOpNew; |
525 | |
526 | // Perform initial set up during codegen (we know that the first analysis |
527 | // pass was successful). For reductions, we need to construct a completely |
528 | // new for-loop, since the incoming and outgoing reduction type |
529 | // changes into SIMD form. For stores, we can simply adjust the stride |
530 | // and insert in the existing for-loop. In both cases, we set up a vector |
531 | // mask for all operations which takes care of confining vectors to |
532 | // the original iteration space (later cleanup loops or other |
533 | // optimizations can take care of those). |
534 | Value vmask; |
535 | if (codegen) { |
536 | Value step = constantIndex(builder&: rewriter, loc, i: vl.vectorLength); |
537 | if (vl.enableVLAVectorization) { |
538 | Value vscale = |
539 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
540 | step = rewriter.create<arith::MulIOp>(loc, vscale, step); |
541 | } |
542 | if (!yield.getResults().empty()) { |
543 | Value init = forOp.getInitArgs()[0]; |
544 | VectorType vtp = vectorType(vl, init.getType()); |
545 | Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), |
546 | forOp.getRegionIterArg(0), init, vtp); |
547 | forOpNew = rewriter.create<scf::ForOp>( |
548 | loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); |
549 | forOpNew->setAttr( |
550 | LoopEmitter::getLoopEmitterLoopAttrName(), |
551 | forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); |
552 | rewriter.setInsertionPointToStart(forOpNew.getBody()); |
553 | } else { |
554 | rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); }); |
555 | rewriter.setInsertionPoint(yield); |
556 | } |
557 | vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), |
558 | forOp.getLowerBound(), forOp.getUpperBound(), step); |
559 | } |
560 | |
561 | // Sparse for-loops either are terminated by a non-empty yield operation |
562 | // (reduction loop) or otherwise by a store operation (pararallel loop). |
563 | if (!yield.getResults().empty()) { |
564 | // Analyze/vectorize reduction. |
565 | if (yield->getNumOperands() != 1) |
566 | return false; |
567 | Value red = yield->getOperand(0); |
568 | Value iter = forOp.getRegionIterArg(0); |
569 | vector::CombiningKind kind; |
570 | Value vrhs; |
571 | if (isVectorizableReduction(red, iter, kind) && |
572 | vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { |
573 | if (codegen) { |
574 | Value partial = forOpNew.getResult(0); |
575 | Value vpass = genVectorInvariantValue(rewriter, vl, val: iter); |
576 | Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); |
577 | rewriter.create<scf::YieldOp>(loc, vred); |
578 | rewriter.setInsertionPointAfter(forOpNew); |
579 | Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); |
580 | // Now do some relinking (last one is not completely type safe |
581 | // but all bad ones are removed right away). This also folds away |
582 | // nop broadcast operations. |
583 | rewriter.replaceAllUsesWith(forOp.getResult(0), vres); |
584 | rewriter.replaceAllUsesWith(forOp.getInductionVar(), |
585 | forOpNew.getInductionVar()); |
586 | rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0), |
587 | forOpNew.getRegionIterArg(0)); |
588 | rewriter.eraseOp(op: forOp); |
589 | } |
590 | return true; |
591 | } |
592 | } else if (auto store = dyn_cast<memref::StoreOp>(last)) { |
593 | // Analyze/vectorize store operation. |
594 | auto subs = store.getIndices(); |
595 | SmallVector<Value> idxs; |
596 | Value rhs = store.getValue(); |
597 | Value vrhs; |
598 | if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && |
599 | vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { |
600 | if (codegen) { |
601 | genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); |
602 | rewriter.eraseOp(op: store); |
603 | } |
604 | return true; |
605 | } |
606 | } |
607 | |
608 | assert(!codegen && "cannot call codegen when analysis failed" ); |
609 | return false; |
610 | } |
611 | |
612 | /// Basic for-loop vectorizer. |
613 | struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { |
614 | public: |
615 | using OpRewritePattern<scf::ForOp>::OpRewritePattern; |
616 | |
617 | ForOpRewriter(MLIRContext *context, unsigned vectorLength, |
618 | bool enableVLAVectorization, bool enableSIMDIndex32) |
619 | : OpRewritePattern(context), vl{.vectorLength: vectorLength, .enableVLAVectorization: enableVLAVectorization, |
620 | .enableSIMDIndex32: enableSIMDIndex32} {} |
621 | |
622 | LogicalResult matchAndRewrite(scf::ForOp op, |
623 | PatternRewriter &rewriter) const override { |
624 | // Check for single block, unit-stride for-loop that is generated by |
625 | // sparsifier, which means no data dependence analysis is required, |
626 | // and its loop-body is very restricted in form. |
627 | if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) || |
628 | !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) |
629 | return failure(); |
630 | // Analyze (!codegen) and rewrite (codegen) loop-body. |
631 | if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && |
632 | vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) |
633 | return success(); |
634 | return failure(); |
635 | } |
636 | |
637 | private: |
638 | const VL vl; |
639 | }; |
640 | |
641 | /// Reduction chain cleanup. |
642 | /// v = for { } |
643 | /// s = vsum(v) v = for { } |
644 | /// u = expand(s) -> for (v) { } |
645 | /// for (u) { } |
646 | template <typename VectorOp> |
647 | struct ReducChainRewriter : public OpRewritePattern<VectorOp> { |
648 | public: |
649 | using OpRewritePattern<VectorOp>::OpRewritePattern; |
650 | |
651 | LogicalResult matchAndRewrite(VectorOp op, |
652 | PatternRewriter &rewriter) const override { |
653 | Value inp = op.getSource(); |
654 | if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { |
655 | if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { |
656 | if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { |
657 | rewriter.replaceOp(op, redOp.getVector()); |
658 | return success(); |
659 | } |
660 | } |
661 | } |
662 | return failure(); |
663 | } |
664 | }; |
665 | |
666 | } // namespace |
667 | |
668 | //===----------------------------------------------------------------------===// |
669 | // Public method for populating vectorization rules. |
670 | //===----------------------------------------------------------------------===// |
671 | |
672 | /// Populates the given patterns list with vectorization rules. |
673 | void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, |
674 | unsigned vectorLength, |
675 | bool enableVLAVectorization, |
676 | bool enableSIMDIndex32) { |
677 | assert(vectorLength > 0); |
678 | patterns.add<ForOpRewriter>(arg: patterns.getContext(), args&: vectorLength, |
679 | args&: enableVLAVectorization, args&: enableSIMDIndex32); |
680 | patterns.add<ReducChainRewriter<vector::InsertElementOp>, |
681 | ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); |
682 | } |
683 | |