| 1 | //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // A pass that converts loops generated by the sparsifier into a form that |
| 10 | // can exploit SIMD instructions of the target architecture. Note that this pass |
| 11 | // ensures the sparsifier can generate efficient SIMD (including ArmSVE |
| 12 | // support) with proper separation of concerns as far as sparsification and |
| 13 | // vectorization is concerned. However, this pass is not the final abstraction |
| 14 | // level we want, and not the general vectorizer we want either. It forms a good |
| 15 | // stepping stone for incremental future improvements though. |
| 16 | // |
| 17 | //===----------------------------------------------------------------------===// |
| 18 | |
| 19 | #include "Utils/CodegenUtils.h" |
| 20 | #include "Utils/LoopEmitter.h" |
| 21 | |
| 22 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 23 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 24 | #include "mlir/Dialect/Complex/IR/Complex.h" |
| 25 | #include "mlir/Dialect/Math/IR/Math.h" |
| 26 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 27 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 28 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 29 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 30 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 31 | #include "mlir/IR/Matchers.h" |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace mlir::sparse_tensor; |
| 35 | |
| 36 | namespace { |
| 37 | |
| 38 | /// Target SIMD properties: |
| 39 | /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) |
| 40 | /// enableVLAVectorization: enables scalable vectors (viz. ARMSve) |
| 41 | /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency |
| 42 | struct VL { |
| 43 | unsigned vectorLength; |
| 44 | bool enableVLAVectorization; |
| 45 | bool enableSIMDIndex32; |
| 46 | }; |
| 47 | |
| 48 | /// Helper test for invariant value (defined outside given block). |
| 49 | static bool isInvariantValue(Value val, Block *block) { |
| 50 | return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block; |
| 51 | } |
| 52 | |
| 53 | /// Helper test for invariant argument (defined outside given block). |
| 54 | static bool isInvariantArg(BlockArgument arg, Block *block) { |
| 55 | return arg.getOwner() != block; |
| 56 | } |
| 57 | |
| 58 | /// Constructs vector type for element type. |
| 59 | static VectorType vectorType(VL vl, Type etp) { |
| 60 | return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization); |
| 61 | } |
| 62 | |
| 63 | /// Constructs vector type from a memref value. |
| 64 | static VectorType vectorType(VL vl, Value mem) { |
| 65 | return vectorType(vl, getMemRefType(mem).getElementType()); |
| 66 | } |
| 67 | |
| 68 | /// Constructs vector iteration mask. |
| 69 | static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, |
| 70 | Value iv, Value lo, Value hi, Value step) { |
| 71 | VectorType mtp = vectorType(vl, rewriter.getI1Type()); |
| 72 | // Special case if the vector length evenly divides the trip count (for |
| 73 | // example, "for i = 0, 128, 16"). A constant all-true mask is generated |
| 74 | // so that all subsequent masked memory operations are immediately folded |
| 75 | // into unconditional memory operations. |
| 76 | IntegerAttr loInt, hiInt, stepInt; |
| 77 | if (matchPattern(lo, m_Constant(&loInt)) && |
| 78 | matchPattern(hi, m_Constant(&hiInt)) && |
| 79 | matchPattern(step, m_Constant(&stepInt))) { |
| 80 | if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { |
| 81 | Value trueVal = constantI1(builder&: rewriter, loc, b: true); |
| 82 | return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal); |
| 83 | } |
| 84 | } |
| 85 | // Otherwise, generate a vector mask that avoids overrunning the upperbound |
| 86 | // during vector execution. Here we rely on subsequent loop optimizations to |
| 87 | // avoid executing the mask in all iterations, for example, by splitting the |
| 88 | // loop into an unconditional vector loop and a scalar cleanup loop. |
| 89 | auto min = AffineMap::get( |
| 90 | /*dimCount=*/2, /*symbolCount=*/1, |
| 91 | results: {rewriter.getAffineSymbolExpr(position: 0), |
| 92 | rewriter.getAffineDimExpr(position: 0) - rewriter.getAffineDimExpr(position: 1)}, |
| 93 | context: rewriter.getContext()); |
| 94 | Value end = rewriter.createOrFold<affine::AffineMinOp>( |
| 95 | loc, min, ValueRange{hi, iv, step}); |
| 96 | return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); |
| 97 | } |
| 98 | |
| 99 | /// Generates a vectorized invariant. Here we rely on subsequent loop |
| 100 | /// optimizations to hoist the invariant broadcast out of the vector loop. |
| 101 | static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, |
| 102 | Value val) { |
| 103 | VectorType vtp = vectorType(vl, val.getType()); |
| 104 | return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); |
| 105 | } |
| 106 | |
| 107 | /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], |
| 108 | /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note |
| 109 | /// that the sparsifier can only generate indirect loads in |
| 110 | /// the last index, i.e. back(). |
| 111 | static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, |
| 112 | Value mem, ArrayRef<Value> idxs, Value vmask) { |
| 113 | VectorType vtp = vectorType(vl, mem); |
| 114 | Value pass = constantZero(rewriter, loc, vtp); |
| 115 | if (llvm::isa<VectorType>(Val: idxs.back().getType())) { |
| 116 | SmallVector<Value> scalarArgs(idxs); |
| 117 | Value indexVec = idxs.back(); |
| 118 | scalarArgs.back() = constantIndex(builder&: rewriter, loc, i: 0); |
| 119 | return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs, |
| 120 | indexVec, vmask, pass); |
| 121 | } |
| 122 | return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask, |
| 123 | pass); |
| 124 | } |
| 125 | |
| 126 | /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs |
| 127 | /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note |
| 128 | /// that the sparsifier can only generate indirect stores in |
| 129 | /// the last index, i.e. back(). |
| 130 | static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, |
| 131 | ArrayRef<Value> idxs, Value vmask, Value rhs) { |
| 132 | if (llvm::isa<VectorType>(Val: idxs.back().getType())) { |
| 133 | SmallVector<Value> scalarArgs(idxs); |
| 134 | Value indexVec = idxs.back(); |
| 135 | scalarArgs.back() = constantIndex(builder&: rewriter, loc, i: 0); |
| 136 | rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask, |
| 137 | rhs); |
| 138 | return; |
| 139 | } |
| 140 | rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs); |
| 141 | } |
| 142 | |
| 143 | /// Detects a vectorizable reduction operations and returns the |
| 144 | /// combining kind of reduction on success in `kind`. |
| 145 | static bool isVectorizableReduction(Value red, Value iter, |
| 146 | vector::CombiningKind &kind) { |
| 147 | if (auto addf = red.getDefiningOp<arith::AddFOp>()) { |
| 148 | kind = vector::CombiningKind::ADD; |
| 149 | return addf->getOperand(0) == iter || addf->getOperand(1) == iter; |
| 150 | } |
| 151 | if (auto addi = red.getDefiningOp<arith::AddIOp>()) { |
| 152 | kind = vector::CombiningKind::ADD; |
| 153 | return addi->getOperand(0) == iter || addi->getOperand(1) == iter; |
| 154 | } |
| 155 | if (auto subf = red.getDefiningOp<arith::SubFOp>()) { |
| 156 | kind = vector::CombiningKind::ADD; |
| 157 | return subf->getOperand(0) == iter; |
| 158 | } |
| 159 | if (auto subi = red.getDefiningOp<arith::SubIOp>()) { |
| 160 | kind = vector::CombiningKind::ADD; |
| 161 | return subi->getOperand(0) == iter; |
| 162 | } |
| 163 | if (auto mulf = red.getDefiningOp<arith::MulFOp>()) { |
| 164 | kind = vector::CombiningKind::MUL; |
| 165 | return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; |
| 166 | } |
| 167 | if (auto muli = red.getDefiningOp<arith::MulIOp>()) { |
| 168 | kind = vector::CombiningKind::MUL; |
| 169 | return muli->getOperand(0) == iter || muli->getOperand(1) == iter; |
| 170 | } |
| 171 | if (auto andi = red.getDefiningOp<arith::AndIOp>()) { |
| 172 | kind = vector::CombiningKind::AND; |
| 173 | return andi->getOperand(0) == iter || andi->getOperand(1) == iter; |
| 174 | } |
| 175 | if (auto ori = red.getDefiningOp<arith::OrIOp>()) { |
| 176 | kind = vector::CombiningKind::OR; |
| 177 | return ori->getOperand(0) == iter || ori->getOperand(1) == iter; |
| 178 | } |
| 179 | if (auto xori = red.getDefiningOp<arith::XOrIOp>()) { |
| 180 | kind = vector::CombiningKind::XOR; |
| 181 | return xori->getOperand(0) == iter || xori->getOperand(1) == iter; |
| 182 | } |
| 183 | return false; |
| 184 | } |
| 185 | |
| 186 | /// Generates an initial value for a vector reduction, following the scheme |
| 187 | /// given in Chapter 5 of "The Software Vectorization Handbook", where the |
| 188 | /// initial scalar value is correctly embedded in the vector reduction value, |
| 189 | /// and a straightforward horizontal reduction will complete the operation. |
| 190 | /// Value 'r' denotes the initial value of the reduction outside the loop. |
| 191 | static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, |
| 192 | Value red, Value iter, Value r, |
| 193 | VectorType vtp) { |
| 194 | vector::CombiningKind kind; |
| 195 | if (!isVectorizableReduction(red, iter, kind)) |
| 196 | llvm_unreachable("unknown reduction" ); |
| 197 | switch (kind) { |
| 198 | case vector::CombiningKind::ADD: |
| 199 | case vector::CombiningKind::XOR: |
| 200 | // Initialize reduction vector to: | 0 | .. | 0 | r | |
| 201 | return rewriter.create<vector::InsertElementOp>( |
| 202 | loc, r, constantZero(rewriter, loc, vtp), |
| 203 | constantIndex(rewriter, loc, 0)); |
| 204 | case vector::CombiningKind::MUL: |
| 205 | // Initialize reduction vector to: | 1 | .. | 1 | r | |
| 206 | return rewriter.create<vector::InsertElementOp>( |
| 207 | loc, r, constantOne(rewriter, loc, vtp), |
| 208 | constantIndex(rewriter, loc, 0)); |
| 209 | case vector::CombiningKind::AND: |
| 210 | case vector::CombiningKind::OR: |
| 211 | // Initialize reduction vector to: | r | .. | r | r | |
| 212 | return rewriter.create<vector::BroadcastOp>(loc, vtp, r); |
| 213 | default: |
| 214 | break; |
| 215 | } |
| 216 | llvm_unreachable("unknown reduction kind" ); |
| 217 | } |
| 218 | |
| 219 | /// This method is called twice to analyze and rewrite the given subscripts. |
| 220 | /// The first call (!codegen) does the analysis. Then, on success, the second |
| 221 | /// call (codegen) yields the proper vector form in the output parameter |
| 222 | /// vector 'idxs'. This mechanism ensures that analysis and rewriting code |
| 223 | /// stay in sync. Note that the analyis part is simple because the sparsifier |
| 224 | /// only generates relatively simple subscript expressions. |
| 225 | /// |
| 226 | /// See https://llvm.org/docs/GetElementPtr.html for some background on |
| 227 | /// the complications described below. |
| 228 | /// |
| 229 | /// We need to generate a position/coordinate load from the sparse storage |
| 230 | /// scheme. Narrower data types need to be zero extended before casting |
| 231 | /// the value into the `index` type used for looping and indexing. |
| 232 | /// |
| 233 | /// For the scalar case, subscripts simply zero extend narrower indices |
| 234 | /// into 64-bit values before casting to an index type without a performance |
| 235 | /// penalty. Indices that already are 64-bit, in theory, cannot express the |
| 236 | /// full range since the LLVM backend defines addressing in terms of an |
| 237 | /// unsigned pointer/signed index pair. |
| 238 | static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, |
| 239 | VL vl, ValueRange subs, bool codegen, |
| 240 | Value vmask, SmallVectorImpl<Value> &idxs) { |
| 241 | unsigned d = 0; |
| 242 | unsigned dim = subs.size(); |
| 243 | Block *block = &forOp.getRegion().front(); |
| 244 | for (auto sub : subs) { |
| 245 | bool innermost = ++d == dim; |
| 246 | // Invariant subscripts in outer dimensions simply pass through. |
| 247 | // Note that we rely on LICM to hoist loads where all subscripts |
| 248 | // are invariant in the innermost loop. |
| 249 | // Example: |
| 250 | // a[inv][i] for inv |
| 251 | if (isInvariantValue(val: sub, block)) { |
| 252 | if (innermost) |
| 253 | return false; |
| 254 | if (codegen) |
| 255 | idxs.push_back(Elt: sub); |
| 256 | continue; // success so far |
| 257 | } |
| 258 | // Invariant block arguments (including outer loop indices) in outer |
| 259 | // dimensions simply pass through. Direct loop indices in the |
| 260 | // innermost loop simply pass through as well. |
| 261 | // Example: |
| 262 | // a[i][j] for both i and j |
| 263 | if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: sub)) { |
| 264 | if (isInvariantArg(arg, block) == innermost) |
| 265 | return false; |
| 266 | if (codegen) |
| 267 | idxs.push_back(Elt: sub); |
| 268 | continue; // success so far |
| 269 | } |
| 270 | // Look under the hood of casting. |
| 271 | auto cast = sub; |
| 272 | while (true) { |
| 273 | if (auto icast = cast.getDefiningOp<arith::IndexCastOp>()) |
| 274 | cast = icast->getOperand(0); |
| 275 | else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>()) |
| 276 | cast = ecast->getOperand(0); |
| 277 | else |
| 278 | break; |
| 279 | } |
| 280 | // Since the index vector is used in a subsequent gather/scatter |
| 281 | // operations, which effectively defines an unsigned pointer + signed |
| 282 | // index, we must zero extend the vector to an index width. For 8-bit |
| 283 | // and 16-bit values, an 32-bit index width suffices. For 32-bit values, |
| 284 | // zero extending the elements into 64-bit loses some performance since |
| 285 | // the 32-bit indexed gather/scatter is more efficient than the 64-bit |
| 286 | // index variant (if the negative 32-bit index space is unused, the |
| 287 | // enableSIMDIndex32 flag can preserve this performance). For 64-bit |
| 288 | // values, there is no good way to state that the indices are unsigned, |
| 289 | // which creates the potential of incorrect address calculations in the |
| 290 | // unlikely case we need such extremely large offsets. |
| 291 | // Example: |
| 292 | // a[ ind[i] ] |
| 293 | if (auto load = cast.getDefiningOp<memref::LoadOp>()) { |
| 294 | if (!innermost) |
| 295 | return false; |
| 296 | if (codegen) { |
| 297 | SmallVector<Value> idxs2(load.getIndices()); // no need to analyze |
| 298 | Location loc = forOp.getLoc(); |
| 299 | Value vload = |
| 300 | genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); |
| 301 | Type etp = llvm::cast<VectorType>(vload.getType()).getElementType(); |
| 302 | if (!llvm::isa<IndexType>(Val: etp)) { |
| 303 | if (etp.getIntOrFloatBitWidth() < 32) |
| 304 | vload = rewriter.create<arith::ExtUIOp>( |
| 305 | loc, vectorType(vl, rewriter.getI32Type()), vload); |
| 306 | else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) |
| 307 | vload = rewriter.create<arith::ExtUIOp>( |
| 308 | loc, vectorType(vl, rewriter.getI64Type()), vload); |
| 309 | } |
| 310 | idxs.push_back(Elt: vload); |
| 311 | } |
| 312 | continue; // success so far |
| 313 | } |
| 314 | // Address calculation 'i = add inv, idx' (after LICM). |
| 315 | // Example: |
| 316 | // a[base + i] |
| 317 | if (auto load = cast.getDefiningOp<arith::AddIOp>()) { |
| 318 | Value inv = load.getOperand(0); |
| 319 | Value idx = load.getOperand(1); |
| 320 | // Swap non-invariant. |
| 321 | if (!isInvariantValue(val: inv, block)) { |
| 322 | inv = idx; |
| 323 | idx = load.getOperand(0); |
| 324 | } |
| 325 | // Inspect. |
| 326 | if (isInvariantValue(val: inv, block)) { |
| 327 | if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) { |
| 328 | if (isInvariantArg(arg, block) || !innermost) |
| 329 | return false; |
| 330 | if (codegen) |
| 331 | idxs.push_back( |
| 332 | rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx)); |
| 333 | continue; // success so far |
| 334 | } |
| 335 | } |
| 336 | } |
| 337 | return false; |
| 338 | } |
| 339 | return true; |
| 340 | } |
| 341 | |
| 342 | #define UNAOP(xxx) \ |
| 343 | if (isa<xxx>(def)) { \ |
| 344 | if (codegen) \ |
| 345 | vexp = rewriter.create<xxx>(loc, vx); \ |
| 346 | return true; \ |
| 347 | } |
| 348 | |
| 349 | #define TYPEDUNAOP(xxx) \ |
| 350 | if (auto x = dyn_cast<xxx>(def)) { \ |
| 351 | if (codegen) { \ |
| 352 | VectorType vtp = vectorType(vl, x.getType()); \ |
| 353 | vexp = rewriter.create<xxx>(loc, vtp, vx); \ |
| 354 | } \ |
| 355 | return true; \ |
| 356 | } |
| 357 | |
| 358 | #define BINOP(xxx) \ |
| 359 | if (isa<xxx>(def)) { \ |
| 360 | if (codegen) \ |
| 361 | vexp = rewriter.create<xxx>(loc, vx, vy); \ |
| 362 | return true; \ |
| 363 | } |
| 364 | |
| 365 | /// This method is called twice to analyze and rewrite the given expression. |
| 366 | /// The first call (!codegen) does the analysis. Then, on success, the second |
| 367 | /// call (codegen) yields the proper vector form in the output parameter 'vexp'. |
| 368 | /// This mechanism ensures that analysis and rewriting code stay in sync. Note |
| 369 | /// that the analyis part is simple because the sparsifier only generates |
| 370 | /// relatively simple expressions inside the for-loops. |
| 371 | static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, |
| 372 | Value exp, bool codegen, Value vmask, Value &vexp) { |
| 373 | Location loc = forOp.getLoc(); |
| 374 | // Reject unsupported types. |
| 375 | if (!VectorType::isValidElementType(exp.getType())) |
| 376 | return false; |
| 377 | // A block argument is invariant/reduction/index. |
| 378 | if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: exp)) { |
| 379 | if (arg == forOp.getInductionVar()) { |
| 380 | // We encountered a single, innermost index inside the computation, |
| 381 | // such as a[i] = i, which must convert to [i, i+1, ...]. |
| 382 | if (codegen) { |
| 383 | VectorType vtp = vectorType(vl, arg.getType()); |
| 384 | Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg); |
| 385 | Value incr = rewriter.create<vector::StepOp>(loc, vtp); |
| 386 | vexp = rewriter.create<arith::AddIOp>(loc, veci, incr); |
| 387 | } |
| 388 | return true; |
| 389 | } |
| 390 | // An invariant or reduction. In both cases, we treat this as an |
| 391 | // invariant value, and rely on later replacing and folding to |
| 392 | // construct a proper reduction chain for the latter case. |
| 393 | if (codegen) |
| 394 | vexp = genVectorInvariantValue(rewriter, vl, val: exp); |
| 395 | return true; |
| 396 | } |
| 397 | // Something defined outside the loop-body is invariant. |
| 398 | Operation *def = exp.getDefiningOp(); |
| 399 | Block *block = &forOp.getRegion().front(); |
| 400 | if (def->getBlock() != block) { |
| 401 | if (codegen) |
| 402 | vexp = genVectorInvariantValue(rewriter, vl, val: exp); |
| 403 | return true; |
| 404 | } |
| 405 | // Proper load operations. These are either values involved in the |
| 406 | // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi], |
| 407 | // or coordinate values inside the computation that are now fetched from |
| 408 | // the sparse storage coordinates arrays, such as a[i] = i becomes |
| 409 | // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index |
| 410 | // and 'hi = lo + vl - 1'. |
| 411 | if (auto load = dyn_cast<memref::LoadOp>(def)) { |
| 412 | auto subs = load.getIndices(); |
| 413 | SmallVector<Value> idxs; |
| 414 | if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { |
| 415 | if (codegen) |
| 416 | vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); |
| 417 | return true; |
| 418 | } |
| 419 | return false; |
| 420 | } |
| 421 | // Inside loop-body unary and binary operations. Note that it would be |
| 422 | // nicer if we could somehow test and build the operations in a more |
| 423 | // concise manner than just listing them all (although this way we know |
| 424 | // for certain that they can vectorize). |
| 425 | // |
| 426 | // TODO: avoid visiting CSEs multiple times |
| 427 | // |
| 428 | if (def->getNumOperands() == 1) { |
| 429 | Value vx; |
| 430 | if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(idx: 0), codegen, vmask, |
| 431 | vx)) { |
| 432 | UNAOP(math::AbsFOp) |
| 433 | UNAOP(math::AbsIOp) |
| 434 | UNAOP(math::CeilOp) |
| 435 | UNAOP(math::FloorOp) |
| 436 | UNAOP(math::SqrtOp) |
| 437 | UNAOP(math::ExpM1Op) |
| 438 | UNAOP(math::Log1pOp) |
| 439 | UNAOP(math::SinOp) |
| 440 | UNAOP(math::TanhOp) |
| 441 | UNAOP(arith::NegFOp) |
| 442 | TYPEDUNAOP(arith::TruncFOp) |
| 443 | TYPEDUNAOP(arith::ExtFOp) |
| 444 | TYPEDUNAOP(arith::FPToSIOp) |
| 445 | TYPEDUNAOP(arith::FPToUIOp) |
| 446 | TYPEDUNAOP(arith::SIToFPOp) |
| 447 | TYPEDUNAOP(arith::UIToFPOp) |
| 448 | TYPEDUNAOP(arith::ExtSIOp) |
| 449 | TYPEDUNAOP(arith::ExtUIOp) |
| 450 | TYPEDUNAOP(arith::IndexCastOp) |
| 451 | TYPEDUNAOP(arith::TruncIOp) |
| 452 | TYPEDUNAOP(arith::BitcastOp) |
| 453 | // TODO: complex? |
| 454 | } |
| 455 | } else if (def->getNumOperands() == 2) { |
| 456 | Value vx, vy; |
| 457 | if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, |
| 458 | vx) && |
| 459 | vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, |
| 460 | vy)) { |
| 461 | // We only accept shift-by-invariant (where the same shift factor applies |
| 462 | // to all packed elements). In the vector dialect, this is still |
| 463 | // represented with an expanded vector at the right-hand-side, however, |
| 464 | // so that we do not have to special case the code generation. |
| 465 | if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) || |
| 466 | isa<arith::ShRSIOp>(def)) { |
| 467 | Value shiftFactor = def->getOperand(idx: 1); |
| 468 | if (!isInvariantValue(val: shiftFactor, block)) |
| 469 | return false; |
| 470 | } |
| 471 | // Generate code. |
| 472 | BINOP(arith::MulFOp) |
| 473 | BINOP(arith::MulIOp) |
| 474 | BINOP(arith::DivFOp) |
| 475 | BINOP(arith::DivSIOp) |
| 476 | BINOP(arith::DivUIOp) |
| 477 | BINOP(arith::AddFOp) |
| 478 | BINOP(arith::AddIOp) |
| 479 | BINOP(arith::SubFOp) |
| 480 | BINOP(arith::SubIOp) |
| 481 | BINOP(arith::AndIOp) |
| 482 | BINOP(arith::OrIOp) |
| 483 | BINOP(arith::XOrIOp) |
| 484 | BINOP(arith::ShLIOp) |
| 485 | BINOP(arith::ShRUIOp) |
| 486 | BINOP(arith::ShRSIOp) |
| 487 | // TODO: complex? |
| 488 | } |
| 489 | } |
| 490 | return false; |
| 491 | } |
| 492 | |
| 493 | #undef UNAOP |
| 494 | #undef TYPEDUNAOP |
| 495 | #undef BINOP |
| 496 | |
| 497 | /// This method is called twice to analyze and rewrite the given for-loop. |
| 498 | /// The first call (!codegen) does the analysis. Then, on success, the second |
| 499 | /// call (codegen) rewriters the IR into vector form. This mechanism ensures |
| 500 | /// that analysis and rewriting code stay in sync. |
| 501 | static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, |
| 502 | bool codegen) { |
| 503 | Block &block = forOp.getRegion().front(); |
| 504 | // For loops with single yield statement (as below) could be generated |
| 505 | // when custom reduce is used with unary operation. |
| 506 | // for (...) |
| 507 | // yield c_0 |
| 508 | if (block.getOperations().size() <= 1) |
| 509 | return false; |
| 510 | |
| 511 | Location loc = forOp.getLoc(); |
| 512 | scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator()); |
| 513 | auto &last = *++block.rbegin(); |
| 514 | scf::ForOp forOpNew; |
| 515 | |
| 516 | // Perform initial set up during codegen (we know that the first analysis |
| 517 | // pass was successful). For reductions, we need to construct a completely |
| 518 | // new for-loop, since the incoming and outgoing reduction type |
| 519 | // changes into SIMD form. For stores, we can simply adjust the stride |
| 520 | // and insert in the existing for-loop. In both cases, we set up a vector |
| 521 | // mask for all operations which takes care of confining vectors to |
| 522 | // the original iteration space (later cleanup loops or other |
| 523 | // optimizations can take care of those). |
| 524 | Value vmask; |
| 525 | if (codegen) { |
| 526 | Value step = constantIndex(builder&: rewriter, loc, i: vl.vectorLength); |
| 527 | if (vl.enableVLAVectorization) { |
| 528 | Value vscale = |
| 529 | rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); |
| 530 | step = rewriter.create<arith::MulIOp>(loc, vscale, step); |
| 531 | } |
| 532 | if (!yield.getResults().empty()) { |
| 533 | Value init = forOp.getInitArgs()[0]; |
| 534 | VectorType vtp = vectorType(vl, init.getType()); |
| 535 | Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), |
| 536 | forOp.getRegionIterArg(0), init, vtp); |
| 537 | forOpNew = rewriter.create<scf::ForOp>( |
| 538 | loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); |
| 539 | forOpNew->setAttr( |
| 540 | LoopEmitter::getLoopEmitterLoopAttrName(), |
| 541 | forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); |
| 542 | rewriter.setInsertionPointToStart(forOpNew.getBody()); |
| 543 | } else { |
| 544 | rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); }); |
| 545 | rewriter.setInsertionPoint(yield); |
| 546 | } |
| 547 | vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), |
| 548 | forOp.getLowerBound(), forOp.getUpperBound(), step); |
| 549 | } |
| 550 | |
| 551 | // Sparse for-loops either are terminated by a non-empty yield operation |
| 552 | // (reduction loop) or otherwise by a store operation (pararallel loop). |
| 553 | if (!yield.getResults().empty()) { |
| 554 | // Analyze/vectorize reduction. |
| 555 | if (yield->getNumOperands() != 1) |
| 556 | return false; |
| 557 | Value red = yield->getOperand(0); |
| 558 | Value iter = forOp.getRegionIterArg(0); |
| 559 | vector::CombiningKind kind; |
| 560 | Value vrhs; |
| 561 | if (isVectorizableReduction(red, iter, kind) && |
| 562 | vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { |
| 563 | if (codegen) { |
| 564 | Value partial = forOpNew.getResult(0); |
| 565 | Value vpass = genVectorInvariantValue(rewriter, vl, val: iter); |
| 566 | Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass); |
| 567 | rewriter.create<scf::YieldOp>(loc, vred); |
| 568 | rewriter.setInsertionPointAfter(forOpNew); |
| 569 | Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial); |
| 570 | // Now do some relinking (last one is not completely type safe |
| 571 | // but all bad ones are removed right away). This also folds away |
| 572 | // nop broadcast operations. |
| 573 | rewriter.replaceAllUsesWith(forOp.getResult(0), vres); |
| 574 | rewriter.replaceAllUsesWith(forOp.getInductionVar(), |
| 575 | forOpNew.getInductionVar()); |
| 576 | rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0), |
| 577 | forOpNew.getRegionIterArg(0)); |
| 578 | rewriter.eraseOp(op: forOp); |
| 579 | } |
| 580 | return true; |
| 581 | } |
| 582 | } else if (auto store = dyn_cast<memref::StoreOp>(last)) { |
| 583 | // Analyze/vectorize store operation. |
| 584 | auto subs = store.getIndices(); |
| 585 | SmallVector<Value> idxs; |
| 586 | Value rhs = store.getValue(); |
| 587 | Value vrhs; |
| 588 | if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && |
| 589 | vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { |
| 590 | if (codegen) { |
| 591 | genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); |
| 592 | rewriter.eraseOp(op: store); |
| 593 | } |
| 594 | return true; |
| 595 | } |
| 596 | } |
| 597 | |
| 598 | assert(!codegen && "cannot call codegen when analysis failed" ); |
| 599 | return false; |
| 600 | } |
| 601 | |
| 602 | /// Basic for-loop vectorizer. |
| 603 | struct ForOpRewriter : public OpRewritePattern<scf::ForOp> { |
| 604 | public: |
| 605 | using OpRewritePattern<scf::ForOp>::OpRewritePattern; |
| 606 | |
| 607 | ForOpRewriter(MLIRContext *context, unsigned vectorLength, |
| 608 | bool enableVLAVectorization, bool enableSIMDIndex32) |
| 609 | : OpRewritePattern(context), vl{.vectorLength: vectorLength, .enableVLAVectorization: enableVLAVectorization, |
| 610 | .enableSIMDIndex32: enableSIMDIndex32} {} |
| 611 | |
| 612 | LogicalResult matchAndRewrite(scf::ForOp op, |
| 613 | PatternRewriter &rewriter) const override { |
| 614 | // Check for single block, unit-stride for-loop that is generated by |
| 615 | // sparsifier, which means no data dependence analysis is required, |
| 616 | // and its loop-body is very restricted in form. |
| 617 | if (!op.getRegion().hasOneBlock() || !isOneInteger(op.getStep()) || |
| 618 | !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) |
| 619 | return failure(); |
| 620 | // Analyze (!codegen) and rewrite (codegen) loop-body. |
| 621 | if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && |
| 622 | vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) |
| 623 | return success(); |
| 624 | return failure(); |
| 625 | } |
| 626 | |
| 627 | private: |
| 628 | const VL vl; |
| 629 | }; |
| 630 | |
| 631 | /// Reduction chain cleanup. |
| 632 | /// v = for { } |
| 633 | /// s = vsum(v) v = for { } |
| 634 | /// u = expand(s) -> for (v) { } |
| 635 | /// for (u) { } |
| 636 | template <typename VectorOp> |
| 637 | struct ReducChainRewriter : public OpRewritePattern<VectorOp> { |
| 638 | public: |
| 639 | using OpRewritePattern<VectorOp>::OpRewritePattern; |
| 640 | |
| 641 | LogicalResult matchAndRewrite(VectorOp op, |
| 642 | PatternRewriter &rewriter) const override { |
| 643 | Value inp = op.getSource(); |
| 644 | if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) { |
| 645 | if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) { |
| 646 | if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) { |
| 647 | rewriter.replaceOp(op, redOp.getVector()); |
| 648 | return success(); |
| 649 | } |
| 650 | } |
| 651 | } |
| 652 | return failure(); |
| 653 | } |
| 654 | }; |
| 655 | |
| 656 | } // namespace |
| 657 | |
| 658 | //===----------------------------------------------------------------------===// |
| 659 | // Public method for populating vectorization rules. |
| 660 | //===----------------------------------------------------------------------===// |
| 661 | |
| 662 | /// Populates the given patterns list with vectorization rules. |
| 663 | void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, |
| 664 | unsigned vectorLength, |
| 665 | bool enableVLAVectorization, |
| 666 | bool enableSIMDIndex32) { |
| 667 | assert(vectorLength > 0); |
| 668 | vector::populateVectorStepLoweringPatterns(patterns); |
| 669 | patterns.add<ForOpRewriter>(arg: patterns.getContext(), args&: vectorLength, |
| 670 | args&: enableVLAVectorization, args&: enableSIMDIndex32); |
| 671 | patterns.add<ReducChainRewriter<vector::InsertElementOp>, |
| 672 | ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext()); |
| 673 | } |
| 674 | |